Binary Skipped ascii filetype fix (#108)
* chore: refactor code for improved readability and maintainability - Standardize variable naming conventions for consistency. - Improve logging messages for better clarity and debugging. - Remove unnecessary imports and clean up code structure. - Enhance error handling and logging in various modules. - Update comments and docstrings for better understanding. - Optimize imports and organize them logically. - Ensure consistent formatting across files for better readability. - Refactor functions to reduce complexity and improve performance. - Add missing type hints and annotations for better code clarity. - Improve test coverage and organization in test files. style(tests): apply consistent formatting and spacing in test files for improved readability and maintainability * chore(tests): remove redundant test for ensure_tables_created with no models to streamline test suite and reduce maintenance overhead * fix(memory.py): update is_binary_file function to correctly identify binary files by returning True for non-text mime types
This commit is contained in:
parent
429f854fb8
commit
e960a68d29
|
|
@ -6,9 +6,8 @@ This script creates a baseline migration representing the current database schem
|
|||
It serves as the foundation for future schema changes.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add the project root to the Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
|
@ -20,6 +19,7 @@ from ra_aid.logging_config import get_logger, setup_logging
|
|||
setup_logging(verbose=True)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_initial_migration():
|
||||
"""
|
||||
Create the initial migration for the current database schema.
|
||||
|
|
@ -49,6 +49,7 @@ def create_initial_migration():
|
|||
print(f"❌ Error creating initial migration: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Creating initial database migration...")
|
||||
success = create_initial_migration()
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from ra_aid.agent_utils import (
|
|||
run_planning_agent,
|
||||
run_research_agent,
|
||||
)
|
||||
from ra_aid.database import init_db, close_db, DatabaseManager, ensure_migrations_applied
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
DEFAULT_RECURSION_LIMIT,
|
||||
|
|
@ -25,6 +24,10 @@ from ra_aid.config import (
|
|||
VALID_PROVIDERS,
|
||||
)
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.database import (
|
||||
DatabaseManager,
|
||||
ensure_migrations_applied,
|
||||
)
|
||||
from ra_aid.dependencies import check_dependencies
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
|
|
@ -171,8 +174,9 @@ Examples:
|
|||
"--aider-config", type=str, help="Specify the aider config file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-aider", action="store_true",
|
||||
help="Use aider for code modifications instead of default file tools (file_str_replace, put_complete_file_contents)"
|
||||
"--use-aider",
|
||||
action="store_true",
|
||||
help="Use aider for code modifications instead of default file tools (file_str_replace, put_complete_file_contents)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-cmd",
|
||||
|
|
@ -343,24 +347,37 @@ def main():
|
|||
try:
|
||||
migration_result = ensure_migrations_applied()
|
||||
if not migration_result:
|
||||
logger.warning("Database migrations failed but execution will continue")
|
||||
logger.warning(
|
||||
"Database migrations failed but execution will continue"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Database migration error: {str(e)}")
|
||||
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
|
||||
validate_environment(args)
|
||||
) # Will exit if main env vars missing
|
||||
(
|
||||
expert_enabled,
|
||||
expert_missing,
|
||||
web_research_enabled,
|
||||
web_research_missing,
|
||||
) = validate_environment(args) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
# Validate model configuration early
|
||||
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
|
||||
model_config = models_params.get(args.provider, {}).get(
|
||||
args.model or "", {}
|
||||
)
|
||||
supports_temperature = model_config.get(
|
||||
"supports_temperature",
|
||||
args.provider
|
||||
in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"],
|
||||
in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"deepseek",
|
||||
],
|
||||
)
|
||||
|
||||
if supports_temperature and args.temperature is None:
|
||||
|
|
@ -377,7 +394,12 @@ def main():
|
|||
status = build_status(args, expert_enabled, web_research_enabled)
|
||||
|
||||
console.print(
|
||||
Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1))
|
||||
Panel(
|
||||
status,
|
||||
title=f"RA.Aid v{__version__}",
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Handle chat mode
|
||||
|
|
@ -449,7 +471,9 @@ def main():
|
|||
CHAT_PROMPT.format(
|
||||
initial_request=initial_request,
|
||||
web_research_section=(
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else ""
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
if web_research_enabled
|
||||
else ""
|
||||
),
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
|
|
@ -502,7 +526,9 @@ def main():
|
|||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = args.research_model or args.model
|
||||
_global_memory["config"]["research_model"] = (
|
||||
args.research_model or args.model
|
||||
)
|
||||
|
||||
# Store temperature in global config
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
# Thread-local storage for context variables
|
||||
_thread_local = threading.local()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import threading
|
|||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, ContextManager
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
|
|
@ -30,6 +30,12 @@ from rich.console import Console
|
|||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
agent_context,
|
||||
is_completed,
|
||||
reset_completion_flags,
|
||||
should_exit,
|
||||
)
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.agents_alias import RAgents
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||
|
|
@ -72,14 +78,6 @@ from ra_aid.tool_configs import (
|
|||
get_web_research_tools,
|
||||
)
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.agent_context import (
|
||||
agent_context,
|
||||
get_current_context,
|
||||
is_completed,
|
||||
reset_completion_flags,
|
||||
get_completion_message,
|
||||
should_exit,
|
||||
)
|
||||
from ra_aid.tools.memory import (
|
||||
_global_memory,
|
||||
get_memory_value,
|
||||
|
|
@ -250,8 +248,12 @@ def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
|||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
result = (
|
||||
(provider.lower() == "anthropic" and model_name and "claude" in model_name.lower())
|
||||
or (provider.lower() == "openrouter" and model_name.lower().startswith("anthropic/claude-"))
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
) or (
|
||||
provider.lower() == "openrouter"
|
||||
and model_name.lower().startswith("anthropic/claude-")
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
@ -955,7 +957,8 @@ def run_agent_with_retry(
|
|||
check_interrupt()
|
||||
|
||||
# Check if the agent has crashed before attempting to run it
|
||||
from ra_aid.agent_context import is_crashed, get_crash_message
|
||||
from ra_aid.agent_context import get_crash_message, is_crashed
|
||||
|
||||
if is_crashed():
|
||||
crash_message = get_crash_message()
|
||||
logger.error("Agent has crashed: %s", crash_message)
|
||||
|
|
@ -982,6 +985,7 @@ def run_agent_with_retry(
|
|||
error_str = str(e).lower()
|
||||
if "400" in error_str or "bad request" in error_str:
|
||||
from ra_aid.agent_context import mark_agent_crashed
|
||||
|
||||
crash_message = f"Unretryable error: {str(e)}"
|
||||
mark_agent_crashed(crash_message)
|
||||
logger.error("Agent has crashed: %s", crash_message)
|
||||
|
|
@ -1007,8 +1011,11 @@ def run_agent_with_retry(
|
|||
) as e:
|
||||
# Check if this is a BadRequestError (HTTP 400) which is unretryable
|
||||
error_str = str(e).lower()
|
||||
if ("400" in error_str or "bad request" in error_str) and isinstance(e, APIError):
|
||||
if (
|
||||
"400" in error_str or "bad request" in error_str
|
||||
) and isinstance(e, APIError):
|
||||
from ra_aid.agent_context import mark_agent_crashed
|
||||
|
||||
crash_message = f"Unretryable API error: {str(e)}"
|
||||
mark_agent_crashed(crash_message)
|
||||
logger.error("Agent has crashed: %s", crash_message)
|
||||
|
|
|
|||
|
|
@ -5,34 +5,29 @@ This package provides database functionality for the ra_aid application,
|
|||
including connection management, models, utility functions, and migrations.
|
||||
"""
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db,
|
||||
get_db,
|
||||
close_db,
|
||||
DatabaseManager
|
||||
from ra_aid.database.connection import DatabaseManager, close_db, get_db, init_db
|
||||
from ra_aid.database.migrations import (
|
||||
MigrationManager,
|
||||
create_new_migration,
|
||||
ensure_migrations_applied,
|
||||
get_migration_status,
|
||||
init_migrations,
|
||||
)
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created
|
||||
from ra_aid.database.migrations import (
|
||||
init_migrations,
|
||||
ensure_migrations_applied,
|
||||
create_new_migration,
|
||||
get_migration_status,
|
||||
MigrationManager
|
||||
)
|
||||
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
||||
|
||||
__all__ = [
|
||||
'init_db',
|
||||
'get_db',
|
||||
'close_db',
|
||||
'DatabaseManager',
|
||||
'BaseModel',
|
||||
'get_model_count',
|
||||
'truncate_table',
|
||||
'ensure_tables_created',
|
||||
'init_migrations',
|
||||
'ensure_migrations_applied',
|
||||
'create_new_migration',
|
||||
'get_migration_status',
|
||||
'MigrationManager',
|
||||
"init_db",
|
||||
"get_db",
|
||||
"close_db",
|
||||
"DatabaseManager",
|
||||
"BaseModel",
|
||||
"get_model_count",
|
||||
"truncate_table",
|
||||
"ensure_tables_created",
|
||||
"init_migrations",
|
||||
"ensure_migrations_applied",
|
||||
"create_new_migration",
|
||||
"get_migration_status",
|
||||
"MigrationManager",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ This module provides functions to initialize, get, and close database connection
|
|||
It also provides a context manager for database connections.
|
||||
"""
|
||||
|
||||
import os
|
||||
import contextvars
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import peewee
|
||||
|
||||
|
|
@ -54,8 +54,12 @@ class DatabaseManager:
|
|||
"""
|
||||
return init_db(in_memory=self.in_memory)
|
||||
|
||||
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception],
|
||||
exc_tb: Optional[Any]) -> None:
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type],
|
||||
exc_val: Optional[Exception],
|
||||
exc_tb: Optional[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Close the database connection when exiting the context.
|
||||
|
||||
|
|
@ -69,6 +73,7 @@ class DatabaseManager:
|
|||
# Don't suppress exceptions
|
||||
return False
|
||||
|
||||
|
||||
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||
"""
|
||||
Initialize the database connection.
|
||||
|
|
@ -113,7 +118,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
|
||||
ra_aid_dir = Path(ra_aid_dir_str)
|
||||
ra_aid_dir = ra_aid_dir.absolute() # Ensure we have the absolute path
|
||||
ra_aid_dir_str = str(ra_aid_dir) # Update string representation with absolute path
|
||||
ra_aid_dir_str = str(
|
||||
ra_aid_dir
|
||||
) # Update string representation with absolute path
|
||||
|
||||
logger.debug(f"Creating database directory at: {ra_aid_dir_str}")
|
||||
|
||||
|
|
@ -126,7 +133,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
try:
|
||||
logger.debug("Attempting directory creation with os.mkdir")
|
||||
os.mkdir(ra_aid_dir_str, mode=0o755)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
|
||||
ra_aid_dir_str
|
||||
)
|
||||
if directory_created:
|
||||
logger.debug("Directory created successfully with os.mkdir")
|
||||
except Exception as e:
|
||||
|
|
@ -142,7 +151,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
try:
|
||||
logger.debug("Attempting directory creation with os.makedirs")
|
||||
os.makedirs(ra_aid_dir_str, exist_ok=True, mode=0o755)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
|
||||
ra_aid_dir_str
|
||||
)
|
||||
if directory_created:
|
||||
logger.debug("Directory created successfully with os.makedirs")
|
||||
except Exception as e:
|
||||
|
|
@ -155,7 +166,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
try:
|
||||
logger.debug("Attempting directory creation with Path.mkdir")
|
||||
ra_aid_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
|
||||
ra_aid_dir_str
|
||||
)
|
||||
if directory_created:
|
||||
logger.debug("Directory created successfully with Path.mkdir")
|
||||
except Exception as e:
|
||||
|
|
@ -168,7 +181,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
os_exists = os.path.exists(ra_aid_dir_str)
|
||||
is_dir = os.path.isdir(ra_aid_dir_str) if os_exists else False
|
||||
|
||||
logger.debug(f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}")
|
||||
logger.debug(
|
||||
f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}"
|
||||
)
|
||||
|
||||
# Check parent directory permissions and contents for debugging
|
||||
try:
|
||||
|
|
@ -190,7 +205,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
# Check directory permissions
|
||||
try:
|
||||
permissions = oct(os.stat(ra_aid_dir_str).st_mode)[-3:]
|
||||
logger.debug(f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}")
|
||||
logger.debug(
|
||||
f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}"
|
||||
)
|
||||
|
||||
# List directory contents for debugging
|
||||
dir_contents = os.listdir(ra_aid_dir_str)
|
||||
|
|
@ -213,7 +230,7 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
if not db_file_exists:
|
||||
try:
|
||||
logger.debug(f"Creating empty database file at: {db_path}")
|
||||
with open(db_path, 'w') as f:
|
||||
with open(db_path, "w") as f:
|
||||
pass # Create empty file
|
||||
|
||||
# Verify the file was created
|
||||
|
|
@ -230,10 +247,10 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
db = peewee.SqliteDatabase(
|
||||
db_path,
|
||||
pragmas={
|
||||
'journal_mode': 'wal', # Write-Ahead Logging for better concurrency
|
||||
'foreign_keys': 1, # Enforce foreign key constraints
|
||||
'cache_size': -1024 * 32, # 32MB cache
|
||||
}
|
||||
"journal_mode": "wal", # Write-Ahead Logging for better concurrency
|
||||
"foreign_keys": 1, # Enforce foreign key constraints
|
||||
"cache_size": -1024 * 32, # 32MB cache
|
||||
},
|
||||
)
|
||||
|
||||
# Always explicitly connect to ensure the connection is established
|
||||
|
|
@ -256,13 +273,15 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
# Check if the database file exists after initialization
|
||||
db_file_exists = os.path.exists(db_path)
|
||||
db_file_size = os.path.getsize(db_path) if db_file_exists else 0
|
||||
logger.debug(f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes")
|
||||
logger.debug(
|
||||
f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Database verification failed: {str(e)}")
|
||||
# Continue anyway, as this is just a verification step
|
||||
|
||||
# Only show initialization message if it hasn't been shown before
|
||||
if not hasattr(db, '_message_shown') or not db._message_shown:
|
||||
if not hasattr(db, "_message_shown") or not db._message_shown:
|
||||
if in_memory:
|
||||
logger.debug("In-memory database connection initialized successfully")
|
||||
else:
|
||||
|
|
@ -280,6 +299,7 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_db() -> peewee.SqliteDatabase:
|
||||
"""
|
||||
Get the current database connection.
|
||||
|
|
@ -309,13 +329,14 @@ def get_db() -> peewee.SqliteDatabase:
|
|||
# First, remove the old connection from the context var
|
||||
db_var.set(None)
|
||||
# Then initialize a new connection with the same in-memory setting
|
||||
in_memory = hasattr(db, '_is_in_memory') and db._is_in_memory
|
||||
in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory
|
||||
logger.debug(f"Creating new database connection (in_memory={in_memory})")
|
||||
# Create a completely new database object, don't reuse the old one
|
||||
return init_db(in_memory=in_memory)
|
||||
|
||||
return db
|
||||
|
||||
|
||||
def close_db() -> None:
|
||||
"""
|
||||
Close the current database connection if it exists.
|
||||
|
|
@ -332,7 +353,9 @@ def close_db() -> None:
|
|||
db.close()
|
||||
logger.info("Database connection closed successfully")
|
||||
else:
|
||||
logger.debug("Database connection was already closed (normal during shutdown)")
|
||||
logger.debug(
|
||||
"Database connection was already closed (normal during shutdown)"
|
||||
)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database Error: Failed to close connection: {str(e)}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -6,16 +6,14 @@ using peewee-migrate. It includes tools for creating, checking, and applying
|
|||
migrations automatically.
|
||||
"""
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import peewee
|
||||
from peewee_migrate import Router
|
||||
from peewee_migrate.router import DEFAULT_MIGRATE_DIR
|
||||
|
||||
from ra_aid.database.connection import get_db, DatabaseManager
|
||||
from ra_aid.database.connection import DatabaseManager, get_db
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -33,7 +31,9 @@ class MigrationManager:
|
|||
pending migrations, apply migrations, and create new migrations.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None):
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the MigrationManager.
|
||||
|
||||
|
|
@ -96,7 +96,9 @@ class MigrationManager:
|
|||
Router: Configured peewee-migrate Router instance
|
||||
"""
|
||||
try:
|
||||
router = Router(self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE)
|
||||
router = Router(
|
||||
self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE
|
||||
)
|
||||
logger.debug(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
|
||||
return router
|
||||
except Exception as e:
|
||||
|
|
@ -120,7 +122,9 @@ class MigrationManager:
|
|||
# Calculate pending migrations
|
||||
pending = [m for m in all_migrations if m not in applied]
|
||||
|
||||
logger.debug(f"Found {len(applied)} applied migrations and {len(pending)} pending migrations")
|
||||
logger.debug(
|
||||
f"Found {len(applied)} applied migrations and {len(pending)} pending migrations"
|
||||
)
|
||||
return applied, pending
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check migrations: {str(e)}")
|
||||
|
|
@ -175,8 +179,8 @@ class MigrationManager:
|
|||
"""
|
||||
try:
|
||||
# Sanitize migration name
|
||||
safe_name = name.replace(' ', '_').lower()
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
safe_name = name.replace(" ", "_").lower()
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
migration_name = f"{timestamp}_{safe_name}"
|
||||
|
||||
logger.info(f"Creating new migration: {migration_name}")
|
||||
|
|
@ -209,7 +213,9 @@ class MigrationManager:
|
|||
}
|
||||
|
||||
|
||||
def init_migrations(db_path: Optional[str] = None, migrations_dir: Optional[str] = None) -> MigrationManager:
|
||||
def init_migrations(
|
||||
db_path: Optional[str] = None, migrations_dir: Optional[str] = None
|
||||
) -> MigrationManager:
|
||||
"""
|
||||
Initialize the migration manager.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,16 +5,17 @@ This module defines the base model class that all models will inherit from.
|
|||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Any, Dict, Type, TypeVar
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
T = TypeVar('T', bound='BaseModel')
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseModel(peewee.Model):
|
||||
"""
|
||||
Base model class for all ra_aid models.
|
||||
|
|
@ -22,6 +23,7 @@ class BaseModel(peewee.Model):
|
|||
All models should inherit from this class to ensure consistent
|
||||
behavior and database connection.
|
||||
"""
|
||||
|
||||
created_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,14 +3,18 @@ Tests for the database connection module.
|
|||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import peewee
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db, get_db, close_db, db_var, DatabaseManager
|
||||
DatabaseManager,
|
||||
close_db,
|
||||
db_var,
|
||||
get_db,
|
||||
init_db,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -25,10 +29,10 @@ def cleanup_db():
|
|||
# Clean up after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
if hasattr(db, '_is_in_memory'):
|
||||
delattr(db, '_is_in_memory')
|
||||
if hasattr(db, '_message_shown'):
|
||||
delattr(db, '_message_shown')
|
||||
if hasattr(db, "_is_in_memory"):
|
||||
delattr(db, "_is_in_memory")
|
||||
if hasattr(db, "_message_shown"):
|
||||
delattr(db, "_message_shown")
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
|
|
@ -69,7 +73,9 @@ def test_init_db_creates_directory(cleanup_db, tmp_path):
|
|||
os.chdir(tmp_path_str)
|
||||
current_cwd = os.getcwd()
|
||||
print(f"Changed working directory to: {current_cwd}")
|
||||
assert current_cwd == tmp_path_str, f"Failed to change directory: {current_cwd} != {tmp_path_str}"
|
||||
assert (
|
||||
current_cwd == tmp_path_str
|
||||
), f"Failed to change directory: {current_cwd} != {tmp_path_str}"
|
||||
|
||||
# Create the .ra-aid directory manually to ensure it exists
|
||||
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
|
||||
|
|
@ -77,13 +83,17 @@ def test_init_db_creates_directory(cleanup_db, tmp_path):
|
|||
os.makedirs(ra_aid_path_str, exist_ok=True)
|
||||
|
||||
# Verify the directory was created
|
||||
assert os.path.exists(ra_aid_path_str), f".ra-aid directory not found at {ra_aid_path_str}"
|
||||
assert os.path.isdir(ra_aid_path_str), f"{ra_aid_path_str} exists but is not a directory"
|
||||
assert os.path.exists(
|
||||
ra_aid_path_str
|
||||
), f".ra-aid directory not found at {ra_aid_path_str}"
|
||||
assert os.path.isdir(
|
||||
ra_aid_path_str
|
||||
), f"{ra_aid_path_str} exists but is not a directory"
|
||||
|
||||
# Create a test file to verify write permissions
|
||||
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
|
||||
print(f"Creating test file to verify write permissions: {test_file_path}")
|
||||
with open(test_file_path, 'w') as f:
|
||||
with open(test_file_path, "w") as f:
|
||||
f.write("Test write permissions")
|
||||
|
||||
# Verify the test file was created
|
||||
|
|
@ -92,11 +102,13 @@ def test_init_db_creates_directory(cleanup_db, tmp_path):
|
|||
# Create an empty database file to ensure it exists before init_db
|
||||
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
|
||||
print(f"Creating empty database file at: {db_file_str}")
|
||||
with open(db_file_str, 'w') as f:
|
||||
with open(db_file_str, "w") as f:
|
||||
f.write("") # Create empty file
|
||||
|
||||
# Verify the database file was created
|
||||
assert os.path.exists(db_file_str), f"Empty database file not created at {db_file_str}"
|
||||
assert os.path.exists(
|
||||
db_file_str
|
||||
), f"Empty database file not created at {db_file_str}"
|
||||
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
|
||||
|
||||
# Get directory permissions for debugging
|
||||
|
|
@ -157,7 +169,7 @@ def test_init_db_with_in_memory_mode(cleanup_db):
|
|||
# Check that the database connection is returned
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
|
|
@ -272,7 +284,7 @@ def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
|
|||
return original_connect(self, *args, **kwargs)
|
||||
|
||||
# Apply the patch
|
||||
monkeypatch.setattr(peewee.SqliteDatabase, 'connect', mock_connect)
|
||||
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
|
||||
|
||||
# Get the database connection
|
||||
db2 = get_db()
|
||||
|
|
@ -321,7 +333,7 @@ def test_close_db_handles_already_closed_connection(cleanup_db):
|
|||
close_db()
|
||||
|
||||
|
||||
@patch('ra_aid.database.connection.peewee.SqliteDatabase.close')
|
||||
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
|
||||
def test_close_db_handles_error(mock_close, cleanup_db):
|
||||
"""
|
||||
Test that close_db handles errors when closing the connection.
|
||||
|
|
@ -362,7 +374,7 @@ def test_database_manager_with_in_memory_mode(cleanup_db):
|
|||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
|
|
@ -391,7 +403,7 @@ def test_init_db_sets_is_in_memory_attribute(cleanup_db):
|
|||
db = init_db(in_memory=False)
|
||||
|
||||
# Check that the _is_in_memory attribute is set to False
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Reset the contextvar
|
||||
|
|
@ -401,22 +413,17 @@ def test_init_db_sets_is_in_memory_attribute(cleanup_db):
|
|||
db = init_db(in_memory=True)
|
||||
|
||||
# Check that the _is_in_memory attribute is set to True
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
"""
|
||||
Tests for the database connection module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db, get_db, close_db,
|
||||
db_var, DatabaseManager, logger
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -493,7 +500,7 @@ class TestInitDb:
|
|||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Verify the database file was created using both Path and os.path methods
|
||||
|
|
@ -513,7 +520,9 @@ class TestInitDb:
|
|||
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
|
||||
|
||||
# Use os.path for assertions to be more reliable
|
||||
assert os.path.exists(ra_aid_dir_str), f"Directory {ra_aid_dir_str} does not exist"
|
||||
assert os.path.exists(
|
||||
ra_aid_dir_str
|
||||
), f"Directory {ra_aid_dir_str} does not exist"
|
||||
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
|
||||
|
||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
||||
|
|
@ -526,7 +535,7 @@ class TestInitDb:
|
|||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
def test_init_db_reuses_connection(self, cleanup_db):
|
||||
|
|
@ -559,7 +568,7 @@ class TestGetDb:
|
|||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
def test_get_db_reuses_connection(self, cleanup_db):
|
||||
|
|
@ -617,7 +626,7 @@ class TestDatabaseManager:
|
|||
with DatabaseManager() as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Verify the database file was created
|
||||
|
|
@ -633,7 +642,7 @@ class TestDatabaseManager:
|
|||
with DatabaseManager(in_memory=True) as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
# Verify the connection is closed after exiting the context
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Database utility functions for ra_aid.
|
|||
This module provides utility functions for common database operations.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import List, Type
|
||||
|
||||
|
|
@ -16,6 +15,7 @@ from ra_aid.logging_config import get_logger
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
||||
"""
|
||||
Ensure that database tables for the specified models exist.
|
||||
|
|
@ -39,9 +39,11 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
|||
|
||||
# Find all classes in the module that inherit from BaseModel
|
||||
for name, obj in inspect.getmembers(models_module):
|
||||
if (inspect.isclass(obj) and
|
||||
issubclass(obj, BaseModel) and
|
||||
obj != BaseModel):
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, BaseModel)
|
||||
and obj != BaseModel
|
||||
):
|
||||
models.append(obj)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Error importing model modules: {e}")
|
||||
|
|
@ -61,6 +63,7 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
|||
logger.error(f"Error: Failed to create tables: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_model_count(model_class: Type[BaseModel]) -> int:
|
||||
"""
|
||||
Get the count of records for a specific model.
|
||||
|
|
@ -77,6 +80,7 @@ def get_model_count(model_class: Type[BaseModel]) -> int:
|
|||
logger.error(f"Database Error: Failed to count records: {str(e)}")
|
||||
return 0
|
||||
|
||||
|
||||
def truncate_table(model_class: Type[BaseModel]) -> None:
|
||||
"""
|
||||
Delete all records from a model's table.
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
"""Module for checking system dependencies required by RA.Aid."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ra_aid import print_error
|
||||
|
|
@ -23,9 +22,11 @@ class RipGrepDependency(Dependency):
|
|||
def check(self):
|
||||
"""Check if ripgrep is installed."""
|
||||
try:
|
||||
result = subprocess.run(['rg', '--version'],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL)
|
||||
result = subprocess.run(
|
||||
["rg", "--version"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError()
|
||||
except (FileNotFoundError, subprocess.SubprocessError):
|
||||
|
|
|
|||
|
|
@ -122,10 +122,7 @@ def create_openrouter_client(
|
|||
is_expert: bool = False,
|
||||
) -> BaseChatModel:
|
||||
"""Create OpenRouter client with appropriate configuration."""
|
||||
default_headers = {
|
||||
"HTTP-Referer": "https://ra-aid.ai",
|
||||
"X-Title": "RA.Aid"
|
||||
}
|
||||
default_headers = {"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"}
|
||||
|
||||
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
|
||||
return ChatDeepseekReasoner(
|
||||
|
|
@ -245,10 +242,7 @@ def create_llm_client(
|
|||
temp_kwargs = {}
|
||||
|
||||
if supports_thinking:
|
||||
temp_kwargs = {"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 12000
|
||||
}}
|
||||
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||
|
||||
if provider == "deepseek":
|
||||
return create_deepseek_client(
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import signal
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pyte
|
||||
from pyte.screens import HistoryScreen
|
||||
|
|
@ -33,7 +33,10 @@ else:
|
|||
|
||||
|
||||
def create_process(
|
||||
cmd: List[str], env: Optional[dict] = None, cols: Optional[int] = None, rows: Optional[int] = None
|
||||
cmd: List[str],
|
||||
env: Optional[dict] = None,
|
||||
cols: Optional[int] = None,
|
||||
rows: Optional[int] = None,
|
||||
) -> Tuple[subprocess.Popen, Optional[int]]:
|
||||
"""
|
||||
Create a subprocess with appropriate settings for the current platform.
|
||||
|
|
@ -403,7 +406,7 @@ def run_interactive_command(
|
|||
all_lines = []
|
||||
|
||||
# Add history.top lines (older history)
|
||||
if hasattr(screen.history.top, 'keys'):
|
||||
if hasattr(screen.history.top, "keys"):
|
||||
# Dictionary-like object
|
||||
for line_num in sorted(screen.history.top.keys()):
|
||||
line = screen.history.top[line_num]
|
||||
|
|
@ -417,7 +420,7 @@ def run_interactive_command(
|
|||
all_lines.extend([render_line(line, cols) for line in screen.display])
|
||||
|
||||
# Add history.bottom lines (newer history)
|
||||
if hasattr(screen.history.bottom, 'keys'):
|
||||
if hasattr(screen.history.bottom, "keys"):
|
||||
# Dictionary-like object
|
||||
for line_num in sorted(screen.history.bottom.keys()):
|
||||
line = screen.history.bottom[line_num]
|
||||
|
|
@ -437,12 +440,12 @@ def run_interactive_command(
|
|||
print(f"Warning: Error processing terminal output: {e}", file=sys.stderr)
|
||||
try:
|
||||
# Decode raw output, strip trailing whitespace from each line
|
||||
decoded = raw_output.decode('utf-8', errors='replace')
|
||||
decoded = raw_output.decode("utf-8", errors="replace")
|
||||
lines = [line.rstrip() for line in decoded.splitlines()]
|
||||
final_output = "\n".join(lines)
|
||||
except Exception:
|
||||
# Ultimate fallback if line processing fails
|
||||
final_output = raw_output.decode('utf-8', errors='replace').strip()
|
||||
final_output = raw_output.decode("utf-8", errors="replace").strip()
|
||||
|
||||
# Add timeout message if process was terminated due to timeout.
|
||||
if was_terminated:
|
||||
|
|
|
|||
|
|
@ -241,7 +241,9 @@ If you find this is an empty directory, you can stop research immediately and as
|
|||
|
||||
"""
|
||||
|
||||
RESEARCH_PROMPT = RESEARCH_COMMON_PROMPT_HEADER + """
|
||||
RESEARCH_PROMPT = (
|
||||
RESEARCH_COMMON_PROMPT_HEADER
|
||||
+ """
|
||||
|
||||
Project State Handling:
|
||||
For new/empty projects:
|
||||
|
|
@ -280,9 +282,12 @@ NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
|||
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATINO IS REQUIRED, CALL request_implementation.
|
||||
IF THE USER ASKED YOU TO UPDATE A FILE, JUST DO RESEARCH FIRST, EMIT YOUR RESEARCH NOTES, THEN CALL request_implementation.
|
||||
"""
|
||||
)
|
||||
|
||||
# Research-only prompt - similar to research prompt but without implementation references
|
||||
RESEARCH_ONLY_PROMPT = RESEARCH_COMMON_PROMPT_HEADER + """
|
||||
RESEARCH_ONLY_PROMPT = (
|
||||
RESEARCH_COMMON_PROMPT_HEADER
|
||||
+ """
|
||||
|
||||
You have been spawned by a higher level research agent, so only spawn more research tasks sparingly if absolutely necessary. Keep your research *very* scoped and efficient.
|
||||
|
||||
|
|
@ -290,6 +295,7 @@ When you emit research notes, keep it extremely concise and relevant only to the
|
|||
|
||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||
"""
|
||||
)
|
||||
|
||||
# Web research prompt - guides web search and information gathering
|
||||
WEB_RESEARCH_PROMPT = """Current Date: {current_date}
|
||||
|
|
|
|||
|
|
@ -162,7 +162,9 @@ class AnthropicStrategy(ProviderStrategy):
|
|||
if not base_key:
|
||||
missing.append("ANTHROPIC_API_KEY environment variable is not set")
|
||||
else:
|
||||
missing.append("EXPERT_ANTHROPIC_API_KEY environment variable is not set")
|
||||
missing.append(
|
||||
"EXPERT_ANTHROPIC_API_KEY environment variable is not set"
|
||||
)
|
||||
else:
|
||||
key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not key:
|
||||
|
|
|
|||
|
|
@ -2,17 +2,17 @@
|
|||
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_task_completed,
|
||||
mark_plan_completed,
|
||||
reset_completion_flags,
|
||||
is_completed,
|
||||
get_completion_message,
|
||||
get_current_context,
|
||||
is_completed,
|
||||
mark_plan_completed,
|
||||
mark_task_completed,
|
||||
reset_completion_flags,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -192,22 +192,12 @@ class TestUtilityFunctions:
|
|||
# These should have safe default returns
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
|
||||
|
||||
"""Unit tests for the agent_context module."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_task_completed,
|
||||
mark_plan_completed,
|
||||
reset_completion_flags,
|
||||
is_completed,
|
||||
get_completion_message,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class TestAgentContext:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
"""Unit tests for agent_should_exit functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_should_exit,
|
||||
should_exit,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ from ra_aid.tools import (
|
|||
emit_related_files,
|
||||
emit_research_notes,
|
||||
file_str_replace,
|
||||
put_complete_file_contents,
|
||||
fuzzy_find_project_files,
|
||||
list_directory_tree,
|
||||
put_complete_file_contents,
|
||||
read_file_tool,
|
||||
ripgrep_search,
|
||||
run_programming_task,
|
||||
|
|
@ -26,7 +26,7 @@ from ra_aid.tools.agent import (
|
|||
request_task_implementation,
|
||||
request_web_research,
|
||||
)
|
||||
from ra_aid.tools.memory import one_shot_completed, plan_implementation_completed
|
||||
from ra_aid.tools.memory import plan_implementation_completed
|
||||
|
||||
|
||||
def set_modification_tools(use_aider=False):
|
||||
|
|
@ -46,7 +46,9 @@ def set_modification_tools(use_aider=False):
|
|||
|
||||
# Read-only tools that don't modify system state
|
||||
def get_read_only_tools(
|
||||
human_interaction: bool = False, web_research_enabled: bool = False, use_aider: bool = False
|
||||
human_interaction: bool = False,
|
||||
web_research_enabled: bool = False,
|
||||
use_aider: bool = False,
|
||||
):
|
||||
"""Get the list of read-only tools, optionally including human interaction tools.
|
||||
|
||||
|
|
@ -100,6 +102,7 @@ def get_all_tools() -> list[BaseTool]:
|
|||
_config = {}
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
_config = _global_memory.get("config", {})
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
@ -137,15 +140,14 @@ def get_research_tools(
|
|||
use_aider = False
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Start with read-only tools
|
||||
tools = get_read_only_tools(
|
||||
human_interaction,
|
||||
web_research_enabled,
|
||||
use_aider=use_aider
|
||||
human_interaction, web_research_enabled, use_aider=use_aider
|
||||
).copy()
|
||||
|
||||
tools.extend(RESEARCH_TOOLS)
|
||||
|
|
@ -179,14 +181,14 @@ def get_planning_tools(
|
|||
use_aider = False
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Start with read-only tools
|
||||
tools = get_read_only_tools(
|
||||
web_research_enabled=web_research_enabled,
|
||||
use_aider=use_aider
|
||||
web_research_enabled=web_research_enabled, use_aider=use_aider
|
||||
).copy()
|
||||
|
||||
# Add planning-specific tools
|
||||
|
|
@ -218,14 +220,14 @@ def get_implementation_tools(
|
|||
use_aider = False
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Start with read-only tools
|
||||
tools = get_read_only_tools(
|
||||
web_research_enabled=web_research_enabled,
|
||||
use_aider=use_aider
|
||||
web_research_enabled=web_research_enabled, use_aider=use_aider
|
||||
).copy()
|
||||
|
||||
# Add modification tools since it's not research-only
|
||||
|
|
|
|||
|
|
@ -5,7 +5,12 @@ from typing import Any, Dict, List, Union
|
|||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
||||
from ra_aid.agent_context import get_completion_message, get_crash_message, is_crashed, reset_completion_flags
|
||||
from ra_aid.agent_context import (
|
||||
get_completion_message,
|
||||
get_crash_message,
|
||||
is_crashed,
|
||||
reset_completion_flags,
|
||||
)
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
|
@ -85,7 +90,9 @@ def request_research(query: str) -> ResearchResult:
|
|||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
# Get completion message if available
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
completion_message = get_completion_message() or (
|
||||
"Task was completed successfully." if success else None
|
||||
)
|
||||
|
||||
work_log = get_work_log()
|
||||
|
||||
|
|
@ -149,7 +156,9 @@ def request_web_research(query: str) -> ResearchResult:
|
|||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
# Get completion message if available
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
completion_message = get_completion_message() or (
|
||||
"Task was completed successfully." if success else None
|
||||
)
|
||||
|
||||
work_log = get_work_log()
|
||||
|
||||
|
|
@ -215,7 +224,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
|||
reason = f"error: {str(e)}"
|
||||
|
||||
# Get completion message if available
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
completion_message = get_completion_message() or (
|
||||
"Task was completed successfully." if success else None
|
||||
)
|
||||
|
||||
work_log = get_work_log()
|
||||
|
||||
|
|
@ -293,7 +304,9 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
reason = f"error: {str(e)}"
|
||||
|
||||
# Get completion message if available
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
completion_message = get_completion_message() or (
|
||||
"Task was completed successfully." if success else None
|
||||
)
|
||||
|
||||
# Get and reset work log if at root depth
|
||||
work_log = get_work_log()
|
||||
|
|
@ -324,15 +337,21 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
# Add header and completion message
|
||||
markdown_parts.append("# Task Implementation")
|
||||
if response_data.get("completion_message"):
|
||||
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
|
||||
markdown_parts.append(
|
||||
f"\n## Completion Message\n\n{response_data['completion_message']}"
|
||||
)
|
||||
|
||||
# Add crash information if applicable
|
||||
if response_data.get("agent_crashed"):
|
||||
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
|
||||
markdown_parts.append(
|
||||
f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}"
|
||||
)
|
||||
|
||||
# Add success status
|
||||
status = "Success" if response_data.get("success", False) else "Failed"
|
||||
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||
reason_text = (
|
||||
f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||
)
|
||||
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
|
||||
|
||||
# Add key facts
|
||||
|
|
@ -351,7 +370,9 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
# Add work log
|
||||
if response_data.get("work_log"):
|
||||
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
|
||||
markdown_parts.append(f"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**")
|
||||
markdown_parts.append(
|
||||
"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**"
|
||||
)
|
||||
|
||||
# Join all parts into a single markdown string
|
||||
markdown_output = "".join(markdown_parts)
|
||||
|
|
@ -403,7 +424,9 @@ def request_implementation(task_spec: str) -> str:
|
|||
reason = f"error: {str(e)}"
|
||||
|
||||
# Get completion message if available
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
completion_message = get_completion_message() or (
|
||||
"Task was completed successfully." if success else None
|
||||
)
|
||||
|
||||
# Get and reset work log if at root depth
|
||||
work_log = get_work_log()
|
||||
|
|
@ -434,15 +457,21 @@ def request_implementation(task_spec: str) -> str:
|
|||
# Add header and completion message
|
||||
markdown_parts.append("# Implementation Plan")
|
||||
if response_data.get("completion_message"):
|
||||
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
|
||||
markdown_parts.append(
|
||||
f"\n## Completion Message\n\n{response_data['completion_message']}"
|
||||
)
|
||||
|
||||
# Add crash information if applicable
|
||||
if response_data.get("agent_crashed"):
|
||||
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
|
||||
markdown_parts.append(
|
||||
f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}"
|
||||
)
|
||||
|
||||
# Add success status
|
||||
status = "Success" if response_data.get("success", False) else "Failed"
|
||||
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||
reason_text = (
|
||||
f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||
)
|
||||
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
|
||||
|
||||
# Add key facts
|
||||
|
|
@ -461,7 +490,9 @@ def request_implementation(task_spec: str) -> str:
|
|||
# Add work log
|
||||
if response_data.get("work_log"):
|
||||
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
|
||||
markdown_parts.append(f"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**")
|
||||
markdown_parts.append(
|
||||
"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**"
|
||||
)
|
||||
|
||||
# Join all parts into a single markdown string
|
||||
markdown_output = "".join(markdown_parts)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
import magic
|
||||
|
|
@ -12,7 +12,11 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ra_aid.agent_context import mark_task_completed, mark_plan_completed, mark_should_exit
|
||||
from ra_aid.agent_context import (
|
||||
mark_plan_completed,
|
||||
mark_should_exit,
|
||||
mark_task_completed,
|
||||
)
|
||||
|
||||
|
||||
class WorkLogEntry(TypedDict):
|
||||
|
|
@ -478,18 +482,36 @@ def is_binary_file(filepath):
|
|||
if magic:
|
||||
try:
|
||||
mime = magic.from_file(filepath, mime=True)
|
||||
return not mime.startswith('text/')
|
||||
except Exception:
|
||||
# Fallback if magic fails
|
||||
return False
|
||||
else:
|
||||
# Basic binary detection if magic is not available
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
f.read(1024) # Try to read as text
|
||||
file_type = magic.from_file(filepath)
|
||||
|
||||
if not mime.startswith("text/"):
|
||||
return True
|
||||
|
||||
if "ASCII text" in file_type:
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return _is_binary_fallback(filepath)
|
||||
else:
|
||||
return _is_binary_fallback(filepath)
|
||||
|
||||
|
||||
def _is_binary_fallback(filepath):
|
||||
"""Fallback method to detect binary files without using magic."""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
chunk = f.read(1024)
|
||||
|
||||
# Check for null bytes which indicate binary content
|
||||
if "\0" in chunk:
|
||||
return True
|
||||
|
||||
# If we can read it as text without errors, it's probably not binary
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
# If we can't decode as UTF-8, it's likely binary
|
||||
return True
|
||||
|
||||
|
||||
def get_work_log() -> str:
|
||||
|
|
|
|||
|
|
@ -159,7 +159,9 @@ def run_programming_task(
|
|||
|
||||
# Return structured output
|
||||
return {
|
||||
"output": (truncate_output(result[0].decode()) + extra_ins) if result[0] else "",
|
||||
"output": (truncate_output(result[0].decode()) + extra_ins)
|
||||
if result[0]
|
||||
else "",
|
||||
"return_code": result[1],
|
||||
"success": result[1] == 0,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,9 +61,7 @@ def put_complete_file_contents(
|
|||
f"at {filepath} in {result['elapsed_time']:.3f}s"
|
||||
)
|
||||
|
||||
logging.debug(
|
||||
f"File write complete: {bytes_written} bytes in {elapsed:.2f}s"
|
||||
)
|
||||
logging.debug(f"File write complete: {bytes_written} bytes in {elapsed:.2f}s")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
|
|||
|
|
@ -1,23 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from typing import List
|
||||
import json
|
||||
import threading
|
||||
import queue
|
||||
import traceback
|
||||
import shutil
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.__stderr__) # Use the real stderr
|
||||
]
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,12 +24,12 @@ project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -55,10 +53,12 @@ app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
|||
# Store active WebSocket connections
|
||||
active_connections: List[WebSocket] = []
|
||||
|
||||
|
||||
def run_ra_aid(message_content, output_queue):
|
||||
"""Run ra-aid in a separate thread"""
|
||||
try:
|
||||
import ra_aid.__main__
|
||||
|
||||
logger.info("Successfully imported ra_aid.__main__")
|
||||
|
||||
# Override sys.argv
|
||||
|
|
@ -78,23 +78,23 @@ def run_ra_aid(message_content, output_queue):
|
|||
logger.debug(f"Raw output: {repr(text)}")
|
||||
|
||||
# Check if this is a box drawing character
|
||||
if any(c in text for c in '╭╮╰╯│─'):
|
||||
if any(c in text for c in "╭╮╰╯│─"):
|
||||
self.box_start = True
|
||||
self.buffer.append(text)
|
||||
elif self.box_start and text.strip():
|
||||
self.buffer.append(text)
|
||||
if '╯' in text: # End of box
|
||||
full_text = ''.join(self.buffer)
|
||||
if "╯" in text: # End of box
|
||||
full_text = "".join(self.buffer)
|
||||
# Extract content from inside the box
|
||||
lines = full_text.split('\n')
|
||||
lines = full_text.split("\n")
|
||||
content_lines = []
|
||||
for line in lines:
|
||||
# Remove box characters and leading/trailing spaces
|
||||
clean_line = line.strip('╭╮╰╯│─ ')
|
||||
clean_line = line.strip("╭╮╰╯│─ ")
|
||||
if clean_line:
|
||||
content_lines.append(clean_line)
|
||||
if content_lines:
|
||||
self.queue.put('\n'.join(content_lines))
|
||||
self.queue.put("\n".join(content_lines))
|
||||
self.buffer = []
|
||||
self.box_start = False
|
||||
elif not self.box_start and text.strip():
|
||||
|
|
@ -102,17 +102,17 @@ def run_ra_aid(message_content, output_queue):
|
|||
|
||||
def flush(self):
|
||||
if self.buffer:
|
||||
full_text = ''.join(self.buffer)
|
||||
full_text = "".join(self.buffer)
|
||||
# Extract content from partial box
|
||||
lines = full_text.split('\n')
|
||||
lines = full_text.split("\n")
|
||||
content_lines = []
|
||||
for line in lines:
|
||||
# Remove box characters and leading/trailing spaces
|
||||
clean_line = line.strip('╭╮╰╯│─ ')
|
||||
clean_line = line.strip("╭╮╰╯│─ ")
|
||||
if clean_line:
|
||||
content_lines.append(clean_line)
|
||||
if content_lines:
|
||||
self.queue.put('\n'.join(content_lines))
|
||||
self.queue.put("\n".join(content_lines))
|
||||
self.buffer = []
|
||||
self.box_start = False
|
||||
|
||||
|
|
@ -144,6 +144,7 @@ def run_ra_aid(message_content, output_queue):
|
|||
traceback.print_exc(file=sys.__stderr__)
|
||||
output_queue.put(f"Error: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def get_root(request: Request):
|
||||
"""Serve the index.html file with port parameter."""
|
||||
|
|
@ -151,6 +152,7 @@ async def get_root(request: Request):
|
|||
"index.html", {"request": request, "server_port": request.url.port or 8080}
|
||||
)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
|
@ -170,7 +172,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||
output_queue = queue.Queue()
|
||||
|
||||
# Create and start thread
|
||||
thread = threading.Thread(target=run_ra_aid, args=(content, output_queue))
|
||||
thread = threading.Thread(
|
||||
target=run_ra_aid, args=(content, output_queue)
|
||||
)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
|
|
@ -183,17 +187,21 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||
line = output_queue.get(timeout=0.1)
|
||||
if line and line.strip(): # Only send non-empty messages
|
||||
logger.debug(f"WebSocket sending: {repr(line)}")
|
||||
await websocket.send_json({
|
||||
"type": "chunk",
|
||||
"chunk": {
|
||||
"agent": {
|
||||
"messages": [{
|
||||
"content": line.strip(),
|
||||
"status": "info"
|
||||
}]
|
||||
}
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "chunk",
|
||||
"chunk": {
|
||||
"agent": {
|
||||
"messages": [
|
||||
{
|
||||
"content": line.strip(),
|
||||
"status": "info",
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
|
|
@ -211,10 +219,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||
except Exception as e:
|
||||
error_msg = f"Error running ra-aid: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": error_msg
|
||||
})
|
||||
await websocket.send_json({"type": "error", "message": error_msg})
|
||||
|
||||
logger.info("Waiting for message...")
|
||||
|
||||
|
|
@ -243,6 +248,7 @@ def run_server(host: str = "0.0.0.0", port: int = 8080):
|
|||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"
|
||||
|
|
|
|||
|
|
@ -1,18 +1,23 @@
|
|||
"""
|
||||
Tests for the database connection module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import peewee
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db, get_db, close_db,
|
||||
db_var, DatabaseManager, logger
|
||||
DatabaseManager,
|
||||
close_db,
|
||||
db_var,
|
||||
get_db,
|
||||
init_db,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""
|
||||
|
|
@ -48,20 +53,23 @@ def cleanup_db():
|
|||
# Log but don't fail if cleanup has issues
|
||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger():
|
||||
"""Mock the logger to test for output messages."""
|
||||
with patch('ra_aid.database.connection.logger') as mock:
|
||||
with patch("ra_aid.database.connection.logger") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
"""Tests for the init_db function."""
|
||||
|
||||
def test_init_db_default(self, cleanup_db):
|
||||
"""Test init_db with default parameters."""
|
||||
db = init_db()
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
|
|
@ -73,7 +81,7 @@ class TestInitDb:
|
|||
db = init_db(in_memory=True)
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
def test_init_db_reuses_connection(self, cleanup_db):
|
||||
|
|
@ -91,8 +99,10 @@ class TestInitDb:
|
|||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
"""Tests for the get_db function."""
|
||||
|
||||
def test_get_db_creates_connection(self, cleanup_db):
|
||||
"""Test that get_db creates a new connection if none exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
|
|
@ -100,7 +110,7 @@ class TestGetDb:
|
|||
db = get_db()
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
def test_get_db_reuses_connection(self, cleanup_db):
|
||||
|
|
@ -118,8 +128,10 @@ class TestGetDb:
|
|||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
|
||||
class TestCloseDb:
|
||||
"""Tests for the close_db function."""
|
||||
|
||||
def test_close_db(self, cleanup_db):
|
||||
"""Test that close_db closes an open connection."""
|
||||
db = init_db()
|
||||
|
|
@ -142,14 +154,16 @@ class TestCloseDb:
|
|||
# This should not raise an exception
|
||||
close_db()
|
||||
|
||||
|
||||
class TestDatabaseManager:
|
||||
"""Tests for the DatabaseManager class."""
|
||||
|
||||
def test_database_manager_default(self, cleanup_db):
|
||||
"""Test DatabaseManager with default parameters."""
|
||||
with DatabaseManager() as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
|
|
@ -163,7 +177,7 @@ class TestDatabaseManager:
|
|||
with DatabaseManager(in_memory=True) as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
|
|
|
|||
|
|
@ -5,22 +5,19 @@ Tests for the database migrations module.
|
|||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, call, PropertyMock
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
import peewee
|
||||
from peewee_migrate import Router
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.migrations import (
|
||||
MigrationManager,
|
||||
init_migrations,
|
||||
ensure_migrations_applied,
|
||||
create_new_migration,
|
||||
get_migration_status,
|
||||
MIGRATIONS_DIRNAME,
|
||||
MIGRATIONS_TABLE
|
||||
MIGRATIONS_TABLE,
|
||||
MigrationManager,
|
||||
create_new_migration,
|
||||
ensure_migrations_applied,
|
||||
get_migration_status,
|
||||
init_migrations,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -56,7 +53,7 @@ def cleanup_db():
|
|||
@pytest.fixture
|
||||
def mock_logger():
|
||||
"""Mock the logger to test for output messages."""
|
||||
with patch('ra_aid.database.migrations.logger') as mock:
|
||||
with patch("ra_aid.database.migrations.logger") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
|
|
@ -83,7 +80,7 @@ def temp_migrations_dir(temp_dir):
|
|||
@pytest.fixture
|
||||
def mock_router():
|
||||
"""Mock the peewee_migrate Router class."""
|
||||
with patch('ra_aid.database.migrations.Router') as mock:
|
||||
with patch("ra_aid.database.migrations.Router") as mock:
|
||||
# Configure the mock router
|
||||
mock_instance = MagicMock()
|
||||
mock.return_value = mock_instance
|
||||
|
|
@ -114,8 +111,12 @@ class TestMigrationManager:
|
|||
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
||||
|
||||
# Verify router initialization was logged
|
||||
mock_logger.debug.assert_any_call(f"Using migrations directory: {migrations_dir}")
|
||||
mock_logger.debug.assert_any_call(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
|
||||
mock_logger.debug.assert_any_call(
|
||||
f"Using migrations directory: {migrations_dir}"
|
||||
)
|
||||
mock_logger.debug.assert_any_call(
|
||||
f"Initialized migration router with table: {MIGRATIONS_TABLE}"
|
||||
)
|
||||
|
||||
def test_ensure_migrations_dir(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test _ensure_migrations_dir creates directory if it doesn't exist."""
|
||||
|
|
@ -131,23 +132,29 @@ class TestMigrationManager:
|
|||
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
||||
|
||||
# Verify creation was logged
|
||||
mock_logger.debug.assert_any_call(f"Creating migrations directory at: {migrations_dir}")
|
||||
mock_logger.debug.assert_any_call(
|
||||
f"Creating migrations directory at: {migrations_dir}"
|
||||
)
|
||||
|
||||
def test_ensure_migrations_dir_error(self, cleanup_db, mock_logger):
|
||||
"""Test _ensure_migrations_dir handles errors."""
|
||||
# Mock os.makedirs to raise an exception
|
||||
with patch('pathlib.Path.mkdir', side_effect=PermissionError("Permission denied")):
|
||||
with patch(
|
||||
"pathlib.Path.mkdir", side_effect=PermissionError("Permission denied")
|
||||
):
|
||||
# Set up test paths - use a path that would require elevated permissions
|
||||
db_path = "/root/test.db"
|
||||
migrations_dir = "/root/migrations"
|
||||
|
||||
# Initialize manager should raise an exception
|
||||
with pytest.raises(Exception):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
manager = MigrationManager(
|
||||
db_path=db_path, migrations_dir=migrations_dir
|
||||
)
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with(
|
||||
f"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'"
|
||||
"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'"
|
||||
)
|
||||
|
||||
def test_init_router(self, cleanup_db, temp_dir, mock_router):
|
||||
|
|
@ -160,7 +167,7 @@ class TestMigrationManager:
|
|||
os.makedirs(migrations_dir, exist_ok=True)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Verify router was initialized
|
||||
|
|
@ -173,7 +180,7 @@ class TestMigrationManager:
|
|||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call check_migrations
|
||||
|
|
@ -201,7 +208,7 @@ class TestMigrationManager:
|
|||
mock_router.done = []
|
||||
|
||||
# Initialize manager with the mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Directly call check_migrations on the manager with the mocked router
|
||||
|
|
@ -212,7 +219,9 @@ class TestMigrationManager:
|
|||
assert pending == []
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to check migrations: Test error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to check migrations: Test error"
|
||||
)
|
||||
|
||||
def test_apply_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
||||
"""Test apply_migrations applies pending migrations."""
|
||||
|
|
@ -221,7 +230,7 @@ class TestMigrationManager:
|
|||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call apply_migrations
|
||||
|
|
@ -236,7 +245,9 @@ class TestMigrationManager:
|
|||
# Verify logging
|
||||
mock_logger.info.assert_any_call("Applying 1 pending migrations...")
|
||||
mock_logger.info.assert_any_call("Applying migration: 002_add_users")
|
||||
mock_logger.info.assert_any_call("Successfully applied migration: 002_add_users")
|
||||
mock_logger.info.assert_any_call(
|
||||
"Successfully applied migration: 002_add_users"
|
||||
)
|
||||
mock_logger.info.assert_any_call("Successfully applied 1 migrations")
|
||||
|
||||
def test_apply_migrations_no_pending(self, cleanup_db, temp_dir, mock_logger):
|
||||
|
|
@ -251,7 +262,7 @@ class TestMigrationManager:
|
|||
mock_router.done = ["001_initial"]
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call apply_migrations
|
||||
|
|
@ -279,7 +290,7 @@ class TestMigrationManager:
|
|||
mock_router.run.side_effect = Exception("Migration error")
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call apply_migrations
|
||||
|
|
@ -300,7 +311,7 @@ class TestMigrationManager:
|
|||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call create_migration
|
||||
|
|
@ -315,7 +326,9 @@ class TestMigrationManager:
|
|||
|
||||
# Verify logging
|
||||
mock_logger.info.assert_any_call(f"Creating new migration: {result}")
|
||||
mock_logger.info.assert_any_call(f"Successfully created migration: {result}")
|
||||
mock_logger.info.assert_any_call(
|
||||
f"Successfully created migration: {result}"
|
||||
)
|
||||
|
||||
def test_create_migration_error(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test create_migration handles errors."""
|
||||
|
|
@ -328,7 +341,7 @@ class TestMigrationManager:
|
|||
mock_router.create.side_effect = Exception("Creation error")
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call create_migration
|
||||
|
|
@ -338,7 +351,9 @@ class TestMigrationManager:
|
|||
assert result is None
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to create migration: Creation error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to create migration: Creation error"
|
||||
)
|
||||
|
||||
def test_get_migration_status(self, cleanup_db, temp_dir, mock_router):
|
||||
"""Test get_migration_status returns correct status information."""
|
||||
|
|
@ -347,7 +362,7 @@ class TestMigrationManager:
|
|||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call get_migration_status
|
||||
|
|
@ -372,7 +387,7 @@ class TestMigrationFunctions:
|
|||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Call init_migrations
|
||||
with patch('ra_aid.database.migrations.MigrationManager') as mock_manager:
|
||||
with patch("ra_aid.database.migrations.MigrationManager") as mock_manager:
|
||||
mock_manager.return_value = MagicMock()
|
||||
|
||||
manager = init_migrations(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
|
@ -388,7 +403,9 @@ class TestMigrationFunctions:
|
|||
mock_manager.apply_migrations.return_value = True
|
||||
|
||||
# Call ensure_migrations_applied
|
||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
||||
with patch(
|
||||
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
|
||||
):
|
||||
result = ensure_migrations_applied()
|
||||
|
||||
# Verify result
|
||||
|
|
@ -400,15 +417,19 @@ class TestMigrationFunctions:
|
|||
def test_ensure_migrations_applied_error(self, cleanup_db, mock_logger):
|
||||
"""Test ensure_migrations_applied handles errors."""
|
||||
# Call ensure_migrations_applied with an exception
|
||||
with patch('ra_aid.database.migrations.init_migrations',
|
||||
side_effect=Exception("Test error")):
|
||||
with patch(
|
||||
"ra_aid.database.migrations.init_migrations",
|
||||
side_effect=Exception("Test error"),
|
||||
):
|
||||
result = ensure_migrations_applied()
|
||||
|
||||
# Verify result is False on error
|
||||
assert result is False
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to apply migrations: Test error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to apply migrations: Test error"
|
||||
)
|
||||
|
||||
def test_create_new_migration(self, cleanup_db, mock_logger):
|
||||
"""Test create_new_migration creates a new migration."""
|
||||
|
|
@ -417,27 +438,35 @@ class TestMigrationFunctions:
|
|||
mock_manager.create_migration.return_value = "20250226_123456_test_migration"
|
||||
|
||||
# Call create_new_migration
|
||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
||||
with patch(
|
||||
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
|
||||
):
|
||||
result = create_new_migration("test_migration", auto=True)
|
||||
|
||||
# Verify result
|
||||
assert result == "20250226_123456_test_migration"
|
||||
|
||||
# Verify migration was created
|
||||
mock_manager.create_migration.assert_called_once_with("test_migration", True)
|
||||
mock_manager.create_migration.assert_called_once_with(
|
||||
"test_migration", True
|
||||
)
|
||||
|
||||
def test_create_new_migration_error(self, cleanup_db, mock_logger):
|
||||
"""Test create_new_migration handles errors."""
|
||||
# Call create_new_migration with an exception
|
||||
with patch('ra_aid.database.migrations.init_migrations',
|
||||
side_effect=Exception("Test error")):
|
||||
with patch(
|
||||
"ra_aid.database.migrations.init_migrations",
|
||||
side_effect=Exception("Test error"),
|
||||
):
|
||||
result = create_new_migration("test_migration", auto=True)
|
||||
|
||||
# Verify result is None on error
|
||||
assert result is None
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to create migration: Test error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to create migration: Test error"
|
||||
)
|
||||
|
||||
def test_get_migration_status(self, cleanup_db, mock_logger):
|
||||
"""Test get_migration_status returns correct status information."""
|
||||
|
|
@ -449,11 +478,13 @@ class TestMigrationFunctions:
|
|||
"applied": ["001_initial", "002_add_users"],
|
||||
"pending": ["003_add_profiles"],
|
||||
"migrations_dir": "/test/migrations",
|
||||
"db_path": "/test/db.sqlite"
|
||||
"db_path": "/test/db.sqlite",
|
||||
}
|
||||
|
||||
# Call get_migration_status
|
||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
||||
with patch(
|
||||
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
|
||||
):
|
||||
status = get_migration_status()
|
||||
|
||||
# Verify status information
|
||||
|
|
@ -470,8 +501,10 @@ class TestMigrationFunctions:
|
|||
def test_get_migration_status_error(self, cleanup_db, mock_logger):
|
||||
"""Test get_migration_status handles errors."""
|
||||
# Call get_migration_status with an exception
|
||||
with patch('ra_aid.database.migrations.init_migrations',
|
||||
side_effect=Exception("Test error")):
|
||||
with patch(
|
||||
"ra_aid.database.migrations.init_migrations",
|
||||
side_effect=Exception("Test error"),
|
||||
):
|
||||
status = get_migration_status()
|
||||
|
||||
# Verify default status on error
|
||||
|
|
@ -482,7 +515,9 @@ class TestMigrationFunctions:
|
|||
assert status["pending"] == []
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to get migration status: Test error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to get migration status: Test error"
|
||||
)
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
|
|
@ -502,7 +537,9 @@ class TestIntegration:
|
|||
pass
|
||||
|
||||
# Initialize migration manager
|
||||
manager = MigrationManager(db_path=":memory:", migrations_dir=migrations_dir)
|
||||
manager = MigrationManager(
|
||||
db_path=":memory:", migrations_dir=migrations_dir
|
||||
)
|
||||
|
||||
# Create a test migration
|
||||
migration_name = manager.create_migration("test_migration", auto=False)
|
||||
|
|
@ -525,7 +562,9 @@ def rollback(migrator, database, fake=False, **kwargs):
|
|||
applied, pending = manager.check_migrations()
|
||||
assert len(applied) == 0
|
||||
assert len(pending) == 1
|
||||
assert migration_name in pending[0] # Instead of exact equality, check if name is contained
|
||||
assert (
|
||||
migration_name in pending[0]
|
||||
) # Instead of exact equality, check if name is contained
|
||||
|
||||
# Apply migrations
|
||||
result = manager.apply_migrations()
|
||||
|
|
@ -535,7 +574,9 @@ def rollback(migrator, database, fake=False, **kwargs):
|
|||
applied, pending = manager.check_migrations()
|
||||
assert len(applied) == 1
|
||||
assert len(pending) == 0
|
||||
assert migration_name in applied[0] # Instead of exact equality, check if name is contained
|
||||
assert (
|
||||
migration_name in applied[0]
|
||||
) # Instead of exact equality, check if name is contained
|
||||
|
||||
# Verify migration status
|
||||
status = manager.get_migration_status()
|
||||
|
|
|
|||
|
|
@ -4,13 +4,11 @@ Tests for the database models module.
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import peewee
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.connection import db_var, init_db
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.connection import (
|
||||
db_var, get_db, init_db, close_db
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -85,6 +83,7 @@ def test_base_model_save_updates_timestamps(setup_test_model):
|
|||
|
||||
# Wait a moment to ensure timestamps would be different
|
||||
import time
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
# Update the instance
|
||||
|
|
@ -117,13 +116,15 @@ def test_base_model_get_or_create(setup_test_model):
|
|||
assert instance3.id != instance.id
|
||||
|
||||
|
||||
@patch('ra_aid.database.models.logger')
|
||||
@patch("ra_aid.database.models.logger")
|
||||
def test_base_model_get_or_create_handles_errors(mock_logger, setup_test_model):
|
||||
"""Test that get_or_create handles database errors properly."""
|
||||
TestModel = setup_test_model
|
||||
|
||||
# Mock the parent get_or_create to raise a DatabaseError
|
||||
with patch('peewee.Model.get_or_create', side_effect=peewee.DatabaseError("Test error")):
|
||||
with patch(
|
||||
"peewee.Model.get_or_create", side_effect=peewee.DatabaseError("Test error")
|
||||
):
|
||||
# Call should raise the error
|
||||
with pytest.raises(peewee.DatabaseError):
|
||||
TestModel.get_or_create(name="test")
|
||||
|
|
|
|||
|
|
@ -2,14 +2,12 @@
|
|||
Tests for the database utils module.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import peewee
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
db_var, get_db, init_db, close_db
|
||||
)
|
||||
from ra_aid.database.connection import db_var, init_db
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
||||
|
||||
|
|
@ -46,7 +44,7 @@ def cleanup_db():
|
|||
@pytest.fixture
|
||||
def mock_logger():
|
||||
"""Mock the logger to test for output messages."""
|
||||
with patch('ra_aid.database.utils.logger') as mock:
|
||||
with patch("ra_aid.database.utils.logger") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
|
|
@ -101,22 +99,10 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
|||
assert count == 1
|
||||
|
||||
|
||||
def test_ensure_tables_created_no_models(cleanup_db, mock_logger):
|
||||
"""Test ensure_tables_created with no models."""
|
||||
# Initialize the database in memory
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Mock the import to simulate no models found
|
||||
with patch('ra_aid.database.utils.importlib.import_module', side_effect=ImportError("No module")):
|
||||
# Call ensure_tables_created with no models
|
||||
ensure_tables_created()
|
||||
|
||||
# Verify warning message was logged
|
||||
mock_logger.warning.assert_called_with("No models found to create tables for")
|
||||
|
||||
|
||||
@patch('ra_aid.database.utils.get_db')
|
||||
def test_ensure_tables_created_database_error(mock_get_db, setup_test_model, cleanup_db, mock_logger):
|
||||
@patch("ra_aid.database.utils.get_db")
|
||||
def test_ensure_tables_created_database_error(
|
||||
mock_get_db, setup_test_model, cleanup_db, mock_logger
|
||||
):
|
||||
"""Test ensure_tables_created handles database errors."""
|
||||
# Get the TestModel class from the fixture
|
||||
TestModel = setup_test_model
|
||||
|
|
@ -135,7 +121,9 @@ def test_ensure_tables_created_database_error(mock_get_db, setup_test_model, cle
|
|||
ensure_tables_created([TestModel])
|
||||
|
||||
# Verify error message was logged
|
||||
mock_logger.error.assert_called_with("Database Error: Failed to create tables: Test database error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Database Error: Failed to create tables: Test database error"
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_count(setup_test_model, mock_logger):
|
||||
|
|
@ -157,7 +145,7 @@ def test_get_model_count(setup_test_model, mock_logger):
|
|||
assert count == 2
|
||||
|
||||
|
||||
@patch('peewee.ModelSelect.count')
|
||||
@patch("peewee.ModelSelect.count")
|
||||
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
|
||||
"""Test get_model_count handles database errors."""
|
||||
# Get the TestModel class from the fixture
|
||||
|
|
@ -170,7 +158,9 @@ def test_get_model_count_database_error(mock_count, setup_test_model, mock_logge
|
|||
count = get_model_count(TestModel)
|
||||
|
||||
# Verify error message was logged
|
||||
mock_logger.error.assert_called_with("Database Error: Failed to count records: Test count error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Database Error: Failed to count records: Test count error"
|
||||
)
|
||||
|
||||
# Verify the function returns 0 on error
|
||||
assert count == 0
|
||||
|
|
@ -192,13 +182,15 @@ def test_truncate_table(setup_test_model, mock_logger):
|
|||
truncate_table(TestModel)
|
||||
|
||||
# Verify success message was logged
|
||||
mock_logger.info.assert_called_with(f"Successfully truncated table for {TestModel.__name__}")
|
||||
mock_logger.info.assert_called_with(
|
||||
f"Successfully truncated table for {TestModel.__name__}"
|
||||
)
|
||||
|
||||
# Verify all records were deleted
|
||||
assert TestModel.select().count() == 0
|
||||
|
||||
|
||||
@patch('ra_aid.database.models.BaseModel.delete')
|
||||
@patch("ra_aid.database.models.BaseModel.delete")
|
||||
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
|
||||
"""Test truncate_table handles database errors."""
|
||||
# Get the TestModel class from the fixture
|
||||
|
|
@ -217,4 +209,6 @@ def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logge
|
|||
truncate_table(TestModel)
|
||||
|
||||
# Verify error message was logged
|
||||
mock_logger.error.assert_called_with("Database Error: Failed to truncate table: Test delete error")
|
||||
mock_logger.error.assert_called_with(
|
||||
"Database Error: Failed to truncate table: Test delete error"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
|
||||
This is a test ascii file.
|
||||
|
||||
ASCII text, with very long lines (885), with CRLF line terminators
|
||||
|
||||
Mumblecore craft beer taxidermy, flannel YOLO pug brunch ugh you probably haven't heard of them art party next level Pinterest squid pork belly. Yr next level Carles, Thundercats dreamcatcher scenester master cleanse bitters disrupt tote bag keffiyeh narwhal organic salvia cray. Whatever heirloom Vice art party pickled, try-hard Williamsburg. Authentic pickled pop-up, letterpress bicycle rights cornhole vinyl Etsy readymade disrupt shabby chic Pitchfork keffiyeh. Master cleanse small batch keytar biodiesel Brooklyn, meggings four loko try-hard McSweeney's vinyl tattooed. Cred Schlitz selvage, tousled Odd Future literally before they sold out synth cardigan retro banh mi next level jean shorts meggings fap. Pork belly four dollar toast quinoa, stumptown taxidermy sriracha whatever you probably haven't heard of them squid single-origin coffee freegan disrupt cliche cardigan.
|
||||
|
||||
Bicycle rights cold-pressed Pinterest, beard butcher pickled pop-up synth DIY hashtag. Austin fanny pack farm-to-table keytar, kitsch fap tousled trust fund swag irony +1. Viral fanny pack vinyl, master cleanse 3 wolf moon readymade occupy before they sold out YOLO meggings XOXO art party fap try-hard. Photo booth you probably haven't heard of them artisan pickled Brooklyn cred umami meh, heirloom cray raw denim tousled drinking vinegar. Gentrify Williamsburg iPhone messenger bag heirloom, swag quinoa ennui brunch. Selvage tofu hella gastropub Pinterest, bicycle rights church-key cardigan semiotics cornhole Shoreditch iPhone fixie biodiesel narwhal. Small batch kogi Shoreditch cliche YOLO.
|
||||
|
||||
Literally yr ugh Truffaut raw denim four loko. Vice chia mustache, Intelligentsia authentic taxidermy Truffaut synth health goth. Locavore semiotics occupy, synth 8-bit hoodie umami meh PBR&B Wes Anderson brunch shabby chic Helvetica quinoa. YOLO beard pop-up Neutra PBR&B vinyl fixie, stumptown shabby chic flexitarian umami. Cronut Blue Bottle scenester sriracha keytar PBR ennui flannel VHS swag. Dreamcatcher 3 wolf moon fanny pack, tattooed XOXO bitters High Life fixie 8-bit Austin lomo single-origin coffee put a bird on it. High Life Kickstarter twee Blue Bottle shabby chic, biodiesel heirloom.
|
||||
|
||||
Wayfarers tousled stumptown pop-up slow-carb. Aesthetic American Apparel hoodie irony YOLO. Meggings synth meh, normcore lomo tote bag post-ironic twee sartorial butcher occupy. Tilde photo booth +1 kogi Williamsburg. Pork belly keytar seitan, pug iPhone fingerstache bitters. Ennui Schlitz actually, cardigan fashion axe Helvetica vegan. Swag lumbersexual blog Carles, cred synth asymmetrical heirloom Tumblr bitters letterpress aesthetic.
|
||||
|
|
@ -1,12 +1,17 @@
|
|||
"""Tests for Windows-specific functionality."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.proc.interactive import (
|
||||
create_process,
|
||||
get_terminal_size,
|
||||
run_interactive_command,
|
||||
)
|
||||
|
||||
from ra_aid.proc.interactive import get_terminal_size, create_process, run_interactive_command
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests")
|
||||
class TestWindowsCompatibility:
|
||||
|
|
@ -14,7 +19,7 @@ class TestWindowsCompatibility:
|
|||
|
||||
def test_get_terminal_size(self):
|
||||
"""Test terminal size detection on Windows."""
|
||||
with patch('shutil.get_terminal_size') as mock_get_size:
|
||||
with patch("shutil.get_terminal_size") as mock_get_size:
|
||||
mock_get_size.return_value = MagicMock(columns=120, lines=30)
|
||||
cols, rows = get_terminal_size()
|
||||
assert cols == 120
|
||||
|
|
@ -23,30 +28,31 @@ class TestWindowsCompatibility:
|
|||
|
||||
def test_create_process(self):
|
||||
"""Test process creation on Windows."""
|
||||
with patch('subprocess.Popen') as mock_popen:
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.returncode = 0
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
proc, _ = create_process(['echo', 'test'])
|
||||
proc, _ = create_process(["echo", "test"])
|
||||
|
||||
assert mock_popen.called
|
||||
args, kwargs = mock_popen.call_args
|
||||
assert kwargs['stdin'] == subprocess.PIPE
|
||||
assert kwargs['stdout'] == subprocess.PIPE
|
||||
assert kwargs['stderr'] == subprocess.STDOUT
|
||||
assert 'startupinfo' in kwargs
|
||||
assert kwargs['startupinfo'].dwFlags & subprocess.STARTF_USESHOWWINDOW
|
||||
assert kwargs["stdin"] == subprocess.PIPE
|
||||
assert kwargs["stdout"] == subprocess.PIPE
|
||||
assert kwargs["stderr"] == subprocess.STDOUT
|
||||
assert "startupinfo" in kwargs
|
||||
assert kwargs["startupinfo"].dwFlags & subprocess.STARTF_USESHOWWINDOW
|
||||
|
||||
def test_run_interactive_command(self):
|
||||
"""Test running an interactive command on Windows."""
|
||||
test_output = "Test output\n"
|
||||
|
||||
with patch('subprocess.Popen') as mock_popen, \
|
||||
patch('pyte.Stream') as mock_stream, \
|
||||
patch('pyte.HistoryScreen') as mock_screen, \
|
||||
patch('threading.Thread') as mock_thread:
|
||||
|
||||
with (
|
||||
patch("subprocess.Popen") as mock_popen,
|
||||
patch("pyte.Stream") as mock_stream,
|
||||
patch("pyte.HistoryScreen") as mock_screen,
|
||||
patch("threading.Thread") as mock_thread,
|
||||
):
|
||||
# Setup mock process
|
||||
mock_process = MagicMock()
|
||||
mock_process.stdout = MagicMock()
|
||||
|
|
@ -67,7 +73,7 @@ class TestWindowsCompatibility:
|
|||
mock_thread.return_value = mock_thread_instance
|
||||
|
||||
# Run the command
|
||||
output, return_code = run_interactive_command(['echo', 'test'])
|
||||
output, return_code = run_interactive_command(["echo", "test"])
|
||||
|
||||
# Verify results
|
||||
assert return_code == 0
|
||||
|
|
@ -80,7 +86,6 @@ class TestWindowsCompatibility:
|
|||
def test_windows_dependencies(self):
|
||||
"""Test that required Windows dependencies are available."""
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
|
||||
# If we get here without ImportError, the test passes
|
||||
assert True
|
||||
|
|
@ -91,11 +96,12 @@ class TestWindowsCompatibility:
|
|||
pytest.skip("Windows-specific test")
|
||||
|
||||
# Test with multiple chunks of output to verify proper handling
|
||||
with patch('subprocess.Popen') as mock_popen, \
|
||||
patch('msvcrt.kbhit', return_value=False), \
|
||||
patch('threading.Thread') as mock_thread, \
|
||||
patch('time.sleep'): # Mock sleep to speed up test
|
||||
|
||||
with (
|
||||
patch("subprocess.Popen") as mock_popen,
|
||||
patch("msvcrt.kbhit", return_value=False),
|
||||
patch("threading.Thread") as mock_thread,
|
||||
patch("time.sleep"),
|
||||
): # Mock sleep to speed up test
|
||||
# Setup mock process
|
||||
mock_process = MagicMock()
|
||||
mock_process.stdout = MagicMock()
|
||||
|
|
@ -110,14 +116,14 @@ class TestWindowsCompatibility:
|
|||
b"First chunk\n",
|
||||
b"Second chunk\n",
|
||||
b"Third chunk with unicode \xe2\x9c\x93\n", # UTF-8 checkmark
|
||||
None # End of output
|
||||
None, # End of output
|
||||
]
|
||||
return MagicMock()
|
||||
|
||||
mock_thread.side_effect = side_effect
|
||||
|
||||
# Run the command
|
||||
output, return_code = run_interactive_command(['test', 'command'])
|
||||
output, return_code = run_interactive_command(["test", "command"])
|
||||
|
||||
# Verify results
|
||||
assert return_code == 0
|
||||
|
|
|
|||
|
|
@ -2,22 +2,19 @@
|
|||
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_task_completed,
|
||||
mark_plan_completed,
|
||||
reset_completion_flags,
|
||||
is_completed,
|
||||
get_completion_message,
|
||||
get_current_context,
|
||||
is_completed,
|
||||
mark_plan_completed,
|
||||
mark_should_exit,
|
||||
mark_task_completed,
|
||||
reset_completion_flags,
|
||||
should_exit,
|
||||
mark_agent_crashed,
|
||||
is_crashed,
|
||||
get_crash_message,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ import pytest
|
|||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from ra_aid.agent_context import agent_context, get_current_context, reset_completion_flags
|
||||
from ra_aid.agent_context import (
|
||||
agent_context,
|
||||
)
|
||||
from ra_aid.agent_utils import (
|
||||
AgentState,
|
||||
create_agent,
|
||||
|
|
@ -116,7 +118,10 @@ def test_create_agent_anthropic(mock_model, mock_memory):
|
|||
|
||||
assert agent == "react_agent"
|
||||
mock_react.assert_called_once_with(
|
||||
mock_model, [], version='v2', state_modifier=mock_react.call_args[1]["state_modifier"]
|
||||
mock_model,
|
||||
[],
|
||||
version="v2",
|
||||
state_modifier=mock_react.call_args[1]["state_modifier"],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -259,7 +264,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
|
|||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "react_agent"
|
||||
mock_react.assert_called_once_with(mock_model, [], version='v2')
|
||||
mock_react.assert_called_once_with(mock_model, [], version="v2")
|
||||
|
||||
|
||||
def test_get_model_token_limit_research(mock_memory):
|
||||
|
|
@ -459,21 +464,27 @@ def test_is_anthropic_claude():
|
|||
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
||||
|
||||
# Test OpenRouter provider cases
|
||||
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-2"})
|
||||
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-instant"})
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
||||
)
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
||||
)
|
||||
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
||||
|
||||
# Test edge cases
|
||||
assert not is_anthropic_claude({}) # Empty config
|
||||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||
assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider
|
||||
assert not is_anthropic_claude(
|
||||
{"provider": "other", "model": "claude-2"}
|
||||
) # Wrong provider
|
||||
|
||||
|
||||
def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
||||
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
|
||||
# Setup mocks for dependencies to isolate our test
|
||||
dummy_agent = Mock()
|
||||
|
|
@ -504,16 +515,27 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
|||
|
||||
# Apply mocks
|
||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||
mock_restore_interrupt_handling,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
|
||||
# First, run without a crash - agent should be run
|
||||
with agent_context() as ctx:
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_context.get_crash_message", mock_get_crash_message
|
||||
)
|
||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||
assert mock_calls["run_agent_stream"] == 1
|
||||
|
||||
|
|
@ -524,7 +546,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
|||
with agent_context() as ctx:
|
||||
mark_agent_crashed("Test crash message")
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_context.get_crash_message", mock_get_crash_message
|
||||
)
|
||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||
# Verify _run_agent_stream was not called
|
||||
assert mock_calls["run_agent_stream"] == 0
|
||||
|
|
@ -534,9 +558,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
|||
|
||||
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
||||
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
||||
from ra_aid.agent_context import agent_context, is_crashed
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.agent_context import agent_context, is_crashed
|
||||
|
||||
# Setup mocks
|
||||
dummy_agent = Mock()
|
||||
|
|
@ -572,14 +596,25 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
|||
|
||||
# Apply mocks
|
||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||
mock_restore_interrupt_handling,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
|
||||
with agent_context() as ctx:
|
||||
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
|
||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||
|
|
@ -594,9 +629,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
|||
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
||||
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
||||
# Import APIError from anthropic module and patch it on the agent_utils module
|
||||
from anthropic import APIError as AnthropicAPIError
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
|
||||
from ra_aid.agent_context import agent_context, is_crashed
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
|
||||
# Setup mocks
|
||||
dummy_agent = Mock()
|
||||
|
|
@ -613,7 +648,9 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
|||
if run_count[0] == 1:
|
||||
# First call throws a 400 Bad Request APIError
|
||||
mock_error = MockAPIError("400 Bad Request")
|
||||
mock_error.__class__.__name__ = "APIError" # Make it look like Anthropic's APIError
|
||||
mock_error.__class__.__name__ = (
|
||||
"APIError" # Make it look like Anthropic's APIError
|
||||
)
|
||||
raise mock_error
|
||||
# If it's called again, it should run normally
|
||||
|
||||
|
|
@ -638,14 +675,25 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
|||
|
||||
# Apply mocks
|
||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||
mock_restore_interrupt_handling,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
|
||||
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
|
||||
with agent_context() as ctx:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -12,8 +12,12 @@ def clean_env():
|
|||
"""Remove relevant environment variables before each test."""
|
||||
# Save existing values
|
||||
saved_vars = {}
|
||||
for var in ['ANTHROPIC_API_KEY', 'EXPERT_ANTHROPIC_API_KEY',
|
||||
'ANTHROPIC_MODEL', 'EXPERT_ANTHROPIC_MODEL']:
|
||||
for var in [
|
||||
"ANTHROPIC_API_KEY",
|
||||
"EXPERT_ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_MODEL",
|
||||
"EXPERT_ANTHROPIC_MODEL",
|
||||
]:
|
||||
saved_vars[var] = os.environ.get(var)
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
|
@ -31,6 +35,7 @@ def clean_env():
|
|||
@dataclass
|
||||
class MockArgs:
|
||||
"""Mock arguments class for testing."""
|
||||
|
||||
expert_provider: str
|
||||
expert_model: Optional[str] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
"""Unit tests for crash propagation behavior in agent_context."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
mark_agent_crashed,
|
||||
is_crashed,
|
||||
get_crash_message,
|
||||
is_crashed,
|
||||
mark_agent_crashed,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -45,12 +45,23 @@ def test_default_anthropic_provider(clean_env, monkeypatch):
|
|||
"""Test that Anthropic is the default provider when no environment variables are set."""
|
||||
args = parse_arguments(["-m", "test message"])
|
||||
assert args.provider == "anthropic"
|
||||
assert args.model == "claude-3-7-sonnet-20250219" # Updated to match current default
|
||||
assert (
|
||||
args.model == "claude-3-7-sonnet-20250219"
|
||||
) # Updated to match current default
|
||||
|
||||
|
||||
def test_respects_user_specified_anthropic_model(clean_env):
|
||||
"""Test that user-specified Anthropic models are respected."""
|
||||
args = parse_arguments(["-m", "test message", "--provider", "anthropic", "--model", "claude-3-5-sonnet-20241022"])
|
||||
args = parse_arguments(
|
||||
[
|
||||
"-m",
|
||||
"test message",
|
||||
"--provider",
|
||||
"anthropic",
|
||||
"--model",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
]
|
||||
)
|
||||
assert args.provider == "anthropic"
|
||||
assert args.model == "claude-3-5-sonnet-20241022" # Should not be overridden
|
||||
|
||||
|
|
|
|||
|
|
@ -123,10 +123,7 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
|
|||
temperature=0,
|
||||
timeout=180,
|
||||
max_retries=5,
|
||||
default_headers={
|
||||
"HTTP-Referer": "https://ra-aid.ai",
|
||||
"X-Title": "RA.Aid"
|
||||
}
|
||||
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -225,10 +222,7 @@ def test_initialize_openrouter(clean_env, mock_openai):
|
|||
temperature=0.7,
|
||||
timeout=180,
|
||||
max_retries=5,
|
||||
default_headers={
|
||||
"HTTP-Referer": "https://ra-aid.ai",
|
||||
"X-Title": "RA.Aid"
|
||||
}
|
||||
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -372,10 +366,7 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
|
|||
temperature=test_temp,
|
||||
timeout=180,
|
||||
max_retries=5,
|
||||
default_headers={
|
||||
"HTTP-Referer": "https://ra-aid.ai",
|
||||
"X-Title": "RA.Aid"
|
||||
}
|
||||
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -586,7 +577,9 @@ def mock_deepseek_reasoner():
|
|||
yield mock
|
||||
|
||||
|
||||
def test_reasoning_effort_only_passed_to_supported_models(clean_env, mock_openai, monkeypatch):
|
||||
def test_reasoning_effort_only_passed_to_supported_models(
|
||||
clean_env, mock_openai, monkeypatch
|
||||
):
|
||||
"""Test that reasoning_effort is only passed to supported models."""
|
||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
||||
|
||||
|
|
@ -603,7 +596,9 @@ def test_reasoning_effort_only_passed_to_supported_models(clean_env, mock_openai
|
|||
)
|
||||
|
||||
|
||||
def test_reasoning_effort_passed_to_supported_models(clean_env, mock_openai, monkeypatch):
|
||||
def test_reasoning_effort_passed_to_supported_models(
|
||||
clean_env, mock_openai, monkeypatch
|
||||
):
|
||||
"""Test that reasoning_effort is passed to models that support it."""
|
||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
||||
|
||||
|
|
@ -664,8 +659,5 @@ def test_initialize_openrouter_deepseek(
|
|||
temperature=0.7,
|
||||
timeout=180,
|
||||
max_retries=5,
|
||||
default_headers={
|
||||
"HTTP-Referer": "https://ra-aid.ai",
|
||||
"X-Title": "RA.Aid"
|
||||
}
|
||||
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -204,9 +204,9 @@ def test_use_aider_flag(mock_dependencies):
|
|||
"""Test that use-aider flag is correctly stored in config."""
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
|
||||
|
||||
from ra_aid.__main__ import main
|
||||
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
|
||||
|
||||
_global_memory.clear()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from ra_aid.tool_configs import (
|
||||
MODIFICATION_TOOLS,
|
||||
get_implementation_tools,
|
||||
get_planning_tools,
|
||||
get_read_only_tools,
|
||||
get_research_tools,
|
||||
get_web_research_tools,
|
||||
set_modification_tools,
|
||||
MODIFICATION_TOOLS,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -736,17 +736,20 @@ def test_emit_related_files_binary_filtering(reset_memory, tmp_path, monkeypatch
|
|||
|
||||
# Apply the mock
|
||||
import ra_aid.tools.memory
|
||||
|
||||
monkeypatch.setattr(ra_aid.tools.memory, "is_binary_file", mock_is_binary_file)
|
||||
|
||||
# Call emit_related_files with mix of text and binary files
|
||||
result = emit_related_files.invoke({
|
||||
"files": [
|
||||
str(text_file1),
|
||||
str(binary_file1),
|
||||
str(text_file2),
|
||||
str(binary_file2)
|
||||
]
|
||||
})
|
||||
result = emit_related_files.invoke(
|
||||
{
|
||||
"files": [
|
||||
str(text_file1),
|
||||
str(binary_file1),
|
||||
str(text_file2),
|
||||
str(binary_file2),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Verify the result message mentions skipped binary files
|
||||
assert "Files noted." in result
|
||||
|
|
@ -764,3 +767,57 @@ def test_emit_related_files_binary_filtering(reset_memory, tmp_path, monkeypatch
|
|||
|
||||
# Verify counter is correct (only incremented for text files)
|
||||
assert _global_memory["related_file_id_counter"] == 2
|
||||
|
||||
|
||||
def test_is_binary_file_with_ascii(reset_memory, monkeypatch):
|
||||
"""Test that ASCII files are correctly identified as text files"""
|
||||
import os
|
||||
|
||||
import ra_aid.tools.memory
|
||||
|
||||
# Path to the mock ASCII file
|
||||
ascii_file_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "mocks", "ascii.txt"
|
||||
)
|
||||
|
||||
# Test with magic library if available
|
||||
if ra_aid.tools.memory.magic:
|
||||
# Test real implementation with ASCII file
|
||||
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
|
||||
assert not is_binary, "ASCII file should not be identified as binary"
|
||||
|
||||
# Test fallback implementation
|
||||
# Mock magic to be None to force fallback implementation
|
||||
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
|
||||
|
||||
# Test fallback with ASCII file
|
||||
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
|
||||
assert (
|
||||
not is_binary
|
||||
), "ASCII file should not be identified as binary with fallback method"
|
||||
|
||||
|
||||
def test_is_binary_file_with_null_bytes(reset_memory, tmp_path, monkeypatch):
|
||||
"""Test that files with null bytes are correctly identified as binary"""
|
||||
import ra_aid.tools.memory
|
||||
|
||||
# Create a file with null bytes (binary content)
|
||||
binary_file = tmp_path / "binary_with_nulls.bin"
|
||||
with open(binary_file, "wb") as f:
|
||||
f.write(b"Some text with \x00 null \x00 bytes")
|
||||
|
||||
# Test with magic library if available
|
||||
if ra_aid.tools.memory.magic:
|
||||
# Test real implementation with binary file
|
||||
is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file))
|
||||
assert is_binary, "File with null bytes should be identified as binary"
|
||||
|
||||
# Test fallback implementation
|
||||
# Mock magic to be None to force fallback implementation
|
||||
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
|
||||
|
||||
# Test fallback with binary file
|
||||
is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file))
|
||||
assert (
|
||||
is_binary
|
||||
), "File with null bytes should be identified as binary with fallback method"
|
||||
|
|
|
|||
Loading…
Reference in New Issue