Binary Skipped ascii filetype fix (#108)

* chore: refactor code for improved readability and maintainability

- Standardize variable naming conventions for consistency.
- Improve logging messages for better clarity and debugging.
- Remove unnecessary imports and clean up code structure.
- Enhance error handling and logging in various modules.
- Update comments and docstrings for better understanding.
- Optimize imports and organize them logically.
- Ensure consistent formatting across files for better readability.
- Refactor functions to reduce complexity and improve performance.
- Add missing type hints and annotations for better code clarity.
- Improve test coverage and organization in test files.

style(tests): apply consistent formatting and spacing in test files for improved readability and maintainability

* chore(tests): remove redundant test for ensure_tables_created with no models to streamline test suite and reduce maintenance overhead

* fix(memory.py): update is_binary_file function to correctly identify binary files by returning True for non-text mime types
This commit is contained in:
Ariel Frischer 2025-02-28 03:47:35 -08:00 committed by GitHub
parent 429f854fb8
commit e960a68d29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1365 additions and 1060 deletions

View File

@ -6,9 +6,8 @@ This script creates a baseline migration representing the current database schem
It serves as the foundation for future schema changes.
"""
import sys
import os
from pathlib import Path
import sys
# Add the project root to the Python path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@ -20,10 +19,11 @@ from ra_aid.logging_config import get_logger, setup_logging
setup_logging(verbose=True)
logger = get_logger(__name__)
def create_initial_migration():
"""
Create the initial migration for the current database schema.
Returns:
bool: True if migration was created successfully, False otherwise
"""
@ -31,11 +31,11 @@ def create_initial_migration():
with DatabaseManager() as db:
# Create a descriptive name for the initial migration
migration_name = "initial_schema"
# Create the migration
logger.info(f"Creating initial migration '{migration_name}'...")
result = create_new_migration(migration_name, auto=True)
if result:
logger.info(f"Successfully created initial migration: {result}")
print(f"✅ Initial migration created successfully: {result}")
@ -49,9 +49,10 @@ def create_initial_migration():
print(f"❌ Error creating initial migration: {str(e)}")
return False
if __name__ == "__main__":
print("Creating initial database migration...")
success = create_initial_migration()
# Exit with appropriate code
sys.exit(0 if success else 1)

View File

@ -17,7 +17,6 @@ from ra_aid.agent_utils import (
run_planning_agent,
run_research_agent,
)
from ra_aid.database import init_db, close_db, DatabaseManager, ensure_migrations_applied
from ra_aid.config import (
DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_RECURSION_LIMIT,
@ -25,6 +24,10 @@ from ra_aid.config import (
VALID_PROVIDERS,
)
from ra_aid.console.output import cpm
from ra_aid.database import (
DatabaseManager,
ensure_migrations_applied,
)
from ra_aid.dependencies import check_dependencies
from ra_aid.env import validate_environment
from ra_aid.exceptions import AgentInterrupt
@ -171,8 +174,9 @@ Examples:
"--aider-config", type=str, help="Specify the aider config file path"
)
parser.add_argument(
"--use-aider", action="store_true",
help="Use aider for code modifications instead of default file tools (file_str_replace, put_complete_file_contents)"
"--use-aider",
action="store_true",
help="Use aider for code modifications instead of default file tools (file_str_replace, put_complete_file_contents)",
)
parser.add_argument(
"--test-cmd",
@ -343,24 +347,37 @@ def main():
try:
migration_result = ensure_migrations_applied()
if not migration_result:
logger.warning("Database migrations failed but execution will continue")
logger.warning(
"Database migrations failed but execution will continue"
)
except Exception as e:
logger.error(f"Database migration error: {str(e)}")
# Check dependencies before proceeding
check_dependencies()
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
validate_environment(args)
) # Will exit if main env vars missing
(
expert_enabled,
expert_missing,
web_research_enabled,
web_research_missing,
) = validate_environment(args) # Will exit if main env vars missing
logger.debug("Environment validation successful")
# Validate model configuration early
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
model_config = models_params.get(args.provider, {}).get(
args.model or "", {}
)
supports_temperature = model_config.get(
"supports_temperature",
args.provider
in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"],
in [
"anthropic",
"openai",
"openrouter",
"openai-compatible",
"deepseek",
],
)
if supports_temperature and args.temperature is None:
@ -377,7 +394,12 @@ def main():
status = build_status(args, expert_enabled, web_research_enabled)
console.print(
Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1))
Panel(
status,
title=f"RA.Aid v{__version__}",
border_style="bright_blue",
padding=(0, 1),
)
)
# Handle chat mode
@ -386,7 +408,7 @@ def main():
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
sys.exit(1)
@ -429,7 +451,7 @@ def main():
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
_global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
@ -449,7 +471,9 @@ def main():
CHAT_PROMPT.format(
initial_request=initial_request,
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else ""
WEB_RESEARCH_PROMPT_SECTION_CHAT
if web_research_enabled
else ""
),
working_directory=working_directory,
current_date=current_date,
@ -502,11 +526,13 @@ def main():
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = args.research_model or args.model
_global_memory["config"]["research_model"] = (
args.research_model or args.model
)
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)

View File

@ -2,7 +2,7 @@
import threading
from contextlib import contextmanager
from typing import Dict, Optional, Set
from typing import Optional
# Thread-local storage for context variables
_thread_local = threading.local()
@ -19,7 +19,7 @@ class AgentContext:
"""
# Store reference to parent context
self.parent = parent_context
# Initialize completion flags
self.task_completed = False
self.plan_completed = False
@ -27,8 +27,8 @@ class AgentContext:
self.agent_should_exit = False
self.agent_has_crashed = False
self.agent_crashed_message = None
# Note: Completion flags (task_completed, plan_completed, completion_message,
# Note: Completion flags (task_completed, plan_completed, completion_message,
# agent_should_exit) are no longer inherited from parent contexts
def mark_task_completed(self, message: str) -> None:
@ -58,29 +58,29 @@ class AgentContext:
def mark_should_exit(self) -> None:
"""Mark that the agent should exit execution.
This propagates the exit state to all parent contexts.
"""
self.agent_should_exit = True
# Propagate to parent context if it exists
if self.parent:
self.parent.mark_should_exit()
def mark_agent_crashed(self, message: str) -> None:
"""Mark the agent as crashed with the given message.
Unlike exit state, crash state does not propagate to parent contexts.
Args:
message: Error message explaining the crash
"""
self.agent_has_crashed = True
self.agent_crashed_message = message
def is_crashed(self) -> bool:
"""Check if the agent has crashed.
Returns:
True if the agent has crashed, False otherwise
"""
@ -116,17 +116,17 @@ def agent_context(parent_context=None):
"""
# Save the previous context
previous_context = getattr(_thread_local, "current_context", None)
# Create a new context, inheriting from parent if provided
# If parent_context is None but previous_context exists, use previous_context as parent
if parent_context is None and previous_context is not None:
context = AgentContext(previous_context)
else:
context = AgentContext(parent_context)
# Set as current context
_thread_local.current_context = context
try:
yield context
finally:
@ -202,7 +202,7 @@ def mark_should_exit() -> None:
def is_crashed() -> bool:
"""Check if the current agent has crashed.
Returns:
True if the current agent has crashed, False otherwise
"""
@ -212,7 +212,7 @@ def is_crashed() -> bool:
def mark_agent_crashed(message: str) -> None:
"""Mark the current agent as crashed with the given message.
Args:
message: Error message explaining the crash
"""
@ -223,9 +223,9 @@ def mark_agent_crashed(message: str) -> None:
def get_crash_message() -> Optional[str]:
"""Get the crash message from the current context.
Returns:
The crash message or None if the agent has not crashed
"""
context = get_current_context()
return context.agent_crashed_message if context and context.is_crashed() else None
return context.agent_crashed_message if context and context.is_crashed() else None

View File

@ -7,7 +7,7 @@ import threading
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence, ContextManager
from typing import Any, Dict, List, Literal, Optional, Sequence
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
@ -30,6 +30,12 @@ from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.agent_context import (
agent_context,
is_completed,
reset_completion_flags,
should_exit,
)
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.agents_alias import RAgents
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
@ -72,14 +78,6 @@ from ra_aid.tool_configs import (
get_web_research_tools,
)
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.agent_context import (
agent_context,
get_current_context,
is_completed,
reset_completion_flags,
get_completion_message,
should_exit,
)
from ra_aid.tools.memory import (
_global_memory,
get_memory_value,
@ -250,8 +248,12 @@ def is_anthropic_claude(config: Dict[str, Any]) -> bool:
provider = config.get("provider", "")
model_name = config.get("model", "")
result = (
(provider.lower() == "anthropic" and model_name and "claude" in model_name.lower())
or (provider.lower() == "openrouter" and model_name.lower().startswith("anthropic/claude-"))
provider.lower() == "anthropic"
and model_name
and "claude" in model_name.lower()
) or (
provider.lower() == "openrouter"
and model_name.lower().startswith("anthropic/claude-")
)
return result
@ -953,14 +955,15 @@ def run_agent_with_retry(
for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
# Check if the agent has crashed before attempting to run it
from ra_aid.agent_context import is_crashed, get_crash_message
from ra_aid.agent_context import get_crash_message, is_crashed
if is_crashed():
crash_message = get_crash_message()
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
try:
_run_agent_stream(agent, msg_list, config)
if fallback_handler:
@ -982,11 +985,12 @@ def run_agent_with_retry(
error_str = str(e).lower()
if "400" in error_str or "bad request" in error_str:
from ra_aid.agent_context import mark_agent_crashed
crash_message = f"Unretryable error: {str(e)}"
mark_agent_crashed(crash_message)
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
_handle_fallback_response(e, fallback_handler, agent, msg_list)
continue
except FallbackToolExecutionError as e:
@ -1007,13 +1011,16 @@ def run_agent_with_retry(
) as e:
# Check if this is a BadRequestError (HTTP 400) which is unretryable
error_str = str(e).lower()
if ("400" in error_str or "bad request" in error_str) and isinstance(e, APIError):
if (
"400" in error_str or "bad request" in error_str
) and isinstance(e, APIError):
from ra_aid.agent_context import mark_agent_crashed
crash_message = f"Unretryable API error: {str(e)}"
mark_agent_crashed(crash_message)
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
_handle_api_error(e, attempt, max_retries, base_delay)
finally:
_decrement_agent_depth()

View File

@ -5,34 +5,29 @@ This package provides database functionality for the ra_aid application,
including connection management, models, utility functions, and migrations.
"""
from ra_aid.database.connection import (
init_db,
get_db,
close_db,
DatabaseManager
from ra_aid.database.connection import DatabaseManager, close_db, get_db, init_db
from ra_aid.database.migrations import (
MigrationManager,
create_new_migration,
ensure_migrations_applied,
get_migration_status,
init_migrations,
)
from ra_aid.database.models import BaseModel
from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created
from ra_aid.database.migrations import (
init_migrations,
ensure_migrations_applied,
create_new_migration,
get_migration_status,
MigrationManager
)
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
__all__ = [
'init_db',
'get_db',
'close_db',
'DatabaseManager',
'BaseModel',
'get_model_count',
'truncate_table',
'ensure_tables_created',
'init_migrations',
'ensure_migrations_applied',
'create_new_migration',
'get_migration_status',
'MigrationManager',
"init_db",
"get_db",
"close_db",
"DatabaseManager",
"BaseModel",
"get_model_count",
"truncate_table",
"ensure_tables_created",
"init_migrations",
"ensure_migrations_applied",
"create_new_migration",
"get_migration_status",
"MigrationManager",
]

View File

@ -5,10 +5,10 @@ This module provides functions to initialize, get, and close database connection
It also provides a context manager for database connections.
"""
import os
import contextvars
import os
from pathlib import Path
from typing import Optional, Any
from typing import Any, Optional
import peewee
@ -22,64 +22,69 @@ logger = get_logger(__name__)
class DatabaseManager:
"""
Context manager for database connections.
This class provides a context manager interface for database connections,
using the existing contextvars approach internally.
Example:
with DatabaseManager() as db:
# Use the database connection
db.execute_sql("SELECT * FROM table")
# Or with in-memory database:
with DatabaseManager(in_memory=True) as db:
# Use in-memory database
"""
def __init__(self, in_memory: bool = False):
"""
Initialize the DatabaseManager.
Args:
in_memory: Whether to use an in-memory database (default: False)
"""
self.in_memory = in_memory
def __enter__(self) -> peewee.SqliteDatabase:
"""
Initialize the database connection and return it.
Returns:
peewee.SqliteDatabase: The initialized database connection
"""
return init_db(in_memory=self.in_memory)
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception],
exc_tb: Optional[Any]) -> None:
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[Any],
) -> None:
"""
Close the database connection when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
close_db()
# Don't suppress exceptions
return False
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
"""
Initialize the database connection.
Creates the .ra-aid directory if it doesn't exist and initializes
the SQLite database connection. If a database connection already exists,
returns the existing connection instead of creating a new one.
Args:
in_memory: Whether to use an in-memory database (default: False)
Returns:
peewee.SqliteDatabase: The initialized database connection
"""
@ -98,7 +103,7 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
else:
# Connection exists and is open, return it
return existing_db
# Set up database path
if in_memory:
# Use in-memory database
@ -108,25 +113,29 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
# Get current working directory and create .ra-aid directory if it doesn't exist
cwd = os.getcwd()
logger.debug(f"Current working directory: {cwd}")
# Define the .ra-aid directory path
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
ra_aid_dir = Path(ra_aid_dir_str)
ra_aid_dir = ra_aid_dir.absolute() # Ensure we have the absolute path
ra_aid_dir_str = str(ra_aid_dir) # Update string representation with absolute path
ra_aid_dir_str = str(
ra_aid_dir
) # Update string representation with absolute path
logger.debug(f"Creating database directory at: {ra_aid_dir_str}")
# Multiple approaches to ensure directory creation
directory_created = False
error_messages = []
# Approach 1: Try os.mkdir directly
if not os.path.exists(ra_aid_dir_str):
try:
logger.debug("Attempting directory creation with os.mkdir")
os.mkdir(ra_aid_dir_str, mode=0o755)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
ra_aid_dir_str
)
if directory_created:
logger.debug("Directory created successfully with os.mkdir")
except Exception as e:
@ -136,40 +145,46 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
else:
logger.debug("Directory already exists, skipping creation")
directory_created = True
# Approach 2: Try os.makedirs if os.mkdir failed
if not directory_created:
try:
logger.debug("Attempting directory creation with os.makedirs")
os.makedirs(ra_aid_dir_str, exist_ok=True, mode=0o755)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
ra_aid_dir_str
)
if directory_created:
logger.debug("Directory created successfully with os.makedirs")
except Exception as e:
error_msg = f"os.makedirs failed: {str(e)}"
logger.debug(error_msg)
error_messages.append(error_msg)
# Approach 3: Try Path.mkdir if previous methods failed
if not directory_created:
try:
logger.debug("Attempting directory creation with Path.mkdir")
ra_aid_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
ra_aid_dir_str
)
if directory_created:
logger.debug("Directory created successfully with Path.mkdir")
except Exception as e:
error_msg = f"Path.mkdir failed: {str(e)}"
logger.debug(error_msg)
error_messages.append(error_msg)
# Verify the directory was actually created
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
is_dir = os.path.isdir(ra_aid_dir_str) if os_exists else False
logger.debug(f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}")
logger.debug(
f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}"
)
# Check parent directory permissions and contents for debugging
try:
parent_dir = os.path.dirname(ra_aid_dir_str)
@ -179,19 +194,21 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
logger.debug(f"Parent directory contents: {parent_contents}")
except Exception as e:
logger.debug(f"Could not check parent directory: {str(e)}")
if not os_exists or not is_dir:
error_msg = f"Directory does not exist or is not a directory after creation attempts: {ra_aid_dir_str}"
logger.error(error_msg)
if error_messages:
logger.error(f"Previous errors: {', '.join(error_messages)}")
raise FileNotFoundError(f"Failed to create directory: {ra_aid_dir_str}")
# Check directory permissions
try:
permissions = oct(os.stat(ra_aid_dir_str).st_mode)[-3:]
logger.debug(f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}")
logger.debug(
f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}"
)
# List directory contents for debugging
dir_contents = os.listdir(ra_aid_dir_str)
logger.debug(f"Directory contents: {dir_contents}")
@ -201,21 +218,21 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
# Database path for file-based database - use os.path.join for maximum compatibility
db_path = os.path.join(ra_aid_dir_str, "pk.db")
logger.debug(f"Database path: {db_path}")
try:
# For file-based databases, ensure the file exists or can be created
if db_path != ":memory:":
# Check if the database file exists
db_file_exists = os.path.exists(db_path)
logger.debug(f"Database file exists check: {db_file_exists}")
# If the file doesn't exist, try to create an empty file to ensure we have write permissions
if not db_file_exists:
try:
logger.debug(f"Creating empty database file at: {db_path}")
with open(db_path, 'w') as f:
with open(db_path, "w") as f:
pass # Create empty file
# Verify the file was created
if os.path.exists(db_path):
logger.debug("Empty database file created successfully")
@ -224,51 +241,53 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
except Exception as e:
logger.error(f"Error creating database file: {str(e)}")
# Continue anyway, as SQLite might be able to create the file itself
# Initialize the database connection
logger.debug(f"Initializing SQLite database at: {db_path}")
db = peewee.SqliteDatabase(
db_path,
pragmas={
'journal_mode': 'wal', # Write-Ahead Logging for better concurrency
'foreign_keys': 1, # Enforce foreign key constraints
'cache_size': -1024 * 32, # 32MB cache
}
"journal_mode": "wal", # Write-Ahead Logging for better concurrency
"foreign_keys": 1, # Enforce foreign key constraints
"cache_size": -1024 * 32, # 32MB cache
},
)
# Always explicitly connect to ensure the connection is established
if db.is_closed():
logger.debug("Explicitly connecting to database")
db.connect()
# Store the database connection in the contextvar
db_var.set(db)
# Store whether this is an in-memory database (for backward compatibility)
db._is_in_memory = in_memory
# Verify the database is usable by executing a simple query
if not in_memory:
try:
db.execute_sql("SELECT 1")
logger.debug("Database connection verified with test query")
# Check if the database file exists after initialization
db_file_exists = os.path.exists(db_path)
db_file_size = os.path.getsize(db_path) if db_file_exists else 0
logger.debug(f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes")
logger.debug(
f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes"
)
except Exception as e:
logger.error(f"Database verification failed: {str(e)}")
# Continue anyway, as this is just a verification step
# Only show initialization message if it hasn't been shown before
if not hasattr(db, '_message_shown') or not db._message_shown:
if not hasattr(db, "_message_shown") or not db._message_shown:
if in_memory:
logger.debug("In-memory database connection initialized successfully")
else:
logger.debug("Database connection initialized successfully")
db._message_shown = True
return db
except peewee.OperationalError as e:
logger.error(f"Database Operational Error: {str(e)}")
@ -280,23 +299,24 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
logger.error(f"Failed to initialize database: {str(e)}")
raise
def get_db() -> peewee.SqliteDatabase:
"""
Get the current database connection.
If no connection exists, initializes a new one.
If connection exists but is closed, reopens it.
Returns:
peewee.SqliteDatabase: The current database connection
"""
db = db_var.get()
if db is None:
# No database connection exists, initialize one
# Use the default in-memory mode (False)
return init_db(in_memory=False)
# Check if connection is closed and reopen if needed
if db.is_closed():
try:
@ -309,30 +329,33 @@ def get_db() -> peewee.SqliteDatabase:
# First, remove the old connection from the context var
db_var.set(None)
# Then initialize a new connection with the same in-memory setting
in_memory = hasattr(db, '_is_in_memory') and db._is_in_memory
in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory
logger.debug(f"Creating new database connection (in_memory={in_memory})")
# Create a completely new database object, don't reuse the old one
return init_db(in_memory=in_memory)
return db
def close_db() -> None:
"""
Close the current database connection if it exists.
Handles various error conditions gracefully.
"""
db = db_var.get()
if db is None:
logger.warning("No database connection to close")
return
try:
if not db.is_closed():
db.close()
logger.info("Database connection closed successfully")
else:
logger.debug("Database connection was already closed (normal during shutdown)")
logger.debug(
"Database connection was already closed (normal during shutdown)"
)
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to close connection: {str(e)}")
except Exception as e:

View File

@ -6,16 +6,14 @@ using peewee-migrate. It includes tools for creating, checking, and applying
migrations automatically.
"""
import os
import datetime
import os
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any
from typing import Any, Dict, List, Optional, Tuple
import peewee
from peewee_migrate import Router
from peewee_migrate.router import DEFAULT_MIGRATE_DIR
from ra_aid.database.connection import get_db, DatabaseManager
from ra_aid.database.connection import DatabaseManager, get_db
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -28,48 +26,50 @@ MIGRATIONS_TABLE = "migrationshistory"
class MigrationManager:
"""
Manages database migrations for the ra_aid application.
This class provides methods to initialize the migrator, check for
pending migrations, apply migrations, and create new migrations.
"""
def __init__(self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None):
def __init__(
self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None
):
"""
Initialize the MigrationManager.
Args:
db_path: Optional path to the database file. If None, uses the default.
migrations_dir: Optional path to the migrations directory. If None, uses default.
"""
self.db = get_db()
# Determine database path
if db_path is None:
# Get current working directory
cwd = os.getcwd()
ra_aid_dir = os.path.join(cwd, ".ra-aid")
db_path = os.path.join(ra_aid_dir, "pk.db")
self.db_path = db_path
# Determine migrations directory
if migrations_dir is None:
# Use a directory within .ra-aid
ra_aid_dir = os.path.dirname(self.db_path)
migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME)
self.migrations_dir = migrations_dir
# Ensure migrations directory exists
self._ensure_migrations_dir()
# Initialize router
self.router = self._init_router()
def _ensure_migrations_dir(self) -> None:
"""
Ensure that the migrations directory exists.
Creates the directory if it doesn't exist.
"""
try:
@ -77,75 +77,79 @@ class MigrationManager:
if not migrations_path.exists():
logger.debug(f"Creating migrations directory at: {self.migrations_dir}")
migrations_path.mkdir(parents=True, exist_ok=True)
# Create __init__.py to make it a proper package
init_file = migrations_path / "__init__.py"
if not init_file.exists():
init_file.touch()
logger.debug(f"Using migrations directory: {self.migrations_dir}")
except Exception as e:
logger.error(f"Failed to create migrations directory: {str(e)}")
raise
def _init_router(self) -> Router:
"""
Initialize the peewee-migrate Router.
Returns:
Router: Configured peewee-migrate Router instance
"""
try:
router = Router(self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE)
router = Router(
self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE
)
logger.debug(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
return router
except Exception as e:
logger.error(f"Failed to initialize migration router: {str(e)}")
raise
def check_migrations(self) -> Tuple[List[str], List[str]]:
"""
Check for pending migrations.
Returns:
Tuple[List[str], List[str]]: A tuple containing (applied_migrations, pending_migrations)
"""
try:
# Get all migrations
all_migrations = self.router.todo
# Get applied migrations
applied = self.router.done
# Calculate pending migrations
pending = [m for m in all_migrations if m not in applied]
logger.debug(f"Found {len(applied)} applied migrations and {len(pending)} pending migrations")
logger.debug(
f"Found {len(applied)} applied migrations and {len(pending)} pending migrations"
)
return applied, pending
except Exception as e:
logger.error(f"Failed to check migrations: {str(e)}")
return [], []
def apply_migrations(self, fake: bool = False) -> bool:
"""
Apply all pending migrations.
Args:
fake: If True, mark migrations as applied without running them
Returns:
bool: True if migrations were applied successfully, False otherwise
"""
try:
# Get pending migrations
_, pending = self.check_migrations()
if not pending:
logger.info("No pending migrations to apply")
return True
logger.info(f"Applying {len(pending)} pending migrations...")
# Apply migrations
for migration in pending:
try:
@ -155,50 +159,50 @@ class MigrationManager:
except Exception as e:
logger.error(f"Failed to apply migration {migration}: {str(e)}")
return False
logger.info(f"Successfully applied {len(pending)} migrations")
return True
except Exception as e:
logger.error(f"Failed to apply migrations: {str(e)}")
return False
def create_migration(self, name: str, auto: bool = True) -> Optional[str]:
"""
Create a new migration.
Args:
name: Name of the migration
auto: If True, automatically detect model changes
Returns:
Optional[str]: The name of the created migration, or None if creation failed
"""
try:
# Sanitize migration name
safe_name = name.replace(' ', '_').lower()
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
safe_name = name.replace(" ", "_").lower()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
migration_name = f"{timestamp}_{safe_name}"
logger.info(f"Creating new migration: {migration_name}")
# Create migration
self.router.create(migration_name, auto=auto)
logger.info(f"Successfully created migration: {migration_name}")
return migration_name
except Exception as e:
logger.error(f"Failed to create migration: {str(e)}")
return None
def get_migration_status(self) -> Dict[str, Any]:
"""
Get the current migration status.
Returns:
Dict[str, Any]: A dictionary containing migration status information
"""
applied, pending = self.check_migrations()
return {
"applied_count": len(applied),
"pending_count": len(pending),
@ -209,14 +213,16 @@ class MigrationManager:
}
def init_migrations(db_path: Optional[str] = None, migrations_dir: Optional[str] = None) -> MigrationManager:
def init_migrations(
db_path: Optional[str] = None, migrations_dir: Optional[str] = None
) -> MigrationManager:
"""
Initialize the migration manager.
Args:
db_path: Optional path to the database file
migrations_dir: Optional path to the migrations directory
Returns:
MigrationManager: Initialized migration manager
"""
@ -226,10 +232,10 @@ def init_migrations(db_path: Optional[str] = None, migrations_dir: Optional[str]
def ensure_migrations_applied() -> bool:
"""
Check for and apply any pending migrations.
This function should be called during application startup to ensure
the database schema is up to date.
Returns:
bool: True if migrations were applied successfully or none were pending
"""
@ -245,11 +251,11 @@ def ensure_migrations_applied() -> bool:
def create_new_migration(name: str, auto: bool = True) -> Optional[str]:
"""
Create a new migration with the given name.
Args:
name: Name of the migration
auto: If True, automatically detect model changes
Returns:
Optional[str]: The name of the created migration, or None if creation failed
"""
@ -265,7 +271,7 @@ def create_new_migration(name: str, auto: bool = True) -> Optional[str]:
def get_migration_status() -> Dict[str, Any]:
"""
Get the current migration status.
Returns:
Dict[str, Any]: A dictionary containing migration status information
"""

View File

@ -5,51 +5,53 @@ This module defines the base model class that all models will inherit from.
"""
import datetime
from typing import Any, Dict, Type, TypeVar
from typing import Any, Type, TypeVar
import peewee
from ra_aid.database.connection import get_db
from ra_aid.logging_config import get_logger
T = TypeVar('T', bound='BaseModel')
T = TypeVar("T", bound="BaseModel")
logger = get_logger(__name__)
class BaseModel(peewee.Model):
"""
Base model class for all ra_aid models.
All models should inherit from this class to ensure consistent
behavior and database connection.
"""
created_at = peewee.DateTimeField(default=datetime.datetime.now)
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
class Meta:
database = get_db()
def save(self, *args: Any, **kwargs: Any) -> int:
"""
Override save to update the updated_at field.
Args:
*args: Arguments to pass to the parent save method
**kwargs: Keyword arguments to pass to the parent save method
Returns:
int: The primary key of the saved instance
"""
self.updated_at = datetime.datetime.now()
return super().save(*args, **kwargs)
@classmethod
def get_or_create(cls: Type[T], **kwargs: Any) -> tuple[T, bool]:
"""
Get an instance or create it if it doesn't exist.
Args:
**kwargs: Fields to use for lookup and creation
Returns:
tuple: (instance, created) where created is a boolean indicating
whether a new instance was created

View File

@ -3,14 +3,18 @@ Tests for the database connection module.
"""
import os
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from unittest.mock import patch
import peewee
import pytest
from ra_aid.database.connection import (
init_db, get_db, close_db, db_var, DatabaseManager
DatabaseManager,
close_db,
db_var,
get_db,
init_db,
)
@ -21,17 +25,17 @@ def cleanup_db():
"""
# Run the test
yield
# Clean up after the test
db = db_var.get()
if db is not None:
if hasattr(db, '_is_in_memory'):
delattr(db, '_is_in_memory')
if hasattr(db, '_message_shown'):
delattr(db, '_message_shown')
if hasattr(db, "_is_in_memory"):
delattr(db, "_is_in_memory")
if hasattr(db, "_message_shown"):
delattr(db, "_message_shown")
if not db.is_closed():
db.close()
# Reset the contextvar
db_var.set(None)
@ -43,10 +47,10 @@ def setup_in_memory_db():
"""
# Initialize in-memory database
db = init_db(in_memory=True)
# Run the test
yield db
# Clean up
if not db.is_closed():
db.close()
@ -60,60 +64,68 @@ def test_init_db_creates_directory(cleanup_db, tmp_path):
# Get and print the original working directory
original_cwd = os.getcwd()
print(f"Original working directory: {original_cwd}")
# Convert tmp_path to string for consistent handling
tmp_path_str = str(tmp_path.absolute())
print(f"Temporary directory path: {tmp_path_str}")
# Change to the temporary directory
os.chdir(tmp_path_str)
current_cwd = os.getcwd()
print(f"Changed working directory to: {current_cwd}")
assert current_cwd == tmp_path_str, f"Failed to change directory: {current_cwd} != {tmp_path_str}"
assert (
current_cwd == tmp_path_str
), f"Failed to change directory: {current_cwd} != {tmp_path_str}"
# Create the .ra-aid directory manually to ensure it exists
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
os.makedirs(ra_aid_path_str, exist_ok=True)
# Verify the directory was created
assert os.path.exists(ra_aid_path_str), f".ra-aid directory not found at {ra_aid_path_str}"
assert os.path.isdir(ra_aid_path_str), f"{ra_aid_path_str} exists but is not a directory"
assert os.path.exists(
ra_aid_path_str
), f".ra-aid directory not found at {ra_aid_path_str}"
assert os.path.isdir(
ra_aid_path_str
), f"{ra_aid_path_str} exists but is not a directory"
# Create a test file to verify write permissions
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
print(f"Creating test file to verify write permissions: {test_file_path}")
with open(test_file_path, 'w') as f:
with open(test_file_path, "w") as f:
f.write("Test write permissions")
# Verify the test file was created
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
# Create an empty database file to ensure it exists before init_db
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
print(f"Creating empty database file at: {db_file_str}")
with open(db_file_str, 'w') as f:
with open(db_file_str, "w") as f:
f.write("") # Create empty file
# Verify the database file was created
assert os.path.exists(db_file_str), f"Empty database file not created at {db_file_str}"
assert os.path.exists(
db_file_str
), f"Empty database file not created at {db_file_str}"
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
# Get directory permissions for debugging
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
print(f"Directory permissions: {dir_perms}")
# Initialize the database
print("Calling init_db()")
db = init_db()
print("init_db() returned successfully")
# List contents of the current directory for debugging
print(f"Contents of current directory: {os.listdir(current_cwd)}")
# List contents of the .ra-aid directory for debugging
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
# Check that the database file exists using os.path
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
@ -126,10 +138,10 @@ def test_init_db_creates_database_file(cleanup_db, tmp_path):
"""
# Change to the temporary directory
os.chdir(tmp_path)
# Initialize the database
init_db()
# Check that the database file was created
assert (tmp_path / ".ra-aid" / "pk.db").exists()
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
@ -141,7 +153,7 @@ def test_init_db_returns_database_connection(cleanup_db):
"""
# Initialize the database
db = init_db()
# Check that the database connection is returned
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
@ -153,11 +165,11 @@ def test_init_db_with_in_memory_mode(cleanup_db):
"""
# Initialize the database in in-memory mode
db = init_db(in_memory=True)
# Check that the database connection is returned
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
@ -167,10 +179,10 @@ def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path):
"""
# Change to the temporary directory
os.chdir(tmp_path)
# Initialize the database in in-memory mode
init_db(in_memory=True)
# Check that the .ra-aid directory was not created
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
# Instead, check that the database file was not created
@ -183,10 +195,10 @@ def test_init_db_returns_existing_connection(cleanup_db):
"""
# Initialize the database
db1 = init_db()
# Initialize the database again
db2 = init_db()
# Check that the same connection is returned
assert db1 is db2
@ -197,13 +209,13 @@ def test_init_db_reopens_closed_connection(cleanup_db):
"""
# Initialize the database
db1 = init_db()
# Close the connection
db1.close()
# Initialize the database again
db2 = init_db()
# Check that the same connection is returned and it's open
assert db1 is db2
assert not db1.is_closed()
@ -215,7 +227,7 @@ def test_get_db_initializes_connection(cleanup_db):
"""
# Get the database connection
db = get_db()
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
@ -227,10 +239,10 @@ def test_get_db_returns_existing_connection(cleanup_db):
"""
# Initialize the database
db1 = init_db()
# Get the database connection
db2 = get_db()
# Check that the same connection is returned
assert db1 is db2
@ -241,13 +253,13 @@ def test_get_db_reopens_closed_connection(cleanup_db):
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Get the database connection
db2 = get_db()
# Check that the same connection is returned and it's open
assert db is db2
assert not db.is_closed()
@ -259,24 +271,24 @@ def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Create a patched version of the connect method that raises an error
original_connect = peewee.SqliteDatabase.connect
def mock_connect(self, *args, **kwargs):
if self is db: # Only raise for the specific db instance
raise peewee.OperationalError("Test error")
return original_connect(self, *args, **kwargs)
# Apply the patch
monkeypatch.setattr(peewee.SqliteDatabase, 'connect', mock_connect)
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
# Get the database connection
db2 = get_db()
# Check that a new connection was initialized
assert db is not db2
assert not db2.is_closed()
@ -288,10 +300,10 @@ def test_close_db_closes_connection(cleanup_db):
"""
# Initialize the database
db = init_db()
# Close the connection
close_db()
# Check that the connection is closed
assert db.is_closed()
@ -302,7 +314,7 @@ def test_close_db_handles_no_connection():
"""
# Reset the contextvar
db_var.set(None)
# Close the connection (should not raise an error)
close_db()
@ -313,25 +325,25 @@ def test_close_db_handles_already_closed_connection(cleanup_db):
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Close the connection again (should not raise an error)
close_db()
@patch('ra_aid.database.connection.peewee.SqliteDatabase.close')
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
def test_close_db_handles_error(mock_close, cleanup_db):
"""
Test that close_db handles errors when closing the connection.
"""
# Initialize the database
init_db()
# Make close raise an error
mock_close.side_effect = peewee.DatabaseError("Test error")
# Close the connection (should not raise an error)
close_db()
@ -345,10 +357,10 @@ def test_database_manager_context_manager(cleanup_db):
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
# Store the connection for later
db_in_context = db
# Check that the connection is closed after exiting the context
assert db_in_context.is_closed()
@ -362,7 +374,7 @@ def test_database_manager_with_in_memory_mode(cleanup_db):
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
@ -372,13 +384,13 @@ def test_init_db_shows_message_only_once(cleanup_db, caplog):
"""
# Initialize the database
init_db(in_memory=True)
# Clear the log
caplog.clear()
# Initialize the database again
init_db(in_memory=True)
# Check that no message was logged
assert "database connection initialized" not in caplog.text.lower()
@ -389,41 +401,36 @@ def test_init_db_sets_is_in_memory_attribute(cleanup_db):
"""
# Initialize the database with in_memory=False
db = init_db(in_memory=False)
# Check that the _is_in_memory attribute is set to False
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Reset the contextvar
db_var.set(None)
# Initialize the database with in_memory=True
db = init_db(in_memory=True)
# Check that the _is_in_memory attribute is set to True
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
"""
Tests for the database connection module.
"""
import os
import shutil
from pathlib import Path
import pytest
import peewee
from ra_aid.database.connection import (
init_db, get_db, close_db,
db_var, DatabaseManager, logger
)
import pytest
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections and files between tests.
This fixture:
1. Closes any open database connection
2. Resets the contextvar
@ -431,117 +438,119 @@ def cleanup_db():
"""
# Store the current working directory
original_cwd = os.getcwd()
# Run the test
yield
# Clean up after the test
try:
# Close any open database connection
close_db()
# Reset the contextvar
db_var.set(None)
# Clean up the .ra-aid directory if it exists
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
ra_aid_dir_str = str(ra_aid_dir.absolute())
# Check using both methods
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
if os_exists:
# Only remove the database file, not the entire directory
db_file = os.path.join(ra_aid_dir_str, "pk.db")
if os.path.exists(db_file):
os.unlink(db_file)
# Remove WAL and SHM files if they exist
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
if os.path.exists(wal_file):
os.unlink(wal_file)
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
if os.path.exists(shm_file):
os.unlink(shm_file)
# List remaining contents for debugging
if os.path.exists(ra_aid_dir_str):
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
except Exception as e:
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
# Make sure we're back in the original directory
os.chdir(original_cwd)
class TestInitDb:
"""Tests for the init_db function."""
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
# Get the absolute path of the current working directory
cwd = os.getcwd()
print(f"Current working directory: {cwd}")
# Initialize the database
db = init_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created using both Path and os.path methods
ra_aid_dir = Path(cwd) / ".ra-aid"
ra_aid_dir_str = str(ra_aid_dir.absolute())
# Check directory existence using both methods
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
# List the contents of the current directory
print(f"Contents of {cwd}: {os.listdir(cwd)}")
# If the directory exists, list its contents
if os_exists:
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
# Use os.path for assertions to be more reliable
assert os.path.exists(ra_aid_dir_str), f"Directory {ra_aid_dir_str} does not exist"
assert os.path.exists(
ra_aid_dir_str
), f"Directory {ra_aid_dir_str} does not exist"
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
db_file = os.path.join(ra_aid_dir_str, "pk.db")
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
assert os.path.isfile(db_file), f"{db_file} is not a file"
def test_init_db_in_memory(self, cleanup_db):
"""Test init_db with in_memory=True."""
db = init_db(in_memory=True)
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
db1 = init_db()
db2 = init_db()
assert db1 is db2
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = init_db()
assert db1 is db2
assert not db1.is_closed()
@ -549,32 +558,32 @@ class TestInitDb:
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_creates_connection(self, cleanup_db):
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
db = get_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
def test_get_db_reuses_connection(self, cleanup_db):
"""Test that get_db reuses an existing connection."""
db1 = init_db()
db2 = get_db()
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = get_db()
assert db1 is db2
assert not db1.is_closed()
@ -582,63 +591,63 @@ class TestGetDb:
class TestCloseDb:
"""Tests for the close_db function."""
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
assert not db.is_closed()
close_db()
assert db.is_closed()
def test_close_db_no_connection(self, cleanup_db):
"""Test that close_db handles the case where no connection exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
# This should not raise an exception
close_db()
def test_close_db_already_closed(self, cleanup_db):
"""Test that close_db handles the case where the connection is already closed."""
db = init_db()
db.close()
assert db.is_closed()
# This should not raise an exception
close_db()
class TestDatabaseManager:
"""Tests for the DatabaseManager class."""
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_in_memory(self, cleanup_db):
"""Test DatabaseManager with in_memory=True."""
with DatabaseManager(in_memory=True) as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_exception_handling(self, cleanup_db):
"""Test that DatabaseManager properly handles exceptions."""
try:
@ -648,6 +657,6 @@ class TestDatabaseManager:
except ValueError:
# The exception should be propagated
pass
# Verify the connection is closed even if an exception occurred
assert db.is_closed()

View File

@ -4,7 +4,6 @@ Database utility functions for ra_aid.
This module provides utility functions for common database operations.
"""
import importlib
import inspect
from typing import List, Type
@ -16,18 +15,19 @@ from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
"""
Ensure that database tables for the specified models exist.
If no models are specified, this function will attempt to discover
all models that inherit from BaseModel.
Args:
models: Optional list of model classes to create tables for
"""
db = get_db()
if models is None:
# If no models are specified, try to discover them
models = []
@ -36,20 +36,22 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
# This is a placeholder - in a real implementation, you would
# dynamically discover and import all modules that might contain models
from ra_aid.database import models as models_module
# Find all classes in the module that inherit from BaseModel
for name, obj in inspect.getmembers(models_module):
if (inspect.isclass(obj) and
issubclass(obj, BaseModel) and
obj != BaseModel):
if (
inspect.isclass(obj)
and issubclass(obj, BaseModel)
and obj != BaseModel
):
models.append(obj)
except ImportError as e:
logger.warning(f"Error importing model modules: {e}")
if not models:
logger.warning("No models found to create tables for")
return
try:
with db.atomic():
db.create_tables(models, safe=True)
@ -61,13 +63,14 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
logger.error(f"Error: Failed to create tables: {str(e)}")
raise
def get_model_count(model_class: Type[BaseModel]) -> int:
"""
Get the count of records for a specific model.
Args:
model_class: The model class to count records for
Returns:
int: The number of records for the model
"""
@ -77,10 +80,11 @@ def get_model_count(model_class: Type[BaseModel]) -> int:
logger.error(f"Database Error: Failed to count records: {str(e)}")
return 0
def truncate_table(model_class: Type[BaseModel]) -> None:
"""
Delete all records from a model's table.
Args:
model_class: The model class to truncate
"""

View File

@ -1,8 +1,7 @@
"""Module for checking system dependencies required by RA.Aid."""
import os
import sys
import subprocess
import sys
from abc import ABC, abstractmethod
from ra_aid import print_error
@ -23,9 +22,11 @@ class RipGrepDependency(Dependency):
def check(self):
"""Check if ripgrep is installed."""
try:
result = subprocess.run(['rg', '--version'],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
result = subprocess.run(
["rg", "--version"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
if result.returncode != 0:
raise FileNotFoundError()
except (FileNotFoundError, subprocess.SubprocessError):

View File

@ -122,11 +122,8 @@ def create_openrouter_client(
is_expert: bool = False,
) -> BaseChatModel:
"""Create OpenRouter client with appropriate configuration."""
default_headers = {
"HTTP-Referer": "https://ra-aid.ai",
"X-Title": "RA.Aid"
}
default_headers = {"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"}
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
return ChatDeepseekReasoner(
api_key=api_key,
@ -243,12 +240,9 @@ def create_llm_client(
temp_kwargs = {"temperature": temperature}
else:
temp_kwargs = {}
if supports_thinking:
temp_kwargs = {"thinking": {
"type": "enabled",
"budget_tokens": 12000
}}
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
if provider == "deepseek":
return create_deepseek_client(

View File

@ -17,7 +17,7 @@ import signal
import subprocess
import sys
import time
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import pyte
from pyte.screens import HistoryScreen
@ -33,17 +33,20 @@ else:
def create_process(
cmd: List[str], env: Optional[dict] = None, cols: Optional[int] = None, rows: Optional[int] = None
cmd: List[str],
env: Optional[dict] = None,
cols: Optional[int] = None,
rows: Optional[int] = None,
) -> Tuple[subprocess.Popen, Optional[int]]:
"""
Create a subprocess with appropriate settings for the current platform.
Args:
cmd: Command to execute as a list of strings
env: Environment variables dictionary, defaults to os.environ.copy()
cols: Number of columns for the terminal, defaults to current terminal width
rows: Number of rows for the terminal, defaults to current terminal height
Returns:
On Unix: (process, master_fd) where master_fd is the file descriptor for the pty master
On Windows: (process, None) as Windows doesn't use ptys
@ -61,7 +64,7 @@ def create_process(
# Windows-specific process creation
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
proc = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
@ -78,7 +81,7 @@ def create_process(
master_fd, slave_fd = os.openpty()
# Set master_fd to non-blocking to avoid indefinite blocking
os.set_blocking(master_fd, False)
proc = subprocess.Popen(
cmd,
stdin=slave_fd,
@ -90,18 +93,18 @@ def create_process(
preexec_fn=os.setsid, # Create new process group for proper signal handling
)
os.close(slave_fd) # Close slave end in the parent process
return proc, master_fd
def get_terminal_size() -> Tuple[int, int]:
"""
Get the current terminal size in a cross-platform way.
This function works on both Unix and Windows systems, using shutil.get_terminal_size()
which is available in Python 3.3+. If the terminal size cannot be determined
(e.g., when running in a non-interactive environment), it falls back to default values.
Returns:
A tuple of (columns, rows) representing the terminal dimensions.
"""
@ -117,11 +120,11 @@ def render_line(line, columns: int) -> str:
"""Render a single screen line from the pyte buffer (a mapping of column to Char)."""
if not line:
return ""
# Handle string lines directly (from screen.display)
if isinstance(line, str):
return line
# Handle dictionary-style lines (from history)
try:
return "".join(line[x].data for x in range(columns) if x in line)
@ -135,21 +138,21 @@ def run_interactive_command(
) -> Tuple[bytes, int]:
"""
Runs an interactive command with output capture, capturing final scrollback history.
This function provides a cross-platform way to run interactive commands with:
- Full terminal emulation using pyte's HistoryScreen
- Real-time display of command output
- Input forwarding when running in an interactive terminal
- Timeout handling to prevent runaway processes
- Comprehensive output capture including ANSI escape sequences
The implementation differs significantly between Windows and Unix:
On Windows:
- Uses threading to handle I/O operations
- Relies on msvcrt for keyboard input detection
- Uses pipes for process communication
On Unix:
- Uses pseudo-terminals (PTY) for full terminal emulation
- Uses select() for non-blocking I/O
@ -230,7 +233,7 @@ def run_interactive_command(
# Windows implementation using threads for I/O
running = True
stdin_thread = None
def read_stdout():
nonlocal running
while running and proc.poll() is None:
@ -246,7 +249,7 @@ def run_interactive_command(
except Exception as e:
print(f"Error reading stdout: {e}", file=sys.stderr)
break
def read_stderr():
nonlocal running
while running and proc.poll() is None:
@ -262,7 +265,7 @@ def run_interactive_command(
except Exception as e:
print(f"Error reading stderr: {e}", file=sys.stderr)
break
def handle_input():
nonlocal running
try:
@ -276,7 +279,7 @@ def run_interactive_command(
pass
except Exception as e:
print(f"Error handling input: {e}", file=sys.stderr)
# Start I/O threads
stdout_thread = threading.Thread(target=read_stdout)
stderr_thread = threading.Thread(target=read_stderr)
@ -284,13 +287,13 @@ def run_interactive_command(
stderr_thread.daemon = True
stdout_thread.start()
stderr_thread.start()
# Only start stdin thread if we're in an interactive terminal
if sys.stdin.isatty():
stdin_thread = threading.Thread(target=handle_input)
stdin_thread.daemon = True
stdin_thread.start()
try:
# Main thread monitors timeout
while proc.poll() is None:
@ -307,7 +310,7 @@ def run_interactive_command(
stderr_thread.join(1.0)
if stdin_thread:
stdin_thread.join(1.0)
# Close pipes
if proc.stdout:
proc.stdout.close()
@ -387,23 +390,23 @@ def run_interactive_command(
# Ensure we have captured data even if the screen processing failed
raw_output = b"".join(captured_data)
# Process the captured output through a fresh screen
try:
# Create a new screen and stream for final processing
screen = HistoryScreen(cols, rows, history=2000, ratio=0.5)
stream = pyte.Stream(screen)
# Feed all captured data at once to get the final state
raw_output = b"".join(captured_data)
decoded = raw_output.decode("utf-8", errors="ignore")
stream.feed(decoded)
# Get all history lines (top and bottom) and current display
all_lines = []
# Add history.top lines (older history)
if hasattr(screen.history.top, 'keys'):
if hasattr(screen.history.top, "keys"):
# Dictionary-like object
for line_num in sorted(screen.history.top.keys()):
line = screen.history.top[line_num]
@ -412,12 +415,12 @@ def run_interactive_command(
# Deque or other iterable
for i, line in enumerate(screen.history.top):
all_lines.append(render_line(line, cols))
# Add current display lines
all_lines.extend([render_line(line, cols) for line in screen.display])
# Add history.bottom lines (newer history)
if hasattr(screen.history.bottom, 'keys'):
if hasattr(screen.history.bottom, "keys"):
# Dictionary-like object
for line_num in sorted(screen.history.bottom.keys()):
line = screen.history.bottom[line_num]
@ -426,23 +429,23 @@ def run_interactive_command(
# Deque or other iterable
for i, line in enumerate(screen.history.bottom):
all_lines.append(render_line(line, cols))
# Trim out empty lines to get only meaningful lines
# Also strip trailing whitespace from each line
trimmed_lines = [line.rstrip() for line in all_lines if line and line.strip()]
final_output = "\n".join(trimmed_lines)
except Exception as e:
# If anything goes wrong with screen processing, fall back to raw output
print(f"Warning: Error processing terminal output: {e}", file=sys.stderr)
try:
# Decode raw output, strip trailing whitespace from each line
decoded = raw_output.decode('utf-8', errors='replace')
decoded = raw_output.decode("utf-8", errors="replace")
lines = [line.rstrip() for line in decoded.splitlines()]
final_output = "\n".join(lines)
except Exception:
# Ultimate fallback if line processing fails
final_output = raw_output.decode('utf-8', errors='replace').strip()
final_output = raw_output.decode("utf-8", errors="replace").strip()
# Add timeout message if process was terminated due to timeout.
if was_terminated:
@ -458,7 +461,7 @@ def run_interactive_command(
else:
# Handle any unexpected type
final_output = str(final_output)[-8000:].encode("utf-8")
return final_output, proc.returncode

View File

@ -241,7 +241,9 @@ If you find this is an empty directory, you can stop research immediately and as
"""
RESEARCH_PROMPT = RESEARCH_COMMON_PROMPT_HEADER + """
RESEARCH_PROMPT = (
RESEARCH_COMMON_PROMPT_HEADER
+ """
Project State Handling:
For new/empty projects:
@ -280,9 +282,12 @@ NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATINO IS REQUIRED, CALL request_implementation.
IF THE USER ASKED YOU TO UPDATE A FILE, JUST DO RESEARCH FIRST, EMIT YOUR RESEARCH NOTES, THEN CALL request_implementation.
"""
)
# Research-only prompt - similar to research prompt but without implementation references
RESEARCH_ONLY_PROMPT = RESEARCH_COMMON_PROMPT_HEADER + """
RESEARCH_ONLY_PROMPT = (
RESEARCH_COMMON_PROMPT_HEADER
+ """
You have been spawned by a higher level research agent, so only spawn more research tasks sparingly if absolutely necessary. Keep your research *very* scoped and efficient.
@ -290,6 +295,7 @@ When you emit research notes, keep it extremely concise and relevant only to the
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
"""
)
# Web research prompt - guides web search and information gathering
WEB_RESEARCH_PROMPT = """Current Date: {current_date}

View File

@ -162,7 +162,9 @@ class AnthropicStrategy(ProviderStrategy):
if not base_key:
missing.append("ANTHROPIC_API_KEY environment variable is not set")
else:
missing.append("EXPERT_ANTHROPIC_API_KEY environment variable is not set")
missing.append(
"EXPERT_ANTHROPIC_API_KEY environment variable is not set"
)
else:
key = os.environ.get("ANTHROPIC_API_KEY")
if not key:

View File

@ -2,17 +2,17 @@
import threading
import time
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_task_completed,
mark_plan_completed,
reset_completion_flags,
is_completed,
get_completion_message,
get_current_context,
is_completed,
mark_plan_completed,
mark_task_completed,
reset_completion_flags,
)
@ -30,7 +30,7 @@ class TestAgentContext:
"""Test that child contexts inherit state from parent contexts."""
parent = AgentContext()
parent.mark_task_completed("Parent task completed")
child = AgentContext(parent_context=parent)
assert child.task_completed is True
assert child.completion_message == "Parent task completed"
@ -39,7 +39,7 @@ class TestAgentContext:
"""Test marking a task as completed."""
context = AgentContext()
context.mark_task_completed("Task done")
assert context.task_completed is True
assert context.plan_completed is False
assert context.completion_message == "Task done"
@ -48,7 +48,7 @@ class TestAgentContext:
"""Test marking a plan as completed."""
context = AgentContext()
context.mark_plan_completed("Plan done")
assert context.task_completed is True
assert context.plan_completed is True
assert context.completion_message == "Plan done"
@ -57,7 +57,7 @@ class TestAgentContext:
"""Test resetting completion flags."""
context = AgentContext()
context.mark_task_completed("Task done")
context.reset_completion_flags()
assert context.task_completed is False
assert context.plan_completed is False
@ -67,13 +67,13 @@ class TestAgentContext:
"""Test the is_completed property."""
context = AgentContext()
assert context.is_completed is False
context.mark_task_completed("Task done")
assert context.is_completed is True
context.reset_completion_flags()
assert context.is_completed is False
context.mark_plan_completed("Plan done")
assert context.is_completed is True
@ -84,29 +84,29 @@ class TestContextManager:
def test_context_manager_basic(self):
"""Test basic context manager functionality."""
assert get_current_context() is None
with agent_context() as ctx:
assert get_current_context() is ctx
assert ctx.task_completed is False
assert get_current_context() is None
def test_nested_context_managers(self):
"""Test nested context managers."""
with agent_context() as outer_ctx:
assert get_current_context() is outer_ctx
with agent_context() as inner_ctx:
assert get_current_context() is inner_ctx
assert inner_ctx is not outer_ctx
assert get_current_context() is outer_ctx
def test_context_manager_with_parent(self):
"""Test context manager with explicit parent context."""
parent = AgentContext()
parent.mark_task_completed("Parent task")
with agent_context(parent_context=parent) as ctx:
assert ctx.task_completed is True
assert ctx.completion_message == "Parent task"
@ -115,13 +115,13 @@ class TestContextManager:
"""Test that nested contexts inherit from outer contexts by default."""
with agent_context() as outer:
outer.mark_task_completed("Outer task")
with agent_context() as inner:
assert inner.task_completed is True
assert inner.completion_message == "Outer task"
inner.mark_plan_completed("Inner plan")
# Outer context should not be affected by inner context changes
assert outer.task_completed is True
assert outer.plan_completed is False
@ -134,23 +134,23 @@ class TestThreadIsolation:
def test_thread_isolation(self):
"""Test that contexts are isolated between threads."""
results = {}
def thread_func(thread_id):
with agent_context() as ctx:
ctx.mark_task_completed(f"Thread {thread_id}")
time.sleep(0.1) # Give other threads time to run
# Store the context's message for verification
results[thread_id] = get_completion_message()
threads = []
for i in range(3):
t = threading.Thread(target=thread_func, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Each thread should have its own message
assert results[0] == "Thread 0"
assert results[1] == "Thread 1"
@ -188,26 +188,16 @@ class TestUtilityFunctions:
mark_task_completed("No context")
mark_plan_completed("No context")
reset_completion_flags()
# These should have safe default returns
assert is_completed() is False
assert get_completion_message() == ""
"""Unit tests for the agent_context module."""
import threading
import time
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_task_completed,
mark_plan_completed,
reset_completion_flags,
is_completed,
get_completion_message,
)
class TestAgentContext:

View File

@ -1,11 +1,9 @@
"""Unit tests for agent_should_exit functionality."""
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_should_exit,
should_exit,
)
@ -18,7 +16,7 @@ class TestAgentShouldExit:
"""Test basic mark_should_exit functionality."""
context = AgentContext()
assert context.agent_should_exit is False
context.mark_should_exit()
assert context.agent_should_exit is True
@ -34,10 +32,10 @@ class TestAgentShouldExit:
"""Test that mark_should_exit propagates to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Mark child as should exit
child.mark_should_exit()
# Verify both child and parent are marked
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
@ -49,10 +47,10 @@ class TestAgentShouldExit:
# Initially both should be False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark inner as should exit
inner.mark_should_exit()
# Both should now be True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
assert outer.agent_should_exit is True

View File

@ -9,9 +9,9 @@ from ra_aid.tools import (
emit_related_files,
emit_research_notes,
file_str_replace,
put_complete_file_contents,
fuzzy_find_project_files,
list_directory_tree,
put_complete_file_contents,
read_file_tool,
ripgrep_search,
run_programming_task,
@ -26,12 +26,12 @@ from ra_aid.tools.agent import (
request_task_implementation,
request_web_research,
)
from ra_aid.tools.memory import one_shot_completed, plan_implementation_completed
from ra_aid.tools.memory import plan_implementation_completed
def set_modification_tools(use_aider=False):
"""Set the MODIFICATION_TOOLS list based on configuration.
Args:
use_aider: Whether to use run_programming_task (True) or file modification tools (False)
"""
@ -46,7 +46,9 @@ def set_modification_tools(use_aider=False):
# Read-only tools that don't modify system state
def get_read_only_tools(
human_interaction: bool = False, web_research_enabled: bool = False, use_aider: bool = False
human_interaction: bool = False,
web_research_enabled: bool = False,
use_aider: bool = False,
):
"""Get the list of read-only tools, optionally including human interaction tools.
@ -100,6 +102,7 @@ def get_all_tools() -> list[BaseTool]:
_config = {}
try:
from ra_aid.tools.memory import _global_memory
_config = _global_memory.get("config", {})
except ImportError:
pass
@ -137,15 +140,14 @@ def get_research_tools(
use_aider = False
try:
from ra_aid.tools.memory import _global_memory
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
pass
# Start with read-only tools
tools = get_read_only_tools(
human_interaction,
web_research_enabled,
use_aider=use_aider
human_interaction, web_research_enabled, use_aider=use_aider
).copy()
tools.extend(RESEARCH_TOOLS)
@ -179,14 +181,14 @@ def get_planning_tools(
use_aider = False
try:
from ra_aid.tools.memory import _global_memory
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
pass
# Start with read-only tools
tools = get_read_only_tools(
web_research_enabled=web_research_enabled,
use_aider=use_aider
web_research_enabled=web_research_enabled, use_aider=use_aider
).copy()
# Add planning-specific tools
@ -218,14 +220,14 @@ def get_implementation_tools(
use_aider = False
try:
from ra_aid.tools.memory import _global_memory
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
pass
# Start with read-only tools
tools = get_read_only_tools(
web_research_enabled=web_research_enabled,
use_aider=use_aider
web_research_enabled=web_research_enabled, use_aider=use_aider
).copy()
# Add modification tools since it's not research-only
@ -283,4 +285,4 @@ def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = Fal
if web_research_enabled:
tools.append(request_web_research)
return tools
return tools

View File

@ -5,7 +5,12 @@ from typing import Any, Dict, List, Union
from langchain_core.tools import tool
from rich.console import Console
from ra_aid.agent_context import get_completion_message, get_crash_message, is_crashed, reset_completion_flags
from ra_aid.agent_context import (
get_completion_message,
get_crash_message,
is_crashed,
reset_completion_flags,
)
from ra_aid.console.formatting import print_error
from ra_aid.exceptions import AgentInterrupt
from ra_aid.tools.memory import _global_memory
@ -85,7 +90,9 @@ def request_research(query: str) -> ResearchResult:
reason = f"error: {str(e)}"
finally:
# Get completion message if available
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
completion_message = get_completion_message() or (
"Task was completed successfully." if success else None
)
work_log = get_work_log()
@ -149,7 +156,9 @@ def request_web_research(query: str) -> ResearchResult:
reason = f"error: {str(e)}"
finally:
# Get completion message if available
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
completion_message = get_completion_message() or (
"Task was completed successfully." if success else None
)
work_log = get_work_log()
@ -215,7 +224,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
reason = f"error: {str(e)}"
# Get completion message if available
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
completion_message = get_completion_message() or (
"Task was completed successfully." if success else None
)
work_log = get_work_log()
@ -293,7 +304,9 @@ def request_task_implementation(task_spec: str) -> str:
reason = f"error: {str(e)}"
# Get completion message if available
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
completion_message = get_completion_message() or (
"Task was completed successfully." if success else None
)
# Get and reset work log if at root depth
work_log = get_work_log()
@ -317,45 +330,53 @@ def request_task_implementation(task_spec: str) -> str:
}
if work_log is not None:
response_data["work_log"] = work_log
# Convert the response data to a markdown string
markdown_parts = []
# Add header and completion message
markdown_parts.append("# Task Implementation")
if response_data.get("completion_message"):
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
markdown_parts.append(
f"\n## Completion Message\n\n{response_data['completion_message']}"
)
# Add crash information if applicable
if response_data.get("agent_crashed"):
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
markdown_parts.append(
f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}"
)
# Add success status
status = "Success" if response_data.get("success", False) else "Failed"
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
reason_text = (
f": {response_data.get('reason')}" if response_data.get("reason") else ""
)
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
# Add key facts
if response_data.get("key_facts"):
markdown_parts.append(f"\n## Key Facts\n\n{response_data['key_facts']}")
# Add related files
if response_data.get("related_files"):
files_list = "\n".join([f"- {file}" for file in response_data["related_files"]])
markdown_parts.append(f"\n## Related Files\n\n{files_list}")
# Add key snippets
if response_data.get("key_snippets"):
markdown_parts.append(f"\n## Key Snippets\n\n{response_data['key_snippets']}")
# Add work log
if response_data.get("work_log"):
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
markdown_parts.append(f"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**")
markdown_parts.append(
"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**"
)
# Join all parts into a single markdown string
markdown_output = "".join(markdown_parts)
return markdown_output
@ -403,7 +424,9 @@ def request_implementation(task_spec: str) -> str:
reason = f"error: {str(e)}"
# Get completion message if available
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
completion_message = get_completion_message() or (
"Task was completed successfully." if success else None
)
# Get and reset work log if at root depth
work_log = get_work_log()
@ -427,43 +450,51 @@ def request_implementation(task_spec: str) -> str:
}
if work_log is not None:
response_data["work_log"] = work_log
# Convert the response data to a markdown string
markdown_parts = []
# Add header and completion message
markdown_parts.append("# Implementation Plan")
if response_data.get("completion_message"):
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
markdown_parts.append(
f"\n## Completion Message\n\n{response_data['completion_message']}"
)
# Add crash information if applicable
if response_data.get("agent_crashed"):
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
markdown_parts.append(
f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}"
)
# Add success status
status = "Success" if response_data.get("success", False) else "Failed"
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
reason_text = (
f": {response_data.get('reason')}" if response_data.get("reason") else ""
)
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
# Add key facts
if response_data.get("key_facts"):
markdown_parts.append(f"\n## Key Facts\n\n{response_data['key_facts']}")
# Add related files
if response_data.get("related_files"):
files_list = "\n".join([f"- {file}" for file in response_data["related_files"]])
markdown_parts.append(f"\n## Related Files\n\n{files_list}")
# Add key snippets
if response_data.get("key_snippets"):
markdown_parts.append(f"\n## Key Snippets\n\n{response_data['key_snippets']}")
# Add work log
if response_data.get("work_log"):
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
markdown_parts.append(f"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**")
markdown_parts.append(
"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**"
)
# Join all parts into a single markdown string
markdown_output = "".join(markdown_parts)
return markdown_output
return markdown_output

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional
try:
import magic
@ -12,7 +12,11 @@ from rich.markdown import Markdown
from rich.panel import Panel
from typing_extensions import TypedDict
from ra_aid.agent_context import mark_task_completed, mark_plan_completed, mark_should_exit
from ra_aid.agent_context import (
mark_plan_completed,
mark_should_exit,
mark_task_completed,
)
class WorkLogEntry(TypedDict):
@ -389,7 +393,7 @@ def emit_related_files(files: List[str]) -> str:
invalid_paths.append(file)
results.append(f"Error: Path '{file}' exists but is not a regular file")
continue
# Check if it's a binary file
if is_binary_file(file):
binary_files.append(file)
@ -430,7 +434,7 @@ def emit_related_files(files: List[str]) -> str:
border_style="green",
)
)
# Display skipped binary files
if binary_files:
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
@ -478,18 +482,36 @@ def is_binary_file(filepath):
if magic:
try:
mime = magic.from_file(filepath, mime=True)
return not mime.startswith('text/')
except Exception:
# Fallback if magic fails
return False
else:
# Basic binary detection if magic is not available
try:
with open(filepath, 'r', encoding='utf-8') as f:
f.read(1024) # Try to read as text
file_type = magic.from_file(filepath)
if not mime.startswith("text/"):
return True
if "ASCII text" in file_type:
return False
except UnicodeDecodeError:
return True
except Exception:
return _is_binary_fallback(filepath)
else:
return _is_binary_fallback(filepath)
def _is_binary_fallback(filepath):
"""Fallback method to detect binary files without using magic."""
try:
with open(filepath, "r", encoding="utf-8") as f:
chunk = f.read(1024)
# Check for null bytes which indicate binary content
if "\0" in chunk:
return True
# If we can read it as text without errors, it's probably not binary
return False
except UnicodeDecodeError:
# If we can't decode as UTF-8, it's likely binary
return True
def get_work_log() -> str:

View File

@ -159,7 +159,9 @@ def run_programming_task(
# Return structured output
return {
"output": (truncate_output(result[0].decode()) + extra_ins) if result[0] else "",
"output": (truncate_output(result[0].decode()) + extra_ins)
if result[0]
else "",
"return_code": result[1],
"success": result[1] == 0,
}

View File

@ -138,7 +138,7 @@ def ripgrep_search(
params.append(f"**Before Context Lines**: {before_context_lines}")
if after_context_lines is not None:
params.append(f"**After Context Lines**: {after_context_lines}")
if include_hidden:
params.append("**Including Hidden Files**: yes")
if follow_links:

View File

@ -61,9 +61,7 @@ def put_complete_file_contents(
f"at {filepath} in {result['elapsed_time']:.3f}s"
)
logging.debug(
f"File write complete: {bytes_written} bytes in {elapsed:.2f}s"
)
logging.debug(f"File write complete: {bytes_written} bytes in {elapsed:.2f}s")
console.print(
Panel(

View File

@ -1,23 +1,21 @@
#!/usr/bin/env python3
import sys
import os
from pathlib import Path
import asyncio
from typing import List
import json
import threading
import queue
import traceback
import shutil
import logging
import os
import queue
import sys
import threading
import traceback
from pathlib import Path
from typing import List
# Configure logging
logging.basicConfig(
level=logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s',
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(sys.__stderr__) # Use the real stderr
]
],
)
logger = logging.getLogger(__name__)
@ -26,12 +24,12 @@ project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
import uvicorn
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
app = FastAPI()
@ -55,10 +53,12 @@ app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# Store active WebSocket connections
active_connections: List[WebSocket] = []
def run_ra_aid(message_content, output_queue):
"""Run ra-aid in a separate thread"""
try:
import ra_aid.__main__
logger.info("Successfully imported ra_aid.__main__")
# Override sys.argv
@ -72,47 +72,47 @@ def run_ra_aid(message_content, output_queue):
self.buffer = []
self.box_start = False
self._real_stderr = sys.__stderr__
def write(self, text):
# Always log raw output for debugging
logger.debug(f"Raw output: {repr(text)}")
# Check if this is a box drawing character
if any(c in text for c in '╭╮╰╯│─'):
if any(c in text for c in "╭╮╰╯│─"):
self.box_start = True
self.buffer.append(text)
elif self.box_start and text.strip():
self.buffer.append(text)
if '' in text: # End of box
full_text = ''.join(self.buffer)
if "" in text: # End of box
full_text = "".join(self.buffer)
# Extract content from inside the box
lines = full_text.split('\n')
lines = full_text.split("\n")
content_lines = []
for line in lines:
# Remove box characters and leading/trailing spaces
clean_line = line.strip('╭╮╰╯│─ ')
clean_line = line.strip("╭╮╰╯│─ ")
if clean_line:
content_lines.append(clean_line)
if content_lines:
self.queue.put('\n'.join(content_lines))
self.queue.put("\n".join(content_lines))
self.buffer = []
self.box_start = False
elif not self.box_start and text.strip():
self.queue.put(text.strip())
def flush(self):
if self.buffer:
full_text = ''.join(self.buffer)
full_text = "".join(self.buffer)
# Extract content from partial box
lines = full_text.split('\n')
lines = full_text.split("\n")
content_lines = []
for line in lines:
# Remove box characters and leading/trailing spaces
clean_line = line.strip('╭╮╰╯│─ ')
clean_line = line.strip("╭╮╰╯│─ ")
if clean_line:
content_lines.append(clean_line)
if content_lines:
self.queue.put('\n'.join(content_lines))
self.queue.put("\n".join(content_lines))
self.buffer = []
self.box_start = False
@ -144,6 +144,7 @@ def run_ra_aid(message_content, output_queue):
traceback.print_exc(file=sys.__stderr__)
output_queue.put(f"Error: {str(e)}")
@app.get("/", response_class=HTMLResponse)
async def get_root(request: Request):
"""Serve the index.html file with port parameter."""
@ -151,6 +152,7 @@ async def get_root(request: Request):
"index.html", {"request": request, "server_port": request.url.port or 8080}
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
@ -170,7 +172,9 @@ async def websocket_endpoint(websocket: WebSocket):
output_queue = queue.Queue()
# Create and start thread
thread = threading.Thread(target=run_ra_aid, args=(content, output_queue))
thread = threading.Thread(
target=run_ra_aid, args=(content, output_queue)
)
thread.start()
try:
@ -183,17 +187,21 @@ async def websocket_endpoint(websocket: WebSocket):
line = output_queue.get(timeout=0.1)
if line and line.strip(): # Only send non-empty messages
logger.debug(f"WebSocket sending: {repr(line)}")
await websocket.send_json({
"type": "chunk",
"chunk": {
"agent": {
"messages": [{
"content": line.strip(),
"status": "info"
}]
}
await websocket.send_json(
{
"type": "chunk",
"chunk": {
"agent": {
"messages": [
{
"content": line.strip(),
"status": "info",
}
]
}
},
}
})
)
except queue.Empty:
await asyncio.sleep(0.1)
except Exception as e:
@ -211,10 +219,7 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e:
error_msg = f"Error running ra-aid: {str(e)}"
logger.error(error_msg)
await websocket.send_json({
"type": "error",
"message": error_msg
})
await websocket.send_json({"type": "error", "message": error_msg})
logger.info("Waiting for message...")
@ -243,6 +248,7 @@ def run_server(host: str = "0.0.0.0", port: int = 8080):
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
parser.add_argument(
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"

View File

@ -1,18 +1,23 @@
"""
Tests for the database connection module.
"""
import os
import shutil
from pathlib import Path
import pytest
from unittest.mock import patch
import peewee
from unittest.mock import patch, MagicMock
import pytest
from ra_aid.database.connection import (
init_db, get_db, close_db,
db_var, DatabaseManager, logger
DatabaseManager,
close_db,
db_var,
get_db,
init_db,
)
@pytest.fixture
def cleanup_db():
"""
@ -48,20 +53,23 @@ def cleanup_db():
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch('ra_aid.database.connection.logger') as mock:
with patch("ra_aid.database.connection.logger") as mock:
yield mock
class TestInitDb:
"""Tests for the init_db function."""
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
db = init_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
@ -73,7 +81,7 @@ class TestInitDb:
db = init_db(in_memory=True)
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
def test_init_db_reuses_connection(self, cleanup_db):
@ -91,8 +99,10 @@ class TestInitDb:
assert db1 is db2
assert not db1.is_closed()
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_creates_connection(self, cleanup_db):
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
@ -100,7 +110,7 @@ class TestGetDb:
db = get_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
def test_get_db_reuses_connection(self, cleanup_db):
@ -118,8 +128,10 @@ class TestGetDb:
assert db1 is db2
assert not db1.is_closed()
class TestCloseDb:
"""Tests for the close_db function."""
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
@ -142,14 +154,16 @@ class TestCloseDb:
# This should not raise an exception
close_db()
class TestDatabaseManager:
"""Tests for the DatabaseManager class."""
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
@ -163,7 +177,7 @@ class TestDatabaseManager:
with DatabaseManager(in_memory=True) as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
# Verify the connection is closed after exiting the context
assert db.is_closed()

View File

@ -5,22 +5,19 @@ Tests for the database migrations module.
import os
import shutil
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock, call, PropertyMock
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
import peewee
from peewee_migrate import Router
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.migrations import (
MigrationManager,
init_migrations,
ensure_migrations_applied,
create_new_migration,
get_migration_status,
MIGRATIONS_DIRNAME,
MIGRATIONS_TABLE
MIGRATIONS_TABLE,
MigrationManager,
create_new_migration,
ensure_migrations_applied,
get_migration_status,
init_migrations,
)
@ -37,10 +34,10 @@ def cleanup_db():
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
@ -56,7 +53,7 @@ def cleanup_db():
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch('ra_aid.database.migrations.logger') as mock:
with patch("ra_aid.database.migrations.logger") as mock:
yield mock
@ -83,276 +80,294 @@ def temp_migrations_dir(temp_dir):
@pytest.fixture
def mock_router():
"""Mock the peewee_migrate Router class."""
with patch('ra_aid.database.migrations.Router') as mock:
with patch("ra_aid.database.migrations.Router") as mock:
# Configure the mock router
mock_instance = MagicMock()
mock.return_value = mock_instance
# Set up router properties
mock_instance.todo = ["001_initial", "002_add_users"]
mock_instance.done = ["001_initial"]
yield mock_instance
class TestMigrationManager:
"""Tests for the MigrationManager class."""
def test_init(self, cleanup_db, temp_dir, mock_logger):
"""Test MigrationManager initialization."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Initialize manager
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Verify initialization
assert manager.db_path == db_path
assert manager.migrations_dir == migrations_dir
assert os.path.exists(migrations_dir)
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
# Verify router initialization was logged
mock_logger.debug.assert_any_call(f"Using migrations directory: {migrations_dir}")
mock_logger.debug.assert_any_call(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
mock_logger.debug.assert_any_call(
f"Using migrations directory: {migrations_dir}"
)
mock_logger.debug.assert_any_call(
f"Initialized migration router with table: {MIGRATIONS_TABLE}"
)
def test_ensure_migrations_dir(self, cleanup_db, temp_dir, mock_logger):
"""Test _ensure_migrations_dir creates directory if it doesn't exist."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, "nonexistent_dir", MIGRATIONS_DIRNAME)
# Initialize manager
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Verify directory was created
assert os.path.exists(migrations_dir)
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
# Verify creation was logged
mock_logger.debug.assert_any_call(f"Creating migrations directory at: {migrations_dir}")
mock_logger.debug.assert_any_call(
f"Creating migrations directory at: {migrations_dir}"
)
def test_ensure_migrations_dir_error(self, cleanup_db, mock_logger):
"""Test _ensure_migrations_dir handles errors."""
# Mock os.makedirs to raise an exception
with patch('pathlib.Path.mkdir', side_effect=PermissionError("Permission denied")):
with patch(
"pathlib.Path.mkdir", side_effect=PermissionError("Permission denied")
):
# Set up test paths - use a path that would require elevated permissions
db_path = "/root/test.db"
migrations_dir = "/root/migrations"
# Initialize manager should raise an exception
with pytest.raises(Exception):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
manager = MigrationManager(
db_path=db_path, migrations_dir=migrations_dir
)
# Verify error was logged
mock_logger.error.assert_called_with(
f"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'"
"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'"
)
def test_init_router(self, cleanup_db, temp_dir, mock_router):
"""Test _init_router initializes the Router correctly."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Create the migrations directory
os.makedirs(migrations_dir, exist_ok=True)
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Verify router was initialized
assert manager.router == mock_router
def test_check_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
"""Test check_migrations returns correct migration lists."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Call check_migrations
applied, pending = manager.check_migrations()
# Verify results
assert applied == ["001_initial"]
assert pending == ["002_add_users"]
# Verify logging
mock_logger.debug.assert_called_with(
"Found 1 applied migrations and 1 pending migrations"
)
def test_check_migrations_error(self, cleanup_db, temp_dir, mock_logger):
"""Test check_migrations handles errors."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Create a mock router with a property that raises an exception
mock_router = MagicMock()
# Configure the todo property to raise an exception when accessed
type(mock_router).todo = PropertyMock(side_effect=Exception("Test error"))
mock_router.done = []
# Initialize manager with the mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Directly call check_migrations on the manager with the mocked router
applied, pending = manager.check_migrations()
# Verify empty results are returned on error
assert applied == []
assert pending == []
# Verify error was logged
mock_logger.error.assert_called_with("Failed to check migrations: Test error")
mock_logger.error.assert_called_with(
"Failed to check migrations: Test error"
)
def test_apply_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
"""Test apply_migrations applies pending migrations."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Call apply_migrations
result = manager.apply_migrations()
# Verify result
assert result is True
# Verify migrations were applied
mock_router.run.assert_called_once_with("002_add_users", fake=False)
# Verify logging
mock_logger.info.assert_any_call("Applying 1 pending migrations...")
mock_logger.info.assert_any_call("Applying migration: 002_add_users")
mock_logger.info.assert_any_call("Successfully applied migration: 002_add_users")
mock_logger.info.assert_any_call(
"Successfully applied migration: 002_add_users"
)
mock_logger.info.assert_any_call("Successfully applied 1 migrations")
def test_apply_migrations_no_pending(self, cleanup_db, temp_dir, mock_logger):
"""Test apply_migrations when no migrations are pending."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Create a mock router with no pending migrations
mock_router = MagicMock()
mock_router.todo = ["001_initial"]
mock_router.done = ["001_initial"]
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Call apply_migrations
result = manager.apply_migrations()
# Verify result
assert result is True
# Verify no migrations were applied
mock_router.run.assert_not_called()
# Verify logging
mock_logger.info.assert_called_with("No pending migrations to apply")
def test_apply_migrations_error(self, cleanup_db, temp_dir, mock_logger):
"""Test apply_migrations handles errors during migration."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Create a mock router that raises an exception during run
mock_router = MagicMock()
mock_router.todo = ["001_initial", "002_add_users"]
mock_router.done = ["001_initial"]
mock_router.run.side_effect = Exception("Migration error")
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Call apply_migrations
result = manager.apply_migrations()
# Verify result
assert result is False
# Verify error was logged
mock_logger.error.assert_called_with(
"Failed to apply migration 002_add_users: Migration error"
)
def test_create_migration(self, cleanup_db, temp_dir, mock_router, mock_logger):
"""Test create_migration creates a new migration."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Call create_migration
result = manager.create_migration("add_users", auto=True)
# Verify result contains timestamp and name
assert result is not None
assert "add_users" in result
# Verify migration was created
mock_router.create.assert_called_once()
# Verify logging
mock_logger.info.assert_any_call(f"Creating new migration: {result}")
mock_logger.info.assert_any_call(f"Successfully created migration: {result}")
mock_logger.info.assert_any_call(
f"Successfully created migration: {result}"
)
def test_create_migration_error(self, cleanup_db, temp_dir, mock_logger):
"""Test create_migration handles errors."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Create a mock router that raises an exception during create
mock_router = MagicMock()
mock_router.create.side_effect = Exception("Creation error")
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Call create_migration
result = manager.create_migration("add_users", auto=True)
# Verify result is None on error
assert result is None
# Verify error was logged
mock_logger.error.assert_called_with("Failed to create migration: Creation error")
mock_logger.error.assert_called_with(
"Failed to create migration: Creation error"
)
def test_get_migration_status(self, cleanup_db, temp_dir, mock_router):
"""Test get_migration_status returns correct status information."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Initialize manager with mocked Router
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
# Call get_migration_status
status = manager.get_migration_status()
# Verify status information
assert status["applied_count"] == 1
assert status["pending_count"] == 1
@ -364,81 +379,95 @@ class TestMigrationManager:
class TestMigrationFunctions:
"""Tests for the migration utility functions."""
def test_init_migrations(self, cleanup_db, temp_dir):
"""Test init_migrations returns a MigrationManager instance."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
# Call init_migrations
with patch('ra_aid.database.migrations.MigrationManager') as mock_manager:
with patch("ra_aid.database.migrations.MigrationManager") as mock_manager:
mock_manager.return_value = MagicMock()
manager = init_migrations(db_path=db_path, migrations_dir=migrations_dir)
# Verify MigrationManager was initialized with correct parameters
mock_manager.assert_called_once_with(db_path, migrations_dir)
assert manager == mock_manager.return_value
def test_ensure_migrations_applied(self, cleanup_db, mock_logger):
"""Test ensure_migrations_applied applies pending migrations."""
# Mock MigrationManager
mock_manager = MagicMock()
mock_manager.apply_migrations.return_value = True
# Call ensure_migrations_applied
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
with patch(
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
):
result = ensure_migrations_applied()
# Verify result
assert result is True
# Verify migrations were applied
mock_manager.apply_migrations.assert_called_once()
def test_ensure_migrations_applied_error(self, cleanup_db, mock_logger):
"""Test ensure_migrations_applied handles errors."""
# Call ensure_migrations_applied with an exception
with patch('ra_aid.database.migrations.init_migrations',
side_effect=Exception("Test error")):
with patch(
"ra_aid.database.migrations.init_migrations",
side_effect=Exception("Test error"),
):
result = ensure_migrations_applied()
# Verify result is False on error
assert result is False
# Verify error was logged
mock_logger.error.assert_called_with("Failed to apply migrations: Test error")
mock_logger.error.assert_called_with(
"Failed to apply migrations: Test error"
)
def test_create_new_migration(self, cleanup_db, mock_logger):
"""Test create_new_migration creates a new migration."""
# Mock MigrationManager
mock_manager = MagicMock()
mock_manager.create_migration.return_value = "20250226_123456_test_migration"
# Call create_new_migration
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
with patch(
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
):
result = create_new_migration("test_migration", auto=True)
# Verify result
assert result == "20250226_123456_test_migration"
# Verify migration was created
mock_manager.create_migration.assert_called_once_with("test_migration", True)
mock_manager.create_migration.assert_called_once_with(
"test_migration", True
)
def test_create_new_migration_error(self, cleanup_db, mock_logger):
"""Test create_new_migration handles errors."""
# Call create_new_migration with an exception
with patch('ra_aid.database.migrations.init_migrations',
side_effect=Exception("Test error")):
with patch(
"ra_aid.database.migrations.init_migrations",
side_effect=Exception("Test error"),
):
result = create_new_migration("test_migration", auto=True)
# Verify result is None on error
assert result is None
# Verify error was logged
mock_logger.error.assert_called_with("Failed to create migration: Test error")
mock_logger.error.assert_called_with(
"Failed to create migration: Test error"
)
def test_get_migration_status(self, cleanup_db, mock_logger):
"""Test get_migration_status returns correct status information."""
# Mock MigrationManager
@ -449,13 +478,15 @@ class TestMigrationFunctions:
"applied": ["001_initial", "002_add_users"],
"pending": ["003_add_profiles"],
"migrations_dir": "/test/migrations",
"db_path": "/test/db.sqlite"
"db_path": "/test/db.sqlite",
}
# Call get_migration_status
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
with patch(
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
):
status = get_migration_status()
# Verify status information
assert status["applied_count"] == 2
assert status["pending_count"] == 1
@ -463,31 +494,35 @@ class TestMigrationFunctions:
assert status["pending"] == ["003_add_profiles"]
assert status["migrations_dir"] == "/test/migrations"
assert status["db_path"] == "/test/db.sqlite"
# Verify migration status was retrieved
mock_manager.get_migration_status.assert_called_once()
def test_get_migration_status_error(self, cleanup_db, mock_logger):
"""Test get_migration_status handles errors."""
# Call get_migration_status with an exception
with patch('ra_aid.database.migrations.init_migrations',
side_effect=Exception("Test error")):
with patch(
"ra_aid.database.migrations.init_migrations",
side_effect=Exception("Test error"),
):
status = get_migration_status()
# Verify default status on error
assert status["error"] == "Test error"
assert status["applied_count"] == 0
assert status["pending_count"] == 0
assert status["applied"] == []
assert status["pending"] == []
# Verify error was logged
mock_logger.error.assert_called_with("Failed to get migration status: Test error")
mock_logger.error.assert_called_with(
"Failed to get migration status: Test error"
)
class TestIntegration:
"""Integration tests for the migrations module."""
def test_in_memory_migrations(self, cleanup_db):
"""Test migrations with in-memory database."""
# Initialize in-memory database
@ -496,17 +531,19 @@ class TestIntegration:
with tempfile.TemporaryDirectory() as temp_dir:
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
os.makedirs(migrations_dir, exist_ok=True)
# Create __init__.py to make it a proper package
with open(os.path.join(migrations_dir, "__init__.py"), "w") as f:
pass
# Initialize migration manager
manager = MigrationManager(db_path=":memory:", migrations_dir=migrations_dir)
manager = MigrationManager(
db_path=":memory:", migrations_dir=migrations_dir
)
# Create a test migration
migration_name = manager.create_migration("test_migration", auto=False)
# Write a simple migration file
migration_path = os.path.join(migrations_dir, f"{migration_name}.py")
with open(migration_path, "w") as f:
@ -520,23 +557,27 @@ def migrate(migrator, database, fake=False, **kwargs):
def rollback(migrator, database, fake=False, **kwargs):
migrator.drop_table('test_table')
""")
# Check migrations
applied, pending = manager.check_migrations()
assert len(applied) == 0
assert len(pending) == 1
assert migration_name in pending[0] # Instead of exact equality, check if name is contained
assert (
migration_name in pending[0]
) # Instead of exact equality, check if name is contained
# Apply migrations
result = manager.apply_migrations()
assert result is True
# Check migrations again
applied, pending = manager.check_migrations()
assert len(applied) == 1
assert len(pending) == 0
assert migration_name in applied[0] # Instead of exact equality, check if name is contained
assert (
migration_name in applied[0]
) # Instead of exact equality, check if name is contained
# Verify migration status
status = manager.get_migration_status()
assert status["applied_count"] == 1

View File

@ -4,13 +4,11 @@ Tests for the database models module.
from unittest.mock import patch
import pytest
import peewee
import pytest
from ra_aid.database.connection import db_var, init_db
from ra_aid.database.models import BaseModel
from ra_aid.database.connection import (
db_var, get_db, init_db, close_db
)
@pytest.fixture
@ -26,10 +24,10 @@ def cleanup_db():
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
@ -47,21 +45,21 @@ def setup_test_model(cleanup_db):
"""Set up a test model class for testing."""
# Initialize an in-memory database connection
db = init_db(in_memory=True)
# Define a test model
class TestModel(BaseModel):
name = peewee.CharField()
value = peewee.IntegerField(null=True)
class Meta:
database = db
# Create the table
with db.atomic():
db.create_tables([TestModel], safe=True)
yield TestModel
# Clean up
with db.atomic():
TestModel.drop_table(safe=True)
@ -70,27 +68,28 @@ def setup_test_model(cleanup_db):
def test_base_model_save_updates_timestamps(setup_test_model):
"""Test that saving a model updates the timestamps."""
TestModel = setup_test_model
# Create a new instance
instance = TestModel(name="test", value=42)
instance.save()
# Check that created_at and updated_at are set
assert instance.created_at is not None
assert instance.updated_at is not None
# Store the original timestamps
original_created_at = instance.created_at
original_updated_at = instance.updated_at
# Wait a moment to ensure timestamps would be different
import time
time.sleep(0.001)
# Update the instance
instance.value = 43
instance.save()
# Check that updated_at changed but created_at didn't
assert instance.created_at == original_created_at
assert instance.updated_at != original_updated_at
@ -99,34 +98,36 @@ def test_base_model_save_updates_timestamps(setup_test_model):
def test_base_model_get_or_create(setup_test_model):
"""Test the get_or_create method."""
TestModel = setup_test_model
# First call should create a new instance
instance, created = TestModel.get_or_create(name="test", value=42)
assert created is True
assert instance.name == "test"
assert instance.value == 42
# Second call with same parameters should return existing instance
instance2, created2 = TestModel.get_or_create(name="test", value=42)
assert created2 is False
assert instance2.id == instance.id
# Call with different parameters should create a new instance
instance3, created3 = TestModel.get_or_create(name="test2", value=43)
assert created3 is True
assert instance3.id != instance.id
@patch('ra_aid.database.models.logger')
@patch("ra_aid.database.models.logger")
def test_base_model_get_or_create_handles_errors(mock_logger, setup_test_model):
"""Test that get_or_create handles database errors properly."""
TestModel = setup_test_model
# Mock the parent get_or_create to raise a DatabaseError
with patch('peewee.Model.get_or_create', side_effect=peewee.DatabaseError("Test error")):
with patch(
"peewee.Model.get_or_create", side_effect=peewee.DatabaseError("Test error")
):
# Call should raise the error
with pytest.raises(peewee.DatabaseError):
TestModel.get_or_create(name="test")
# Verify error was logged
mock_logger.error.assert_called_with("Failed in get_or_create: Test error")

View File

@ -2,14 +2,12 @@
Tests for the database utils module.
"""
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
import pytest
import peewee
import pytest
from ra_aid.database.connection import (
db_var, get_db, init_db, close_db
)
from ra_aid.database.connection import db_var, init_db
from ra_aid.database.models import BaseModel
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
@ -27,10 +25,10 @@ def cleanup_db():
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
@ -46,7 +44,7 @@ def cleanup_db():
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch('ra_aid.database.utils.logger') as mock:
with patch("ra_aid.database.utils.logger") as mock:
yield mock
@ -55,22 +53,22 @@ def setup_test_model(cleanup_db):
"""Set up a test model for database tests."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Define a test model class
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
class Meta:
database = db
# Create the test table in a transaction
with db.atomic():
db.create_tables([TestModel], safe=True)
# Yield control to the test
yield TestModel
# Clean up: drop the test table
with db.atomic():
db.drop_tables([TestModel], safe=True)
@ -80,98 +78,90 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
"""Test ensure_tables_created with explicit models."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Define a test model that uses this database
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
class Meta:
database = db
# Call ensure_tables_created with explicit models
ensure_tables_created([TestModel])
# Verify success message was logged
mock_logger.info.assert_called_with("Successfully created tables for 1 models")
# Verify the table exists by trying to use it
TestModel.create(name="test", value=42)
count = TestModel.select().count()
assert count == 1
def test_ensure_tables_created_no_models(cleanup_db, mock_logger):
"""Test ensure_tables_created with no models."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Mock the import to simulate no models found
with patch('ra_aid.database.utils.importlib.import_module', side_effect=ImportError("No module")):
# Call ensure_tables_created with no models
ensure_tables_created()
# Verify warning message was logged
mock_logger.warning.assert_called_with("No models found to create tables for")
@patch('ra_aid.database.utils.get_db')
def test_ensure_tables_created_database_error(mock_get_db, setup_test_model, cleanup_db, mock_logger):
@patch("ra_aid.database.utils.get_db")
def test_ensure_tables_created_database_error(
mock_get_db, setup_test_model, cleanup_db, mock_logger
):
"""Test ensure_tables_created handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create a mock database with a create_tables method that raises an error
mock_db = MagicMock()
mock_db.atomic.return_value.__enter__.return_value = None
mock_db.atomic.return_value.__exit__.return_value = None
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
# Configure get_db to return our mock
mock_get_db.return_value = mock_db
# Call ensure_tables_created and expect an exception
with pytest.raises(peewee.DatabaseError):
ensure_tables_created([TestModel])
# Verify error message was logged
mock_logger.error.assert_called_with("Database Error: Failed to create tables: Test database error")
mock_logger.error.assert_called_with(
"Database Error: Failed to create tables: Test database error"
)
def test_get_model_count(setup_test_model, mock_logger):
"""Test get_model_count returns the correct count."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# First ensure the table is empty
TestModel.delete().execute()
# Create some test records
TestModel.create(name="test1", value=1)
TestModel.create(name="test2", value=2)
# Call get_model_count
count = get_model_count(TestModel)
# Verify the count is correct
assert count == 2
@patch('peewee.ModelSelect.count')
@patch("peewee.ModelSelect.count")
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
"""Test get_model_count handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Configure the mock to raise a DatabaseError
mock_count.side_effect = peewee.DatabaseError("Test count error")
# Call get_model_count
count = get_model_count(TestModel)
# Verify error message was logged
mock_logger.error.assert_called_with("Database Error: Failed to count records: Test count error")
mock_logger.error.assert_called_with(
"Database Error: Failed to count records: Test count error"
)
# Verify the function returns 0 on error
assert count == 0
@ -180,41 +170,45 @@ def test_truncate_table(setup_test_model, mock_logger):
"""Test truncate_table deletes all records."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create some test records
TestModel.create(name="test1", value=1)
TestModel.create(name="test2", value=2)
# Verify records exist
assert TestModel.select().count() == 2
# Call truncate_table
truncate_table(TestModel)
# Verify success message was logged
mock_logger.info.assert_called_with(f"Successfully truncated table for {TestModel.__name__}")
mock_logger.info.assert_called_with(
f"Successfully truncated table for {TestModel.__name__}"
)
# Verify all records were deleted
assert TestModel.select().count() == 0
@patch('ra_aid.database.models.BaseModel.delete')
@patch("ra_aid.database.models.BaseModel.delete")
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
"""Test truncate_table handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create a test record
TestModel.create(name="test", value=42)
# Configure the mock to return a mock query with execute that raises an error
mock_query = MagicMock()
mock_query.execute.side_effect = peewee.DatabaseError("Test delete error")
mock_delete.return_value = mock_query
# Call truncate_table and expect an exception
with pytest.raises(peewee.DatabaseError):
truncate_table(TestModel)
# Verify error message was logged
mock_logger.error.assert_called_with("Database Error: Failed to truncate table: Test delete error")
mock_logger.error.assert_called_with(
"Database Error: Failed to truncate table: Test delete error"
)

View File

@ -0,0 +1,12 @@
This is a test ascii file.
ASCII text, with very long lines (885), with CRLF line terminators
Mumblecore craft beer taxidermy, flannel YOLO pug brunch ugh you probably haven't heard of them art party next level Pinterest squid pork belly. Yr next level Carles, Thundercats dreamcatcher scenester master cleanse bitters disrupt tote bag keffiyeh narwhal organic salvia cray. Whatever heirloom Vice art party pickled, try-hard Williamsburg. Authentic pickled pop-up, letterpress bicycle rights cornhole vinyl Etsy readymade disrupt shabby chic Pitchfork keffiyeh. Master cleanse small batch keytar biodiesel Brooklyn, meggings four loko try-hard McSweeney's vinyl tattooed. Cred Schlitz selvage, tousled Odd Future literally before they sold out synth cardigan retro banh mi next level jean shorts meggings fap. Pork belly four dollar toast quinoa, stumptown taxidermy sriracha whatever you probably haven't heard of them squid single-origin coffee freegan disrupt cliche cardigan.
Bicycle rights cold-pressed Pinterest, beard butcher pickled pop-up synth DIY hashtag. Austin fanny pack farm-to-table keytar, kitsch fap tousled trust fund swag irony +1. Viral fanny pack vinyl, master cleanse 3 wolf moon readymade occupy before they sold out YOLO meggings XOXO art party fap try-hard. Photo booth you probably haven't heard of them artisan pickled Brooklyn cred umami meh, heirloom cray raw denim tousled drinking vinegar. Gentrify Williamsburg iPhone messenger bag heirloom, swag quinoa ennui brunch. Selvage tofu hella gastropub Pinterest, bicycle rights church-key cardigan semiotics cornhole Shoreditch iPhone fixie biodiesel narwhal. Small batch kogi Shoreditch cliche YOLO.
Literally yr ugh Truffaut raw denim four loko. Vice chia mustache, Intelligentsia authentic taxidermy Truffaut synth health goth. Locavore semiotics occupy, synth 8-bit hoodie umami meh PBR&B Wes Anderson brunch shabby chic Helvetica quinoa. YOLO beard pop-up Neutra PBR&B vinyl fixie, stumptown shabby chic flexitarian umami. Cronut Blue Bottle scenester sriracha keytar PBR ennui flannel VHS swag. Dreamcatcher 3 wolf moon fanny pack, tattooed XOXO bitters High Life fixie 8-bit Austin lomo single-origin coffee put a bird on it. High Life Kickstarter twee Blue Bottle shabby chic, biodiesel heirloom.
Wayfarers tousled stumptown pop-up slow-carb. Aesthetic American Apparel hoodie irony YOLO. Meggings synth meh, normcore lomo tote bag post-ironic twee sartorial butcher occupy. Tilde photo booth +1 kogi Williamsburg. Pork belly keytar seitan, pug iPhone fingerstache bitters. Ennui Schlitz actually, cardigan fashion axe Helvetica vegan. Swag lumbersexual blog Carles, cred synth asymmetrical heirloom Tumblr bitters letterpress aesthetic.

View File

@ -176,7 +176,7 @@ def test_strip_trailing_whitespace():
# Create a command that outputs text with trailing whitespace
cmd = 'echo "Line with spaces at end "; echo "Another trailing space line "; echo "Line with tabs at end\t\t"'
output, retcode = run_interactive_command(["/bin/bash", "-c", cmd])
# Check that the output contains the lines without trailing whitespace
lines = output.splitlines()
assert b"Line with spaces at end" in lines[0]

View File

@ -1,12 +1,17 @@
"""Tests for Windows-specific functionality."""
import os
import sys
import subprocess
import pytest
from unittest.mock import patch, MagicMock
import sys
from unittest.mock import MagicMock, patch
import pytest
from ra_aid.proc.interactive import (
create_process,
get_terminal_size,
run_interactive_command,
)
from ra_aid.proc.interactive import get_terminal_size, create_process, run_interactive_command
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests")
class TestWindowsCompatibility:
@ -14,7 +19,7 @@ class TestWindowsCompatibility:
def test_get_terminal_size(self):
"""Test terminal size detection on Windows."""
with patch('shutil.get_terminal_size') as mock_get_size:
with patch("shutil.get_terminal_size") as mock_get_size:
mock_get_size.return_value = MagicMock(columns=120, lines=30)
cols, rows = get_terminal_size()
assert cols == 120
@ -23,30 +28,31 @@ class TestWindowsCompatibility:
def test_create_process(self):
"""Test process creation on Windows."""
with patch('subprocess.Popen') as mock_popen:
with patch("subprocess.Popen") as mock_popen:
mock_process = MagicMock()
mock_process.returncode = 0
mock_popen.return_value = mock_process
proc, _ = create_process(['echo', 'test'])
proc, _ = create_process(["echo", "test"])
assert mock_popen.called
args, kwargs = mock_popen.call_args
assert kwargs['stdin'] == subprocess.PIPE
assert kwargs['stdout'] == subprocess.PIPE
assert kwargs['stderr'] == subprocess.STDOUT
assert 'startupinfo' in kwargs
assert kwargs['startupinfo'].dwFlags & subprocess.STARTF_USESHOWWINDOW
assert kwargs["stdin"] == subprocess.PIPE
assert kwargs["stdout"] == subprocess.PIPE
assert kwargs["stderr"] == subprocess.STDOUT
assert "startupinfo" in kwargs
assert kwargs["startupinfo"].dwFlags & subprocess.STARTF_USESHOWWINDOW
def test_run_interactive_command(self):
"""Test running an interactive command on Windows."""
test_output = "Test output\n"
with patch('subprocess.Popen') as mock_popen, \
patch('pyte.Stream') as mock_stream, \
patch('pyte.HistoryScreen') as mock_screen, \
patch('threading.Thread') as mock_thread:
with (
patch("subprocess.Popen") as mock_popen,
patch("pyte.Stream") as mock_stream,
patch("pyte.HistoryScreen") as mock_screen,
patch("threading.Thread") as mock_thread,
):
# Setup mock process
mock_process = MagicMock()
mock_process.stdout = MagicMock()
@ -54,25 +60,25 @@ class TestWindowsCompatibility:
mock_process.poll.side_effect = [None, 0] # First None, then return 0
mock_process.returncode = 0
mock_popen.return_value = mock_process
# Setup mock screen with history
mock_screen_instance = MagicMock()
mock_screen_instance.history.top = []
mock_screen_instance.history.bottom = []
mock_screen_instance.display = ["Test output"]
mock_screen.return_value = mock_screen_instance
# Setup mock thread
mock_thread_instance = MagicMock()
mock_thread.return_value = mock_thread_instance
# Run the command
output, return_code = run_interactive_command(['echo', 'test'])
output, return_code = run_interactive_command(["echo", "test"])
# Verify results
assert return_code == 0
assert "Test output" in output.decode()
# Verify the thread was started and joined
mock_thread_instance.start.assert_called()
mock_thread_instance.join.assert_called()
@ -80,29 +86,29 @@ class TestWindowsCompatibility:
def test_windows_dependencies(self):
"""Test that required Windows dependencies are available."""
if sys.platform == "win32":
import msvcrt
# If we get here without ImportError, the test passes
assert True
def test_windows_output_handling(self):
"""Test handling of multi-chunk output on Windows."""
if sys.platform != "win32":
pytest.skip("Windows-specific test")
# Test with multiple chunks of output to verify proper handling
with patch('subprocess.Popen') as mock_popen, \
patch('msvcrt.kbhit', return_value=False), \
patch('threading.Thread') as mock_thread, \
patch('time.sleep'): # Mock sleep to speed up test
with (
patch("subprocess.Popen") as mock_popen,
patch("msvcrt.kbhit", return_value=False),
patch("threading.Thread") as mock_thread,
patch("time.sleep"),
): # Mock sleep to speed up test
# Setup mock process
mock_process = MagicMock()
mock_process.stdout = MagicMock()
mock_process.poll.return_value = 0
mock_process.returncode = 0
mock_popen.return_value = mock_process
# Setup mock thread to simulate output collection
def side_effect(*args, **kwargs):
# Simulate thread collecting output
@ -110,15 +116,15 @@ class TestWindowsCompatibility:
b"First chunk\n",
b"Second chunk\n",
b"Third chunk with unicode \xe2\x9c\x93\n", # UTF-8 checkmark
None # End of output
None, # End of output
]
return MagicMock()
mock_thread.side_effect = side_effect
# Run the command
output, return_code = run_interactive_command(['test', 'command'])
output, return_code = run_interactive_command(["test", "command"])
# Verify results
assert return_code == 0
# We can't verify exact output content in this test since we're mocking the thread

View File

@ -2,22 +2,19 @@
import threading
import time
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_task_completed,
mark_plan_completed,
reset_completion_flags,
is_completed,
get_completion_message,
get_current_context,
is_completed,
mark_plan_completed,
mark_should_exit,
mark_task_completed,
reset_completion_flags,
should_exit,
mark_agent_crashed,
is_crashed,
get_crash_message,
)
@ -120,42 +117,42 @@ class TestContextManager:
class TestExitPropagation:
"""Test cases for the agent_should_exit flag propagation."""
def test_mark_should_exit_propagation(self):
"""Test that mark_should_exit propagates to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Initially both contexts should have agent_should_exit as False
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# Both child and parent should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
def test_nested_should_exit_propagation(self):
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
# Initially all contexts should have agent_should_exit as False
assert grandparent.agent_should_exit is False
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# All contexts should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
assert grandparent.agent_should_exit is True
def test_context_manager_should_exit_propagation(self):
"""Test that mark_should_exit propagates when using context managers."""
with agent_context() as outer:
@ -163,10 +160,10 @@ class TestExitPropagation:
# Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark the inner context as should exit
inner.mark_should_exit()
# Both inner and outer should now have agent_should_exit as True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
@ -174,39 +171,39 @@ class TestExitPropagation:
class TestCrashPropagation:
"""Test cases for the agent_has_crashed flag non-propagation."""
def test_mark_agent_crashed_no_propagation(self):
"""Test that mark_agent_crashed does not propagate to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Initially both contexts should have agent_has_crashed as False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Child should be crashed, but parent should not
assert child.is_crashed() is True
assert parent.is_crashed() is False
assert child.agent_crashed_message == "Child crashed"
assert parent.agent_crashed_message is None
def test_nested_crash_no_propagation(self):
"""Test that crash states don't propagate through multiple levels of parent contexts."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
# Initially all contexts should have agent_has_crashed as False
assert grandparent.is_crashed() is False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Only child should be crashed, parent and grandparent should not
assert child.is_crashed() is True
assert parent.is_crashed() is False
@ -214,7 +211,7 @@ class TestCrashPropagation:
assert child.agent_crashed_message == "Child crashed"
assert parent.agent_crashed_message is None
assert grandparent.agent_crashed_message is None
def test_context_manager_crash_no_propagation(self):
"""Test that crash state doesn't propagate when using context managers."""
with agent_context() as outer:
@ -222,27 +219,27 @@ class TestCrashPropagation:
# Initially both contexts should have agent_has_crashed as False
assert outer.is_crashed() is False
assert inner.is_crashed() is False
# Mark the inner context as crashed
inner.mark_agent_crashed("Inner crashed")
# Inner should be crashed, but outer should not
assert inner.is_crashed() is True
assert outer.is_crashed() is False
assert inner.agent_crashed_message == "Inner crashed"
assert outer.agent_crashed_message is None
def test_crash_state_not_inherited(self):
"""Test that new child contexts don't inherit crash states from parent contexts."""
parent = AgentContext()
# Mark the parent as crashed
parent.mark_agent_crashed("Parent crashed")
assert parent.is_crashed() is True
# Create a child context with the crashed parent as parent_context
child = AgentContext(parent_context=parent)
# Child should not be crashed even though parent is
assert parent.is_crashed() is True
assert child.is_crashed() is False
@ -313,18 +310,18 @@ class TestUtilityFunctions:
# These should have safe default returns
assert is_completed() is False
assert get_completion_message() == ""
def test_mark_should_exit_utility(self):
"""Test the mark_should_exit utility function."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert should_exit() is False
# Mark the current context (inner) as should exit
mark_should_exit()
# Both inner and outer should now have agent_should_exit as True
assert should_exit() is True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
assert outer.agent_should_exit is True

View File

@ -8,7 +8,9 @@ import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from ra_aid.agent_context import agent_context, get_current_context, reset_completion_flags
from ra_aid.agent_context import (
agent_context,
)
from ra_aid.agent_utils import (
AgentState,
create_agent,
@ -116,7 +118,10 @@ def test_create_agent_anthropic(mock_model, mock_memory):
assert agent == "react_agent"
mock_react.assert_called_once_with(
mock_model, [], version='v2', state_modifier=mock_react.call_args[1]["state_modifier"]
mock_model,
[],
version="v2",
state_modifier=mock_react.call_args[1]["state_modifier"],
)
@ -259,7 +264,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
agent = create_agent(mock_model, [])
assert agent == "react_agent"
mock_react.assert_called_once_with(mock_model, [], version='v2')
mock_react.assert_called_once_with(mock_model, [], version="v2")
def test_get_model_token_limit_research(mock_memory):
@ -339,7 +344,7 @@ def test_run_agent_stream(monkeypatch):
ctx.plan_completed = True
ctx.task_completed = True
ctx.completion_message = "existing"
call_flag = {"called": False}
def fake_print_agent_output(
@ -352,7 +357,7 @@ def test_run_agent_stream(monkeypatch):
)
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
assert call_flag["called"]
with agent_context() as ctx:
assert ctx.plan_completed is False
assert ctx.task_completed is False
@ -457,74 +462,93 @@ def test_is_anthropic_claude():
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
# Test OpenRouter provider cases
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-2"})
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-instant"})
assert is_anthropic_claude(
{"provider": "openrouter", "model": "anthropic/claude-2"}
)
assert is_anthropic_claude(
{"provider": "openrouter", "model": "anthropic/claude-instant"}
)
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
# Test edge cases
assert not is_anthropic_claude({}) # Empty config
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider
assert not is_anthropic_claude(
{"provider": "other", "model": "claude-2"}
) # Wrong provider
def test_run_agent_with_retry_checks_crash_status(monkeypatch):
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
from ra_aid.agent_utils import run_agent_with_retry
from ra_aid.agent_context import agent_context, mark_agent_crashed
from ra_aid.agent_utils import run_agent_with_retry
# Setup mocks for dependencies to isolate our test
dummy_agent = Mock()
# Track function calls
mock_calls = {"run_agent_stream": 0}
def mock_run_agent_stream(*args, **kwargs):
mock_calls["run_agent_stream"] += 1
def mock_setup_interrupt_handling():
return None
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_is_crashed():
return ctx.is_crashed() if ctx else False
def mock_get_crash_message():
return ctx.agent_crashed_message if ctx and ctx.is_crashed() else None
# Apply mocks
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
monkeypatch.setattr(
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
)
monkeypatch.setattr(
"ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling,
)
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
# First, run without a crash - agent should be run
with agent_context() as ctx:
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
monkeypatch.setattr(
"ra_aid.agent_context.get_crash_message", mock_get_crash_message
)
result = run_agent_with_retry(dummy_agent, "test prompt", {})
assert mock_calls["run_agent_stream"] == 1
# Reset call counter
mock_calls["run_agent_stream"] = 0
# Now run with a crash - agent should not be run
with agent_context() as ctx:
mark_agent_crashed("Test crash message")
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
monkeypatch.setattr(
"ra_aid.agent_context.get_crash_message", mock_get_crash_message
)
result = run_agent_with_retry(dummy_agent, "test prompt", {})
# Verify _run_agent_stream was not called
assert mock_calls["run_agent_stream"] == 0
@ -534,54 +558,65 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
from ra_aid.agent_context import agent_context, is_crashed
from ra_aid.agent_utils import run_agent_with_retry
from ra_aid.exceptions import ToolExecutionError
from ra_aid.agent_context import agent_context, is_crashed
# Setup mocks
dummy_agent = Mock()
# Track function calls and simulate BadRequestError
run_count = [0]
def mock_run_agent_stream(*args, **kwargs):
run_count[0] += 1
if run_count[0] == 1:
# First call throws a 400 BadRequestError
raise ToolExecutionError("400 Bad Request: Invalid input")
# If it's called again, it should run normally
def mock_setup_interrupt_handling():
return None
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True
ctx.agent_crashed_message = message
def mock_is_crashed():
return ctx.is_crashed() if ctx else False
# Apply mocks
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
monkeypatch.setattr(
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
)
monkeypatch.setattr(
"ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling,
)
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
with agent_context() as ctx:
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
monkeypatch.setattr(
"ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed
)
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
result = run_agent_with_retry(dummy_agent, "test prompt", {})
# Verify the agent was only run once and not retried
assert run_count[0] == 1
@ -594,60 +629,73 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
# Import APIError from anthropic module and patch it on the agent_utils module
from anthropic import APIError as AnthropicAPIError
from ra_aid.agent_utils import run_agent_with_retry
from ra_aid.agent_context import agent_context, is_crashed
from ra_aid.agent_utils import run_agent_with_retry
# Setup mocks
dummy_agent = Mock()
# Track function calls and simulate BadRequestError
run_count = [0]
# Create a mock APIError class that simulates Anthropic's APIError
class MockAPIError(Exception):
pass
def mock_run_agent_stream(*args, **kwargs):
run_count[0] += 1
if run_count[0] == 1:
# First call throws a 400 Bad Request APIError
mock_error = MockAPIError("400 Bad Request")
mock_error.__class__.__name__ = "APIError" # Make it look like Anthropic's APIError
mock_error.__class__.__name__ = (
"APIError" # Make it look like Anthropic's APIError
)
raise mock_error
# If it's called again, it should run normally
def mock_setup_interrupt_handling():
return None
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True
ctx.agent_crashed_message = message
def mock_is_crashed():
return ctx.is_crashed() if ctx else False
# Apply mocks
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
monkeypatch.setattr(
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
)
monkeypatch.setattr(
"ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling,
)
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
monkeypatch.setattr(
"ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed
)
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
with agent_context() as ctx:
result = run_agent_with_retry(dummy_agent, "test prompt", {})
# Verify the agent was only run once and not retried

View File

@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional
import pytest
@ -12,14 +12,18 @@ def clean_env():
"""Remove relevant environment variables before each test."""
# Save existing values
saved_vars = {}
for var in ['ANTHROPIC_API_KEY', 'EXPERT_ANTHROPIC_API_KEY',
'ANTHROPIC_MODEL', 'EXPERT_ANTHROPIC_MODEL']:
for var in [
"ANTHROPIC_API_KEY",
"EXPERT_ANTHROPIC_API_KEY",
"ANTHROPIC_MODEL",
"EXPERT_ANTHROPIC_MODEL",
]:
saved_vars[var] = os.environ.get(var)
if var in os.environ:
del os.environ[var]
yield
# Restore saved values
for var, value in saved_vars.items():
if value is not None:
@ -31,6 +35,7 @@ def clean_env():
@dataclass
class MockArgs:
"""Mock arguments class for testing."""
expert_provider: str
expert_model: Optional[str] = None
@ -39,9 +44,9 @@ def test_anthropic_expert_validation_message(clean_env):
"""Test that validation message refers to base key when neither key exists."""
strategy = AnthropicStrategy()
args = MockArgs(expert_provider="anthropic")
result = strategy.validate(args)
assert not result.valid
assert len(result.missing_vars) > 0
assert "ANTHROPIC_API_KEY environment variable is not set" in result.missing_vars[0]

View File

@ -1,56 +1,55 @@
"""Unit tests for crash propagation behavior in agent_context."""
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
mark_agent_crashed,
is_crashed,
get_crash_message,
is_crashed,
mark_agent_crashed,
)
class TestCrashPropagation:
"""Test cases for crash state propagation behavior."""
def test_mark_agent_crashed_no_propagation(self):
"""Test that mark_agent_crashed does not propagate to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Initially both contexts should have is_crashed as False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Child should be crashed but parent should not
assert child.is_crashed() is True
assert child.agent_crashed_message == "Child crashed"
assert parent.is_crashed() is False
assert parent.agent_crashed_message is None
def test_nested_crash_no_propagation(self):
"""Test that crash state doesn't propagate through multiple levels of parent contexts."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
# Initially all contexts should have is_crashed as False
assert grandparent.is_crashed() is False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Only child should be crashed
assert child.is_crashed() is True
assert parent.is_crashed() is False
assert grandparent.is_crashed() is False
def test_context_manager_crash_no_propagation(self):
"""Test that crash states don't propagate when using context managers."""
with agent_context() as outer:
@ -58,14 +57,14 @@ class TestCrashPropagation:
# Initially both contexts should have is_crashed as False
assert outer.is_crashed() is False
assert inner.is_crashed() is False
# Mark the inner context as crashed
inner.mark_agent_crashed("Inner crashed")
# Inner should be crashed but outer should not
assert inner.is_crashed() is True
assert outer.is_crashed() is False
def test_utility_functions_for_crash_state(self):
"""Test utility functions for crash state."""
with agent_context() as outer:
@ -73,12 +72,12 @@ class TestCrashPropagation:
# Initially both contexts should have is_crashed as False
assert is_crashed() is False
assert get_crash_message() is None
# Mark the current context (inner) as crashed
mark_agent_crashed("Utility function crash")
# Current context should be crashed but outer should not
assert is_crashed() is True
assert get_crash_message() == "Utility function crash"
assert inner.is_crashed() is True
assert outer.is_crashed() is False
assert outer.is_crashed() is False

View File

@ -45,12 +45,23 @@ def test_default_anthropic_provider(clean_env, monkeypatch):
"""Test that Anthropic is the default provider when no environment variables are set."""
args = parse_arguments(["-m", "test message"])
assert args.provider == "anthropic"
assert args.model == "claude-3-7-sonnet-20250219" # Updated to match current default
assert (
args.model == "claude-3-7-sonnet-20250219"
) # Updated to match current default
def test_respects_user_specified_anthropic_model(clean_env):
"""Test that user-specified Anthropic models are respected."""
args = parse_arguments(["-m", "test message", "--provider", "anthropic", "--model", "claude-3-5-sonnet-20241022"])
args = parse_arguments(
[
"-m",
"test message",
"--provider",
"anthropic",
"--model",
"claude-3-5-sonnet-20241022",
]
)
assert args.provider == "anthropic"
assert args.model == "claude-3-5-sonnet-20241022" # Should not be overridden

View File

@ -101,7 +101,7 @@ def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
# Check that mock_anthropic was called
assert mock_anthropic.called
# Verify essential parameters
kwargs = mock_anthropic.call_args.kwargs
assert kwargs["api_key"] == "test-key"
@ -123,10 +123,7 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
temperature=0,
timeout=180,
max_retries=5,
default_headers={
"HTTP-Referer": "https://ra-aid.ai",
"X-Title": "RA.Aid"
}
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
)
@ -203,7 +200,7 @@ def test_initialize_anthropic(clean_env, mock_anthropic):
# Check that mock_anthropic was called
assert mock_anthropic.called
# Verify essential parameters
kwargs = mock_anthropic.call_args.kwargs
assert kwargs["api_key"] == "test-key"
@ -225,10 +222,7 @@ def test_initialize_openrouter(clean_env, mock_openai):
temperature=0.7,
timeout=180,
max_retries=5,
default_headers={
"HTTP-Referer": "https://ra-aid.ai",
"X-Title": "RA.Aid"
}
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
)
@ -285,7 +279,7 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin
)
initialize_llm("anthropic", "test-model")
# Verify essential parameters for Anthropic
kwargs = mock_anthropic.call_args.kwargs
assert kwargs["api_key"] == "test-key"
@ -354,7 +348,7 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
# Test Anthropic
initialize_llm("anthropic", "test-model", temperature=test_temp)
# Verify essential parameters for Anthropic
kwargs = mock_anthropic.call_args.kwargs
assert kwargs["api_key"] == "test-key"
@ -372,10 +366,7 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
temperature=test_temp,
timeout=180,
max_retries=5,
default_headers={
"HTTP-Referer": "https://ra-aid.ai",
"X-Title": "RA.Aid"
}
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
)
@ -482,7 +473,7 @@ def test_initialize_llm_cross_provider(
# Initialize Anthropic
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
_llm2 = initialize_llm("anthropic", "claude-3", temperature=0.7)
# Verify essential parameters for Anthropic
kwargs = mock_anthropic.call_args.kwargs
assert kwargs["api_key"] == "anthropic-key"
@ -586,13 +577,15 @@ def mock_deepseek_reasoner():
yield mock
def test_reasoning_effort_only_passed_to_supported_models(clean_env, mock_openai, monkeypatch):
def test_reasoning_effort_only_passed_to_supported_models(
clean_env, mock_openai, monkeypatch
):
"""Test that reasoning_effort is only passed to supported models."""
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
# Initialize expert LLM with GPT-4 (which doesn't support reasoning_effort)
_llm = initialize_expert_llm("openai", "gpt-4")
# Verify reasoning_effort was not included in kwargs
mock_openai.assert_called_with(
api_key="test-key",
@ -603,13 +596,15 @@ def test_reasoning_effort_only_passed_to_supported_models(clean_env, mock_openai
)
def test_reasoning_effort_passed_to_supported_models(clean_env, mock_openai, monkeypatch):
def test_reasoning_effort_passed_to_supported_models(
clean_env, mock_openai, monkeypatch
):
"""Test that reasoning_effort is passed to models that support it."""
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
# Initialize expert LLM with o1 (which supports reasoning_effort)
_llm = initialize_expert_llm("openai", "o1")
# Verify reasoning_effort was included in kwargs
mock_openai.assert_called_with(
api_key="test-key",
@ -664,8 +659,5 @@ def test_initialize_openrouter_deepseek(
temperature=0.7,
timeout=180,
max_retries=5,
default_headers={
"HTTP-Referer": "https://ra-aid.ai",
"X-Title": "RA.Aid"
}
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
)

View File

@ -204,15 +204,15 @@ def test_use_aider_flag(mock_dependencies):
"""Test that use-aider flag is correctly stored in config."""
import sys
from unittest.mock import patch
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
from ra_aid.__main__ import main
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
_global_memory.clear()
# Reset to default state
set_modification_tools(False)
# Check default behavior (use_aider=False)
with patch.object(
sys,
@ -222,15 +222,15 @@ def test_use_aider_flag(mock_dependencies):
main()
config = _global_memory["config"]
assert config.get("use_aider") is False
# Check that file tools are enabled by default
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" in tool_names
assert "put_complete_file_contents" in tool_names
assert "run_programming_task" not in tool_names
_global_memory.clear()
# Check with --use-aider flag
with patch.object(
sys,
@ -240,12 +240,12 @@ def test_use_aider_flag(mock_dependencies):
main()
config = _global_memory["config"]
assert config.get("use_aider") is True
# Check that run_programming_task is enabled
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" not in tool_names
assert "put_complete_file_contents" not in tool_names
assert "run_programming_task" in tool_names
# Reset to default state for other tests
set_modification_tools(False)

View File

@ -1,11 +1,11 @@
from ra_aid.tool_configs import (
MODIFICATION_TOOLS,
get_implementation_tools,
get_planning_tools,
get_read_only_tools,
get_research_tools,
get_web_research_tools,
set_modification_tools,
MODIFICATION_TOOLS,
)
@ -14,16 +14,16 @@ def test_get_read_only_tools():
tools = get_read_only_tools(human_interaction=False, use_aider=False)
assert len(tools) > 0
assert all(callable(tool) for tool in tools)
# Check emit_related_files is not included when use_aider is False
tool_names = [tool.name for tool in tools]
assert "emit_related_files" not in tool_names
# Test with use_aider=True
tools_with_aider = get_read_only_tools(human_interaction=False, use_aider=True)
tool_names_with_aider = [tool.name for tool in tools_with_aider]
assert "emit_related_files" in tool_names_with_aider
# Test with human interaction
tools_with_human = get_read_only_tools(human_interaction=True, use_aider=False)
assert len(tools_with_human) == len(tools) + 1
@ -102,13 +102,13 @@ def test_set_modification_tools():
assert "file_str_replace" in tool_names
assert "put_complete_file_contents" in tool_names
assert "run_programming_task" not in tool_names
# Test with use_aider=True
set_modification_tools(use_aider=True)
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" not in tool_names
assert "put_complete_file_contents" not in tool_names
assert "run_programming_task" in tool_names
# Reset to default for other tests
set_modification_tools(use_aider=False)
set_modification_tools(use_aider=False)

View File

@ -222,7 +222,7 @@ def test_emit_key_snippet(reset_memory):
# Verify counter incremented correctly
assert _global_memory["key_snippet_id_counter"] == 1
# Test snippet without description
snippet2 = {
"filepath": "main.py",
@ -230,16 +230,16 @@ def test_emit_key_snippet(reset_memory):
"snippet": "print('hello')",
"description": None,
}
# Emit second snippet
result = emit_key_snippet.invoke({"snippet_info": snippet2})
# Verify return message
assert result == "Snippet #1 stored."
# Verify snippet stored correctly
assert _global_memory["key_snippets"][1] == snippet2
# Verify counter incremented correctly
assert _global_memory["key_snippet_id_counter"] == 2
@ -723,37 +723,40 @@ def test_emit_related_files_binary_filtering(reset_memory, tmp_path, monkeypatch
text_file1.write_text("Text file 1 content")
text_file2 = tmp_path / "text2.txt"
text_file2.write_text("Text file 2 content")
# Create test "binary" files
binary_file1 = tmp_path / "binary1.bin"
binary_file1.write_text("Binary file 1 content")
binary_file2 = tmp_path / "binary2.bin"
binary_file2.write_text("Binary file 2 content")
# Mock the is_binary_file function to identify our "binary" files
def mock_is_binary_file(filepath):
return ".bin" in str(filepath)
# Apply the mock
import ra_aid.tools.memory
monkeypatch.setattr(ra_aid.tools.memory, "is_binary_file", mock_is_binary_file)
# Call emit_related_files with mix of text and binary files
result = emit_related_files.invoke({
"files": [
str(text_file1),
str(binary_file1),
str(text_file2),
str(binary_file2)
]
})
result = emit_related_files.invoke(
{
"files": [
str(text_file1),
str(binary_file1),
str(text_file2),
str(binary_file2),
]
}
)
# Verify the result message mentions skipped binary files
assert "Files noted." in result
assert "Binary files skipped:" in result
assert f"'{binary_file1}'" in result
assert f"'{binary_file2}'" in result
# Verify only text files were added to related_files
assert len(_global_memory["related_files"]) == 2
file_values = list(_global_memory["related_files"].values())
@ -761,6 +764,60 @@ def test_emit_related_files_binary_filtering(reset_memory, tmp_path, monkeypatch
assert str(text_file2) in file_values
assert str(binary_file1) not in file_values
assert str(binary_file2) not in file_values
# Verify counter is correct (only incremented for text files)
assert _global_memory["related_file_id_counter"] == 2
def test_is_binary_file_with_ascii(reset_memory, monkeypatch):
"""Test that ASCII files are correctly identified as text files"""
import os
import ra_aid.tools.memory
# Path to the mock ASCII file
ascii_file_path = os.path.join(
os.path.dirname(__file__), "..", "mocks", "ascii.txt"
)
# Test with magic library if available
if ra_aid.tools.memory.magic:
# Test real implementation with ASCII file
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
assert not is_binary, "ASCII file should not be identified as binary"
# Test fallback implementation
# Mock magic to be None to force fallback implementation
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
# Test fallback with ASCII file
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
assert (
not is_binary
), "ASCII file should not be identified as binary with fallback method"
def test_is_binary_file_with_null_bytes(reset_memory, tmp_path, monkeypatch):
"""Test that files with null bytes are correctly identified as binary"""
import ra_aid.tools.memory
# Create a file with null bytes (binary content)
binary_file = tmp_path / "binary_with_nulls.bin"
with open(binary_file, "wb") as f:
f.write(b"Some text with \x00 null \x00 bytes")
# Test with magic library if available
if ra_aid.tools.memory.magic:
# Test real implementation with binary file
is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file))
assert is_binary, "File with null bytes should be identified as binary"
# Test fallback implementation
# Mock magic to be None to force fallback implementation
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
# Test fallback with binary file
is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file))
assert (
is_binary
), "File with null bytes should be identified as binary with fallback method"

View File

@ -138,13 +138,13 @@ def test_verify_fix(tmp_path):
# Create a .ra-aid directory inside the temporary directory
ra_aid_dir = tmp_path / ".ra-aid"
ra_aid_dir.mkdir()
# Check that is_new_project() returns True (only .ra-aid directory)
assert is_new_project(str(tmp_path)) is True
# Add a README.md file to the directory
readme_file = tmp_path / "README.md"
readme_file.write_text("# Test Project")
# Check that is_new_project() now returns False (has actual content)
assert is_new_project(str(tmp_path)) is False
assert is_new_project(str(tmp_path)) is False