Binary Skipped ascii filetype fix (#108)
* chore: refactor code for improved readability and maintainability - Standardize variable naming conventions for consistency. - Improve logging messages for better clarity and debugging. - Remove unnecessary imports and clean up code structure. - Enhance error handling and logging in various modules. - Update comments and docstrings for better understanding. - Optimize imports and organize them logically. - Ensure consistent formatting across files for better readability. - Refactor functions to reduce complexity and improve performance. - Add missing type hints and annotations for better code clarity. - Improve test coverage and organization in test files. style(tests): apply consistent formatting and spacing in test files for improved readability and maintainability * chore(tests): remove redundant test for ensure_tables_created with no models to streamline test suite and reduce maintenance overhead * fix(memory.py): update is_binary_file function to correctly identify binary files by returning True for non-text mime types
This commit is contained in:
parent
429f854fb8
commit
e960a68d29
|
|
@ -6,9 +6,8 @@ This script creates a baseline migration representing the current database schem
|
||||||
It serves as the foundation for future schema changes.
|
It serves as the foundation for future schema changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
import sys
|
||||||
|
|
||||||
# Add the project root to the Python path
|
# Add the project root to the Python path
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
@ -20,10 +19,11 @@ from ra_aid.logging_config import get_logger, setup_logging
|
||||||
setup_logging(verbose=True)
|
setup_logging(verbose=True)
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_initial_migration():
|
def create_initial_migration():
|
||||||
"""
|
"""
|
||||||
Create the initial migration for the current database schema.
|
Create the initial migration for the current database schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if migration was created successfully, False otherwise
|
bool: True if migration was created successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
@ -31,11 +31,11 @@ def create_initial_migration():
|
||||||
with DatabaseManager() as db:
|
with DatabaseManager() as db:
|
||||||
# Create a descriptive name for the initial migration
|
# Create a descriptive name for the initial migration
|
||||||
migration_name = "initial_schema"
|
migration_name = "initial_schema"
|
||||||
|
|
||||||
# Create the migration
|
# Create the migration
|
||||||
logger.info(f"Creating initial migration '{migration_name}'...")
|
logger.info(f"Creating initial migration '{migration_name}'...")
|
||||||
result = create_new_migration(migration_name, auto=True)
|
result = create_new_migration(migration_name, auto=True)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
logger.info(f"Successfully created initial migration: {result}")
|
logger.info(f"Successfully created initial migration: {result}")
|
||||||
print(f"✅ Initial migration created successfully: {result}")
|
print(f"✅ Initial migration created successfully: {result}")
|
||||||
|
|
@ -49,9 +49,10 @@ def create_initial_migration():
|
||||||
print(f"❌ Error creating initial migration: {str(e)}")
|
print(f"❌ Error creating initial migration: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Creating initial database migration...")
|
print("Creating initial database migration...")
|
||||||
success = create_initial_migration()
|
success = create_initial_migration()
|
||||||
|
|
||||||
# Exit with appropriate code
|
# Exit with appropriate code
|
||||||
sys.exit(0 if success else 1)
|
sys.exit(0 if success else 1)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from ra_aid.agent_utils import (
|
||||||
run_planning_agent,
|
run_planning_agent,
|
||||||
run_research_agent,
|
run_research_agent,
|
||||||
)
|
)
|
||||||
from ra_aid.database import init_db, close_db, DatabaseManager, ensure_migrations_applied
|
|
||||||
from ra_aid.config import (
|
from ra_aid.config import (
|
||||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||||
DEFAULT_RECURSION_LIMIT,
|
DEFAULT_RECURSION_LIMIT,
|
||||||
|
|
@ -25,6 +24,10 @@ from ra_aid.config import (
|
||||||
VALID_PROVIDERS,
|
VALID_PROVIDERS,
|
||||||
)
|
)
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
|
from ra_aid.database import (
|
||||||
|
DatabaseManager,
|
||||||
|
ensure_migrations_applied,
|
||||||
|
)
|
||||||
from ra_aid.dependencies import check_dependencies
|
from ra_aid.dependencies import check_dependencies
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
|
|
@ -171,8 +174,9 @@ Examples:
|
||||||
"--aider-config", type=str, help="Specify the aider config file path"
|
"--aider-config", type=str, help="Specify the aider config file path"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-aider", action="store_true",
|
"--use-aider",
|
||||||
help="Use aider for code modifications instead of default file tools (file_str_replace, put_complete_file_contents)"
|
action="store_true",
|
||||||
|
help="Use aider for code modifications instead of default file tools (file_str_replace, put_complete_file_contents)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-cmd",
|
"--test-cmd",
|
||||||
|
|
@ -343,24 +347,37 @@ def main():
|
||||||
try:
|
try:
|
||||||
migration_result = ensure_migrations_applied()
|
migration_result = ensure_migrations_applied()
|
||||||
if not migration_result:
|
if not migration_result:
|
||||||
logger.warning("Database migrations failed but execution will continue")
|
logger.warning(
|
||||||
|
"Database migrations failed but execution will continue"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Database migration error: {str(e)}")
|
logger.error(f"Database migration error: {str(e)}")
|
||||||
|
|
||||||
# Check dependencies before proceeding
|
# Check dependencies before proceeding
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
|
(
|
||||||
validate_environment(args)
|
expert_enabled,
|
||||||
) # Will exit if main env vars missing
|
expert_missing,
|
||||||
|
web_research_enabled,
|
||||||
|
web_research_missing,
|
||||||
|
) = validate_environment(args) # Will exit if main env vars missing
|
||||||
logger.debug("Environment validation successful")
|
logger.debug("Environment validation successful")
|
||||||
|
|
||||||
# Validate model configuration early
|
# Validate model configuration early
|
||||||
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
|
model_config = models_params.get(args.provider, {}).get(
|
||||||
|
args.model or "", {}
|
||||||
|
)
|
||||||
supports_temperature = model_config.get(
|
supports_temperature = model_config.get(
|
||||||
"supports_temperature",
|
"supports_temperature",
|
||||||
args.provider
|
args.provider
|
||||||
in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"],
|
in [
|
||||||
|
"anthropic",
|
||||||
|
"openai",
|
||||||
|
"openrouter",
|
||||||
|
"openai-compatible",
|
||||||
|
"deepseek",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if supports_temperature and args.temperature is None:
|
if supports_temperature and args.temperature is None:
|
||||||
|
|
@ -377,7 +394,12 @@ def main():
|
||||||
status = build_status(args, expert_enabled, web_research_enabled)
|
status = build_status(args, expert_enabled, web_research_enabled)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1))
|
Panel(
|
||||||
|
status,
|
||||||
|
title=f"RA.Aid v{__version__}",
|
||||||
|
border_style="bright_blue",
|
||||||
|
padding=(0, 1),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle chat mode
|
# Handle chat mode
|
||||||
|
|
@ -386,7 +408,7 @@ def main():
|
||||||
chat_model = initialize_llm(
|
chat_model = initialize_llm(
|
||||||
args.provider, args.model, temperature=args.temperature
|
args.provider, args.model, temperature=args.temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.research_only:
|
if args.research_only:
|
||||||
print_error("Chat mode cannot be used with --research-only")
|
print_error("Chat mode cannot be used with --research-only")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
@ -429,7 +451,7 @@ def main():
|
||||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||||
_global_memory["config"]["expert_model"] = args.expert_model
|
_global_memory["config"]["expert_model"] = args.expert_model
|
||||||
_global_memory["config"]["temperature"] = args.temperature
|
_global_memory["config"]["temperature"] = args.temperature
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
||||||
|
|
@ -449,7 +471,9 @@ def main():
|
||||||
CHAT_PROMPT.format(
|
CHAT_PROMPT.format(
|
||||||
initial_request=initial_request,
|
initial_request=initial_request,
|
||||||
web_research_section=(
|
web_research_section=(
|
||||||
WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else ""
|
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||||
|
if web_research_enabled
|
||||||
|
else ""
|
||||||
),
|
),
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
|
|
@ -502,11 +526,13 @@ def main():
|
||||||
_global_memory["config"]["research_provider"] = (
|
_global_memory["config"]["research_provider"] = (
|
||||||
args.research_provider or args.provider
|
args.research_provider or args.provider
|
||||||
)
|
)
|
||||||
_global_memory["config"]["research_model"] = args.research_model or args.model
|
_global_memory["config"]["research_model"] = (
|
||||||
|
args.research_model or args.model
|
||||||
|
)
|
||||||
|
|
||||||
# Store temperature in global config
|
# Store temperature in global config
|
||||||
_global_memory["config"]["temperature"] = args.temperature
|
_global_memory["config"]["temperature"] = args.temperature
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, Optional, Set
|
from typing import Optional
|
||||||
|
|
||||||
# Thread-local storage for context variables
|
# Thread-local storage for context variables
|
||||||
_thread_local = threading.local()
|
_thread_local = threading.local()
|
||||||
|
|
@ -19,7 +19,7 @@ class AgentContext:
|
||||||
"""
|
"""
|
||||||
# Store reference to parent context
|
# Store reference to parent context
|
||||||
self.parent = parent_context
|
self.parent = parent_context
|
||||||
|
|
||||||
# Initialize completion flags
|
# Initialize completion flags
|
||||||
self.task_completed = False
|
self.task_completed = False
|
||||||
self.plan_completed = False
|
self.plan_completed = False
|
||||||
|
|
@ -27,8 +27,8 @@ class AgentContext:
|
||||||
self.agent_should_exit = False
|
self.agent_should_exit = False
|
||||||
self.agent_has_crashed = False
|
self.agent_has_crashed = False
|
||||||
self.agent_crashed_message = None
|
self.agent_crashed_message = None
|
||||||
|
|
||||||
# Note: Completion flags (task_completed, plan_completed, completion_message,
|
# Note: Completion flags (task_completed, plan_completed, completion_message,
|
||||||
# agent_should_exit) are no longer inherited from parent contexts
|
# agent_should_exit) are no longer inherited from parent contexts
|
||||||
|
|
||||||
def mark_task_completed(self, message: str) -> None:
|
def mark_task_completed(self, message: str) -> None:
|
||||||
|
|
@ -58,29 +58,29 @@ class AgentContext:
|
||||||
|
|
||||||
def mark_should_exit(self) -> None:
|
def mark_should_exit(self) -> None:
|
||||||
"""Mark that the agent should exit execution.
|
"""Mark that the agent should exit execution.
|
||||||
|
|
||||||
This propagates the exit state to all parent contexts.
|
This propagates the exit state to all parent contexts.
|
||||||
"""
|
"""
|
||||||
self.agent_should_exit = True
|
self.agent_should_exit = True
|
||||||
|
|
||||||
# Propagate to parent context if it exists
|
# Propagate to parent context if it exists
|
||||||
if self.parent:
|
if self.parent:
|
||||||
self.parent.mark_should_exit()
|
self.parent.mark_should_exit()
|
||||||
|
|
||||||
def mark_agent_crashed(self, message: str) -> None:
|
def mark_agent_crashed(self, message: str) -> None:
|
||||||
"""Mark the agent as crashed with the given message.
|
"""Mark the agent as crashed with the given message.
|
||||||
|
|
||||||
Unlike exit state, crash state does not propagate to parent contexts.
|
Unlike exit state, crash state does not propagate to parent contexts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: Error message explaining the crash
|
message: Error message explaining the crash
|
||||||
"""
|
"""
|
||||||
self.agent_has_crashed = True
|
self.agent_has_crashed = True
|
||||||
self.agent_crashed_message = message
|
self.agent_crashed_message = message
|
||||||
|
|
||||||
def is_crashed(self) -> bool:
|
def is_crashed(self) -> bool:
|
||||||
"""Check if the agent has crashed.
|
"""Check if the agent has crashed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the agent has crashed, False otherwise
|
True if the agent has crashed, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
@ -116,17 +116,17 @@ def agent_context(parent_context=None):
|
||||||
"""
|
"""
|
||||||
# Save the previous context
|
# Save the previous context
|
||||||
previous_context = getattr(_thread_local, "current_context", None)
|
previous_context = getattr(_thread_local, "current_context", None)
|
||||||
|
|
||||||
# Create a new context, inheriting from parent if provided
|
# Create a new context, inheriting from parent if provided
|
||||||
# If parent_context is None but previous_context exists, use previous_context as parent
|
# If parent_context is None but previous_context exists, use previous_context as parent
|
||||||
if parent_context is None and previous_context is not None:
|
if parent_context is None and previous_context is not None:
|
||||||
context = AgentContext(previous_context)
|
context = AgentContext(previous_context)
|
||||||
else:
|
else:
|
||||||
context = AgentContext(parent_context)
|
context = AgentContext(parent_context)
|
||||||
|
|
||||||
# Set as current context
|
# Set as current context
|
||||||
_thread_local.current_context = context
|
_thread_local.current_context = context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield context
|
yield context
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -202,7 +202,7 @@ def mark_should_exit() -> None:
|
||||||
|
|
||||||
def is_crashed() -> bool:
|
def is_crashed() -> bool:
|
||||||
"""Check if the current agent has crashed.
|
"""Check if the current agent has crashed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the current agent has crashed, False otherwise
|
True if the current agent has crashed, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
@ -212,7 +212,7 @@ def is_crashed() -> bool:
|
||||||
|
|
||||||
def mark_agent_crashed(message: str) -> None:
|
def mark_agent_crashed(message: str) -> None:
|
||||||
"""Mark the current agent as crashed with the given message.
|
"""Mark the current agent as crashed with the given message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: Error message explaining the crash
|
message: Error message explaining the crash
|
||||||
"""
|
"""
|
||||||
|
|
@ -223,9 +223,9 @@ def mark_agent_crashed(message: str) -> None:
|
||||||
|
|
||||||
def get_crash_message() -> Optional[str]:
|
def get_crash_message() -> Optional[str]:
|
||||||
"""Get the crash message from the current context.
|
"""Get the crash message from the current context.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The crash message or None if the agent has not crashed
|
The crash message or None if the agent has not crashed
|
||||||
"""
|
"""
|
||||||
context = get_current_context()
|
context = get_current_context()
|
||||||
return context.agent_crashed_message if context and context.is_crashed() else None
|
return context.agent_crashed_message if context and context.is_crashed() else None
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence, ContextManager
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
|
|
@ -30,6 +30,12 @@ from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from ra_aid.agent_context import (
|
||||||
|
agent_context,
|
||||||
|
is_completed,
|
||||||
|
reset_completion_flags,
|
||||||
|
should_exit,
|
||||||
|
)
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.agents_alias import RAgents
|
from ra_aid.agents_alias import RAgents
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||||
|
|
@ -72,14 +78,6 @@ from ra_aid.tool_configs import (
|
||||||
get_web_research_tools,
|
get_web_research_tools,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
from ra_aid.agent_context import (
|
|
||||||
agent_context,
|
|
||||||
get_current_context,
|
|
||||||
is_completed,
|
|
||||||
reset_completion_flags,
|
|
||||||
get_completion_message,
|
|
||||||
should_exit,
|
|
||||||
)
|
|
||||||
from ra_aid.tools.memory import (
|
from ra_aid.tools.memory import (
|
||||||
_global_memory,
|
_global_memory,
|
||||||
get_memory_value,
|
get_memory_value,
|
||||||
|
|
@ -250,8 +248,12 @@ def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||||
provider = config.get("provider", "")
|
provider = config.get("provider", "")
|
||||||
model_name = config.get("model", "")
|
model_name = config.get("model", "")
|
||||||
result = (
|
result = (
|
||||||
(provider.lower() == "anthropic" and model_name and "claude" in model_name.lower())
|
provider.lower() == "anthropic"
|
||||||
or (provider.lower() == "openrouter" and model_name.lower().startswith("anthropic/claude-"))
|
and model_name
|
||||||
|
and "claude" in model_name.lower()
|
||||||
|
) or (
|
||||||
|
provider.lower() == "openrouter"
|
||||||
|
and model_name.lower().startswith("anthropic/claude-")
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -953,14 +955,15 @@ def run_agent_with_retry(
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
|
|
||||||
# Check if the agent has crashed before attempting to run it
|
# Check if the agent has crashed before attempting to run it
|
||||||
from ra_aid.agent_context import is_crashed, get_crash_message
|
from ra_aid.agent_context import get_crash_message, is_crashed
|
||||||
|
|
||||||
if is_crashed():
|
if is_crashed():
|
||||||
crash_message = get_crash_message()
|
crash_message = get_crash_message()
|
||||||
logger.error("Agent has crashed: %s", crash_message)
|
logger.error("Agent has crashed: %s", crash_message)
|
||||||
return f"Agent has crashed: {crash_message}"
|
return f"Agent has crashed: {crash_message}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_run_agent_stream(agent, msg_list, config)
|
_run_agent_stream(agent, msg_list, config)
|
||||||
if fallback_handler:
|
if fallback_handler:
|
||||||
|
|
@ -982,11 +985,12 @@ def run_agent_with_retry(
|
||||||
error_str = str(e).lower()
|
error_str = str(e).lower()
|
||||||
if "400" in error_str or "bad request" in error_str:
|
if "400" in error_str or "bad request" in error_str:
|
||||||
from ra_aid.agent_context import mark_agent_crashed
|
from ra_aid.agent_context import mark_agent_crashed
|
||||||
|
|
||||||
crash_message = f"Unretryable error: {str(e)}"
|
crash_message = f"Unretryable error: {str(e)}"
|
||||||
mark_agent_crashed(crash_message)
|
mark_agent_crashed(crash_message)
|
||||||
logger.error("Agent has crashed: %s", crash_message)
|
logger.error("Agent has crashed: %s", crash_message)
|
||||||
return f"Agent has crashed: {crash_message}"
|
return f"Agent has crashed: {crash_message}"
|
||||||
|
|
||||||
_handle_fallback_response(e, fallback_handler, agent, msg_list)
|
_handle_fallback_response(e, fallback_handler, agent, msg_list)
|
||||||
continue
|
continue
|
||||||
except FallbackToolExecutionError as e:
|
except FallbackToolExecutionError as e:
|
||||||
|
|
@ -1007,13 +1011,16 @@ def run_agent_with_retry(
|
||||||
) as e:
|
) as e:
|
||||||
# Check if this is a BadRequestError (HTTP 400) which is unretryable
|
# Check if this is a BadRequestError (HTTP 400) which is unretryable
|
||||||
error_str = str(e).lower()
|
error_str = str(e).lower()
|
||||||
if ("400" in error_str or "bad request" in error_str) and isinstance(e, APIError):
|
if (
|
||||||
|
"400" in error_str or "bad request" in error_str
|
||||||
|
) and isinstance(e, APIError):
|
||||||
from ra_aid.agent_context import mark_agent_crashed
|
from ra_aid.agent_context import mark_agent_crashed
|
||||||
|
|
||||||
crash_message = f"Unretryable API error: {str(e)}"
|
crash_message = f"Unretryable API error: {str(e)}"
|
||||||
mark_agent_crashed(crash_message)
|
mark_agent_crashed(crash_message)
|
||||||
logger.error("Agent has crashed: %s", crash_message)
|
logger.error("Agent has crashed: %s", crash_message)
|
||||||
return f"Agent has crashed: {crash_message}"
|
return f"Agent has crashed: {crash_message}"
|
||||||
|
|
||||||
_handle_api_error(e, attempt, max_retries, base_delay)
|
_handle_api_error(e, attempt, max_retries, base_delay)
|
||||||
finally:
|
finally:
|
||||||
_decrement_agent_depth()
|
_decrement_agent_depth()
|
||||||
|
|
|
||||||
|
|
@ -5,34 +5,29 @@ This package provides database functionality for the ra_aid application,
|
||||||
including connection management, models, utility functions, and migrations.
|
including connection management, models, utility functions, and migrations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from ra_aid.database.connection import (
|
from ra_aid.database.connection import DatabaseManager, close_db, get_db, init_db
|
||||||
init_db,
|
from ra_aid.database.migrations import (
|
||||||
get_db,
|
MigrationManager,
|
||||||
close_db,
|
create_new_migration,
|
||||||
DatabaseManager
|
ensure_migrations_applied,
|
||||||
|
get_migration_status,
|
||||||
|
init_migrations,
|
||||||
)
|
)
|
||||||
from ra_aid.database.models import BaseModel
|
from ra_aid.database.models import BaseModel
|
||||||
from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created
|
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
||||||
from ra_aid.database.migrations import (
|
|
||||||
init_migrations,
|
|
||||||
ensure_migrations_applied,
|
|
||||||
create_new_migration,
|
|
||||||
get_migration_status,
|
|
||||||
MigrationManager
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'init_db',
|
"init_db",
|
||||||
'get_db',
|
"get_db",
|
||||||
'close_db',
|
"close_db",
|
||||||
'DatabaseManager',
|
"DatabaseManager",
|
||||||
'BaseModel',
|
"BaseModel",
|
||||||
'get_model_count',
|
"get_model_count",
|
||||||
'truncate_table',
|
"truncate_table",
|
||||||
'ensure_tables_created',
|
"ensure_tables_created",
|
||||||
'init_migrations',
|
"init_migrations",
|
||||||
'ensure_migrations_applied',
|
"ensure_migrations_applied",
|
||||||
'create_new_migration',
|
"create_new_migration",
|
||||||
'get_migration_status',
|
"get_migration_status",
|
||||||
'MigrationManager',
|
"MigrationManager",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,10 @@ This module provides functions to initialize, get, and close database connection
|
||||||
It also provides a context manager for database connections.
|
It also provides a context manager for database connections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import contextvars
|
import contextvars
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
|
|
@ -22,64 +22,69 @@ logger = get_logger(__name__)
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
"""
|
"""
|
||||||
Context manager for database connections.
|
Context manager for database connections.
|
||||||
|
|
||||||
This class provides a context manager interface for database connections,
|
This class provides a context manager interface for database connections,
|
||||||
using the existing contextvars approach internally.
|
using the existing contextvars approach internally.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
with DatabaseManager() as db:
|
with DatabaseManager() as db:
|
||||||
# Use the database connection
|
# Use the database connection
|
||||||
db.execute_sql("SELECT * FROM table")
|
db.execute_sql("SELECT * FROM table")
|
||||||
|
|
||||||
# Or with in-memory database:
|
# Or with in-memory database:
|
||||||
with DatabaseManager(in_memory=True) as db:
|
with DatabaseManager(in_memory=True) as db:
|
||||||
# Use in-memory database
|
# Use in-memory database
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_memory: bool = False):
|
def __init__(self, in_memory: bool = False):
|
||||||
"""
|
"""
|
||||||
Initialize the DatabaseManager.
|
Initialize the DatabaseManager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_memory: Whether to use an in-memory database (default: False)
|
in_memory: Whether to use an in-memory database (default: False)
|
||||||
"""
|
"""
|
||||||
self.in_memory = in_memory
|
self.in_memory = in_memory
|
||||||
|
|
||||||
def __enter__(self) -> peewee.SqliteDatabase:
|
def __enter__(self) -> peewee.SqliteDatabase:
|
||||||
"""
|
"""
|
||||||
Initialize the database connection and return it.
|
Initialize the database connection and return it.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
peewee.SqliteDatabase: The initialized database connection
|
peewee.SqliteDatabase: The initialized database connection
|
||||||
"""
|
"""
|
||||||
return init_db(in_memory=self.in_memory)
|
return init_db(in_memory=self.in_memory)
|
||||||
|
|
||||||
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception],
|
def __exit__(
|
||||||
exc_tb: Optional[Any]) -> None:
|
self,
|
||||||
|
exc_type: Optional[type],
|
||||||
|
exc_val: Optional[Exception],
|
||||||
|
exc_tb: Optional[Any],
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Close the database connection when exiting the context.
|
Close the database connection when exiting the context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
exc_type: The exception type if an exception was raised
|
exc_type: The exception type if an exception was raised
|
||||||
exc_val: The exception value if an exception was raised
|
exc_val: The exception value if an exception was raised
|
||||||
exc_tb: The traceback if an exception was raised
|
exc_tb: The traceback if an exception was raised
|
||||||
"""
|
"""
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
# Don't suppress exceptions
|
# Don't suppress exceptions
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
"""
|
"""
|
||||||
Initialize the database connection.
|
Initialize the database connection.
|
||||||
|
|
||||||
Creates the .ra-aid directory if it doesn't exist and initializes
|
Creates the .ra-aid directory if it doesn't exist and initializes
|
||||||
the SQLite database connection. If a database connection already exists,
|
the SQLite database connection. If a database connection already exists,
|
||||||
returns the existing connection instead of creating a new one.
|
returns the existing connection instead of creating a new one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_memory: Whether to use an in-memory database (default: False)
|
in_memory: Whether to use an in-memory database (default: False)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
peewee.SqliteDatabase: The initialized database connection
|
peewee.SqliteDatabase: The initialized database connection
|
||||||
"""
|
"""
|
||||||
|
|
@ -98,7 +103,7 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
else:
|
else:
|
||||||
# Connection exists and is open, return it
|
# Connection exists and is open, return it
|
||||||
return existing_db
|
return existing_db
|
||||||
|
|
||||||
# Set up database path
|
# Set up database path
|
||||||
if in_memory:
|
if in_memory:
|
||||||
# Use in-memory database
|
# Use in-memory database
|
||||||
|
|
@ -108,25 +113,29 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
# Get current working directory and create .ra-aid directory if it doesn't exist
|
# Get current working directory and create .ra-aid directory if it doesn't exist
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
logger.debug(f"Current working directory: {cwd}")
|
logger.debug(f"Current working directory: {cwd}")
|
||||||
|
|
||||||
# Define the .ra-aid directory path
|
# Define the .ra-aid directory path
|
||||||
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
|
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
|
||||||
ra_aid_dir = Path(ra_aid_dir_str)
|
ra_aid_dir = Path(ra_aid_dir_str)
|
||||||
ra_aid_dir = ra_aid_dir.absolute() # Ensure we have the absolute path
|
ra_aid_dir = ra_aid_dir.absolute() # Ensure we have the absolute path
|
||||||
ra_aid_dir_str = str(ra_aid_dir) # Update string representation with absolute path
|
ra_aid_dir_str = str(
|
||||||
|
ra_aid_dir
|
||||||
|
) # Update string representation with absolute path
|
||||||
|
|
||||||
logger.debug(f"Creating database directory at: {ra_aid_dir_str}")
|
logger.debug(f"Creating database directory at: {ra_aid_dir_str}")
|
||||||
|
|
||||||
# Multiple approaches to ensure directory creation
|
# Multiple approaches to ensure directory creation
|
||||||
directory_created = False
|
directory_created = False
|
||||||
error_messages = []
|
error_messages = []
|
||||||
|
|
||||||
# Approach 1: Try os.mkdir directly
|
# Approach 1: Try os.mkdir directly
|
||||||
if not os.path.exists(ra_aid_dir_str):
|
if not os.path.exists(ra_aid_dir_str):
|
||||||
try:
|
try:
|
||||||
logger.debug("Attempting directory creation with os.mkdir")
|
logger.debug("Attempting directory creation with os.mkdir")
|
||||||
os.mkdir(ra_aid_dir_str, mode=0o755)
|
os.mkdir(ra_aid_dir_str, mode=0o755)
|
||||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
|
||||||
|
ra_aid_dir_str
|
||||||
|
)
|
||||||
if directory_created:
|
if directory_created:
|
||||||
logger.debug("Directory created successfully with os.mkdir")
|
logger.debug("Directory created successfully with os.mkdir")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -136,40 +145,46 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
else:
|
else:
|
||||||
logger.debug("Directory already exists, skipping creation")
|
logger.debug("Directory already exists, skipping creation")
|
||||||
directory_created = True
|
directory_created = True
|
||||||
|
|
||||||
# Approach 2: Try os.makedirs if os.mkdir failed
|
# Approach 2: Try os.makedirs if os.mkdir failed
|
||||||
if not directory_created:
|
if not directory_created:
|
||||||
try:
|
try:
|
||||||
logger.debug("Attempting directory creation with os.makedirs")
|
logger.debug("Attempting directory creation with os.makedirs")
|
||||||
os.makedirs(ra_aid_dir_str, exist_ok=True, mode=0o755)
|
os.makedirs(ra_aid_dir_str, exist_ok=True, mode=0o755)
|
||||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
|
||||||
|
ra_aid_dir_str
|
||||||
|
)
|
||||||
if directory_created:
|
if directory_created:
|
||||||
logger.debug("Directory created successfully with os.makedirs")
|
logger.debug("Directory created successfully with os.makedirs")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"os.makedirs failed: {str(e)}"
|
error_msg = f"os.makedirs failed: {str(e)}"
|
||||||
logger.debug(error_msg)
|
logger.debug(error_msg)
|
||||||
error_messages.append(error_msg)
|
error_messages.append(error_msg)
|
||||||
|
|
||||||
# Approach 3: Try Path.mkdir if previous methods failed
|
# Approach 3: Try Path.mkdir if previous methods failed
|
||||||
if not directory_created:
|
if not directory_created:
|
||||||
try:
|
try:
|
||||||
logger.debug("Attempting directory creation with Path.mkdir")
|
logger.debug("Attempting directory creation with Path.mkdir")
|
||||||
ra_aid_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
|
ra_aid_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
|
||||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(
|
||||||
|
ra_aid_dir_str
|
||||||
|
)
|
||||||
if directory_created:
|
if directory_created:
|
||||||
logger.debug("Directory created successfully with Path.mkdir")
|
logger.debug("Directory created successfully with Path.mkdir")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Path.mkdir failed: {str(e)}"
|
error_msg = f"Path.mkdir failed: {str(e)}"
|
||||||
logger.debug(error_msg)
|
logger.debug(error_msg)
|
||||||
error_messages.append(error_msg)
|
error_messages.append(error_msg)
|
||||||
|
|
||||||
# Verify the directory was actually created
|
# Verify the directory was actually created
|
||||||
path_exists = ra_aid_dir.exists()
|
path_exists = ra_aid_dir.exists()
|
||||||
os_exists = os.path.exists(ra_aid_dir_str)
|
os_exists = os.path.exists(ra_aid_dir_str)
|
||||||
is_dir = os.path.isdir(ra_aid_dir_str) if os_exists else False
|
is_dir = os.path.isdir(ra_aid_dir_str) if os_exists else False
|
||||||
|
|
||||||
logger.debug(f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}")
|
logger.debug(
|
||||||
|
f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check parent directory permissions and contents for debugging
|
# Check parent directory permissions and contents for debugging
|
||||||
try:
|
try:
|
||||||
parent_dir = os.path.dirname(ra_aid_dir_str)
|
parent_dir = os.path.dirname(ra_aid_dir_str)
|
||||||
|
|
@ -179,19 +194,21 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
logger.debug(f"Parent directory contents: {parent_contents}")
|
logger.debug(f"Parent directory contents: {parent_contents}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not check parent directory: {str(e)}")
|
logger.debug(f"Could not check parent directory: {str(e)}")
|
||||||
|
|
||||||
if not os_exists or not is_dir:
|
if not os_exists or not is_dir:
|
||||||
error_msg = f"Directory does not exist or is not a directory after creation attempts: {ra_aid_dir_str}"
|
error_msg = f"Directory does not exist or is not a directory after creation attempts: {ra_aid_dir_str}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
if error_messages:
|
if error_messages:
|
||||||
logger.error(f"Previous errors: {', '.join(error_messages)}")
|
logger.error(f"Previous errors: {', '.join(error_messages)}")
|
||||||
raise FileNotFoundError(f"Failed to create directory: {ra_aid_dir_str}")
|
raise FileNotFoundError(f"Failed to create directory: {ra_aid_dir_str}")
|
||||||
|
|
||||||
# Check directory permissions
|
# Check directory permissions
|
||||||
try:
|
try:
|
||||||
permissions = oct(os.stat(ra_aid_dir_str).st_mode)[-3:]
|
permissions = oct(os.stat(ra_aid_dir_str).st_mode)[-3:]
|
||||||
logger.debug(f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}")
|
logger.debug(
|
||||||
|
f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}"
|
||||||
|
)
|
||||||
|
|
||||||
# List directory contents for debugging
|
# List directory contents for debugging
|
||||||
dir_contents = os.listdir(ra_aid_dir_str)
|
dir_contents = os.listdir(ra_aid_dir_str)
|
||||||
logger.debug(f"Directory contents: {dir_contents}")
|
logger.debug(f"Directory contents: {dir_contents}")
|
||||||
|
|
@ -201,21 +218,21 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
# Database path for file-based database - use os.path.join for maximum compatibility
|
# Database path for file-based database - use os.path.join for maximum compatibility
|
||||||
db_path = os.path.join(ra_aid_dir_str, "pk.db")
|
db_path = os.path.join(ra_aid_dir_str, "pk.db")
|
||||||
logger.debug(f"Database path: {db_path}")
|
logger.debug(f"Database path: {db_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# For file-based databases, ensure the file exists or can be created
|
# For file-based databases, ensure the file exists or can be created
|
||||||
if db_path != ":memory:":
|
if db_path != ":memory:":
|
||||||
# Check if the database file exists
|
# Check if the database file exists
|
||||||
db_file_exists = os.path.exists(db_path)
|
db_file_exists = os.path.exists(db_path)
|
||||||
logger.debug(f"Database file exists check: {db_file_exists}")
|
logger.debug(f"Database file exists check: {db_file_exists}")
|
||||||
|
|
||||||
# If the file doesn't exist, try to create an empty file to ensure we have write permissions
|
# If the file doesn't exist, try to create an empty file to ensure we have write permissions
|
||||||
if not db_file_exists:
|
if not db_file_exists:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Creating empty database file at: {db_path}")
|
logger.debug(f"Creating empty database file at: {db_path}")
|
||||||
with open(db_path, 'w') as f:
|
with open(db_path, "w") as f:
|
||||||
pass # Create empty file
|
pass # Create empty file
|
||||||
|
|
||||||
# Verify the file was created
|
# Verify the file was created
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
logger.debug("Empty database file created successfully")
|
logger.debug("Empty database file created successfully")
|
||||||
|
|
@ -224,51 +241,53 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating database file: {str(e)}")
|
logger.error(f"Error creating database file: {str(e)}")
|
||||||
# Continue anyway, as SQLite might be able to create the file itself
|
# Continue anyway, as SQLite might be able to create the file itself
|
||||||
|
|
||||||
# Initialize the database connection
|
# Initialize the database connection
|
||||||
logger.debug(f"Initializing SQLite database at: {db_path}")
|
logger.debug(f"Initializing SQLite database at: {db_path}")
|
||||||
db = peewee.SqliteDatabase(
|
db = peewee.SqliteDatabase(
|
||||||
db_path,
|
db_path,
|
||||||
pragmas={
|
pragmas={
|
||||||
'journal_mode': 'wal', # Write-Ahead Logging for better concurrency
|
"journal_mode": "wal", # Write-Ahead Logging for better concurrency
|
||||||
'foreign_keys': 1, # Enforce foreign key constraints
|
"foreign_keys": 1, # Enforce foreign key constraints
|
||||||
'cache_size': -1024 * 32, # 32MB cache
|
"cache_size": -1024 * 32, # 32MB cache
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Always explicitly connect to ensure the connection is established
|
# Always explicitly connect to ensure the connection is established
|
||||||
if db.is_closed():
|
if db.is_closed():
|
||||||
logger.debug("Explicitly connecting to database")
|
logger.debug("Explicitly connecting to database")
|
||||||
db.connect()
|
db.connect()
|
||||||
|
|
||||||
# Store the database connection in the contextvar
|
# Store the database connection in the contextvar
|
||||||
db_var.set(db)
|
db_var.set(db)
|
||||||
|
|
||||||
# Store whether this is an in-memory database (for backward compatibility)
|
# Store whether this is an in-memory database (for backward compatibility)
|
||||||
db._is_in_memory = in_memory
|
db._is_in_memory = in_memory
|
||||||
|
|
||||||
# Verify the database is usable by executing a simple query
|
# Verify the database is usable by executing a simple query
|
||||||
if not in_memory:
|
if not in_memory:
|
||||||
try:
|
try:
|
||||||
db.execute_sql("SELECT 1")
|
db.execute_sql("SELECT 1")
|
||||||
logger.debug("Database connection verified with test query")
|
logger.debug("Database connection verified with test query")
|
||||||
|
|
||||||
# Check if the database file exists after initialization
|
# Check if the database file exists after initialization
|
||||||
db_file_exists = os.path.exists(db_path)
|
db_file_exists = os.path.exists(db_path)
|
||||||
db_file_size = os.path.getsize(db_path) if db_file_exists else 0
|
db_file_size = os.path.getsize(db_path) if db_file_exists else 0
|
||||||
logger.debug(f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes")
|
logger.debug(
|
||||||
|
f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Database verification failed: {str(e)}")
|
logger.error(f"Database verification failed: {str(e)}")
|
||||||
# Continue anyway, as this is just a verification step
|
# Continue anyway, as this is just a verification step
|
||||||
|
|
||||||
# Only show initialization message if it hasn't been shown before
|
# Only show initialization message if it hasn't been shown before
|
||||||
if not hasattr(db, '_message_shown') or not db._message_shown:
|
if not hasattr(db, "_message_shown") or not db._message_shown:
|
||||||
if in_memory:
|
if in_memory:
|
||||||
logger.debug("In-memory database connection initialized successfully")
|
logger.debug("In-memory database connection initialized successfully")
|
||||||
else:
|
else:
|
||||||
logger.debug("Database connection initialized successfully")
|
logger.debug("Database connection initialized successfully")
|
||||||
db._message_shown = True
|
db._message_shown = True
|
||||||
|
|
||||||
return db
|
return db
|
||||||
except peewee.OperationalError as e:
|
except peewee.OperationalError as e:
|
||||||
logger.error(f"Database Operational Error: {str(e)}")
|
logger.error(f"Database Operational Error: {str(e)}")
|
||||||
|
|
@ -280,23 +299,24 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
logger.error(f"Failed to initialize database: {str(e)}")
|
logger.error(f"Failed to initialize database: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> peewee.SqliteDatabase:
|
def get_db() -> peewee.SqliteDatabase:
|
||||||
"""
|
"""
|
||||||
Get the current database connection.
|
Get the current database connection.
|
||||||
|
|
||||||
If no connection exists, initializes a new one.
|
If no connection exists, initializes a new one.
|
||||||
If connection exists but is closed, reopens it.
|
If connection exists but is closed, reopens it.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
peewee.SqliteDatabase: The current database connection
|
peewee.SqliteDatabase: The current database connection
|
||||||
"""
|
"""
|
||||||
db = db_var.get()
|
db = db_var.get()
|
||||||
|
|
||||||
if db is None:
|
if db is None:
|
||||||
# No database connection exists, initialize one
|
# No database connection exists, initialize one
|
||||||
# Use the default in-memory mode (False)
|
# Use the default in-memory mode (False)
|
||||||
return init_db(in_memory=False)
|
return init_db(in_memory=False)
|
||||||
|
|
||||||
# Check if connection is closed and reopen if needed
|
# Check if connection is closed and reopen if needed
|
||||||
if db.is_closed():
|
if db.is_closed():
|
||||||
try:
|
try:
|
||||||
|
|
@ -309,30 +329,33 @@ def get_db() -> peewee.SqliteDatabase:
|
||||||
# First, remove the old connection from the context var
|
# First, remove the old connection from the context var
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
# Then initialize a new connection with the same in-memory setting
|
# Then initialize a new connection with the same in-memory setting
|
||||||
in_memory = hasattr(db, '_is_in_memory') and db._is_in_memory
|
in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory
|
||||||
logger.debug(f"Creating new database connection (in_memory={in_memory})")
|
logger.debug(f"Creating new database connection (in_memory={in_memory})")
|
||||||
# Create a completely new database object, don't reuse the old one
|
# Create a completely new database object, don't reuse the old one
|
||||||
return init_db(in_memory=in_memory)
|
return init_db(in_memory=in_memory)
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|
||||||
|
|
||||||
def close_db() -> None:
|
def close_db() -> None:
|
||||||
"""
|
"""
|
||||||
Close the current database connection if it exists.
|
Close the current database connection if it exists.
|
||||||
|
|
||||||
Handles various error conditions gracefully.
|
Handles various error conditions gracefully.
|
||||||
"""
|
"""
|
||||||
db = db_var.get()
|
db = db_var.get()
|
||||||
if db is None:
|
if db is None:
|
||||||
logger.warning("No database connection to close")
|
logger.warning("No database connection to close")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
logger.info("Database connection closed successfully")
|
logger.info("Database connection closed successfully")
|
||||||
else:
|
else:
|
||||||
logger.debug("Database connection was already closed (normal during shutdown)")
|
logger.debug(
|
||||||
|
"Database connection was already closed (normal during shutdown)"
|
||||||
|
)
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Database Error: Failed to close connection: {str(e)}")
|
logger.error(f"Database Error: Failed to close connection: {str(e)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -6,16 +6,14 @@ using peewee-migrate. It includes tools for creating, checking, and applying
|
||||||
migrations automatically.
|
migrations automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Dict, Any
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import peewee
|
|
||||||
from peewee_migrate import Router
|
from peewee_migrate import Router
|
||||||
from peewee_migrate.router import DEFAULT_MIGRATE_DIR
|
|
||||||
|
|
||||||
from ra_aid.database.connection import get_db, DatabaseManager
|
from ra_aid.database.connection import DatabaseManager, get_db
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -28,48 +26,50 @@ MIGRATIONS_TABLE = "migrationshistory"
|
||||||
class MigrationManager:
|
class MigrationManager:
|
||||||
"""
|
"""
|
||||||
Manages database migrations for the ra_aid application.
|
Manages database migrations for the ra_aid application.
|
||||||
|
|
||||||
This class provides methods to initialize the migrator, check for
|
This class provides methods to initialize the migrator, check for
|
||||||
pending migrations, apply migrations, and create new migrations.
|
pending migrations, apply migrations, and create new migrations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None):
|
def __init__(
|
||||||
|
self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the MigrationManager.
|
Initialize the MigrationManager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_path: Optional path to the database file. If None, uses the default.
|
db_path: Optional path to the database file. If None, uses the default.
|
||||||
migrations_dir: Optional path to the migrations directory. If None, uses default.
|
migrations_dir: Optional path to the migrations directory. If None, uses default.
|
||||||
"""
|
"""
|
||||||
self.db = get_db()
|
self.db = get_db()
|
||||||
|
|
||||||
# Determine database path
|
# Determine database path
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
# Get current working directory
|
# Get current working directory
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
ra_aid_dir = os.path.join(cwd, ".ra-aid")
|
ra_aid_dir = os.path.join(cwd, ".ra-aid")
|
||||||
db_path = os.path.join(ra_aid_dir, "pk.db")
|
db_path = os.path.join(ra_aid_dir, "pk.db")
|
||||||
|
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
|
||||||
# Determine migrations directory
|
# Determine migrations directory
|
||||||
if migrations_dir is None:
|
if migrations_dir is None:
|
||||||
# Use a directory within .ra-aid
|
# Use a directory within .ra-aid
|
||||||
ra_aid_dir = os.path.dirname(self.db_path)
|
ra_aid_dir = os.path.dirname(self.db_path)
|
||||||
migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
self.migrations_dir = migrations_dir
|
self.migrations_dir = migrations_dir
|
||||||
|
|
||||||
# Ensure migrations directory exists
|
# Ensure migrations directory exists
|
||||||
self._ensure_migrations_dir()
|
self._ensure_migrations_dir()
|
||||||
|
|
||||||
# Initialize router
|
# Initialize router
|
||||||
self.router = self._init_router()
|
self.router = self._init_router()
|
||||||
|
|
||||||
def _ensure_migrations_dir(self) -> None:
|
def _ensure_migrations_dir(self) -> None:
|
||||||
"""
|
"""
|
||||||
Ensure that the migrations directory exists.
|
Ensure that the migrations directory exists.
|
||||||
|
|
||||||
Creates the directory if it doesn't exist.
|
Creates the directory if it doesn't exist.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -77,75 +77,79 @@ class MigrationManager:
|
||||||
if not migrations_path.exists():
|
if not migrations_path.exists():
|
||||||
logger.debug(f"Creating migrations directory at: {self.migrations_dir}")
|
logger.debug(f"Creating migrations directory at: {self.migrations_dir}")
|
||||||
migrations_path.mkdir(parents=True, exist_ok=True)
|
migrations_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Create __init__.py to make it a proper package
|
# Create __init__.py to make it a proper package
|
||||||
init_file = migrations_path / "__init__.py"
|
init_file = migrations_path / "__init__.py"
|
||||||
if not init_file.exists():
|
if not init_file.exists():
|
||||||
init_file.touch()
|
init_file.touch()
|
||||||
|
|
||||||
logger.debug(f"Using migrations directory: {self.migrations_dir}")
|
logger.debug(f"Using migrations directory: {self.migrations_dir}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create migrations directory: {str(e)}")
|
logger.error(f"Failed to create migrations directory: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _init_router(self) -> Router:
|
def _init_router(self) -> Router:
|
||||||
"""
|
"""
|
||||||
Initialize the peewee-migrate Router.
|
Initialize the peewee-migrate Router.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Router: Configured peewee-migrate Router instance
|
Router: Configured peewee-migrate Router instance
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
router = Router(self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE)
|
router = Router(
|
||||||
|
self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE
|
||||||
|
)
|
||||||
logger.debug(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
|
logger.debug(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
|
||||||
return router
|
return router
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize migration router: {str(e)}")
|
logger.error(f"Failed to initialize migration router: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def check_migrations(self) -> Tuple[List[str], List[str]]:
|
def check_migrations(self) -> Tuple[List[str], List[str]]:
|
||||||
"""
|
"""
|
||||||
Check for pending migrations.
|
Check for pending migrations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[List[str], List[str]]: A tuple containing (applied_migrations, pending_migrations)
|
Tuple[List[str], List[str]]: A tuple containing (applied_migrations, pending_migrations)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all migrations
|
# Get all migrations
|
||||||
all_migrations = self.router.todo
|
all_migrations = self.router.todo
|
||||||
|
|
||||||
# Get applied migrations
|
# Get applied migrations
|
||||||
applied = self.router.done
|
applied = self.router.done
|
||||||
|
|
||||||
# Calculate pending migrations
|
# Calculate pending migrations
|
||||||
pending = [m for m in all_migrations if m not in applied]
|
pending = [m for m in all_migrations if m not in applied]
|
||||||
|
|
||||||
logger.debug(f"Found {len(applied)} applied migrations and {len(pending)} pending migrations")
|
logger.debug(
|
||||||
|
f"Found {len(applied)} applied migrations and {len(pending)} pending migrations"
|
||||||
|
)
|
||||||
return applied, pending
|
return applied, pending
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to check migrations: {str(e)}")
|
logger.error(f"Failed to check migrations: {str(e)}")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
def apply_migrations(self, fake: bool = False) -> bool:
|
def apply_migrations(self, fake: bool = False) -> bool:
|
||||||
"""
|
"""
|
||||||
Apply all pending migrations.
|
Apply all pending migrations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fake: If True, mark migrations as applied without running them
|
fake: If True, mark migrations as applied without running them
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if migrations were applied successfully, False otherwise
|
bool: True if migrations were applied successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get pending migrations
|
# Get pending migrations
|
||||||
_, pending = self.check_migrations()
|
_, pending = self.check_migrations()
|
||||||
|
|
||||||
if not pending:
|
if not pending:
|
||||||
logger.info("No pending migrations to apply")
|
logger.info("No pending migrations to apply")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.info(f"Applying {len(pending)} pending migrations...")
|
logger.info(f"Applying {len(pending)} pending migrations...")
|
||||||
|
|
||||||
# Apply migrations
|
# Apply migrations
|
||||||
for migration in pending:
|
for migration in pending:
|
||||||
try:
|
try:
|
||||||
|
|
@ -155,50 +159,50 @@ class MigrationManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to apply migration {migration}: {str(e)}")
|
logger.error(f"Failed to apply migration {migration}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
logger.info(f"Successfully applied {len(pending)} migrations")
|
logger.info(f"Successfully applied {len(pending)} migrations")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to apply migrations: {str(e)}")
|
logger.error(f"Failed to apply migrations: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def create_migration(self, name: str, auto: bool = True) -> Optional[str]:
|
def create_migration(self, name: str, auto: bool = True) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Create a new migration.
|
Create a new migration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the migration
|
name: Name of the migration
|
||||||
auto: If True, automatically detect model changes
|
auto: If True, automatically detect model changes
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: The name of the created migration, or None if creation failed
|
Optional[str]: The name of the created migration, or None if creation failed
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Sanitize migration name
|
# Sanitize migration name
|
||||||
safe_name = name.replace(' ', '_').lower()
|
safe_name = name.replace(" ", "_").lower()
|
||||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
migration_name = f"{timestamp}_{safe_name}"
|
migration_name = f"{timestamp}_{safe_name}"
|
||||||
|
|
||||||
logger.info(f"Creating new migration: {migration_name}")
|
logger.info(f"Creating new migration: {migration_name}")
|
||||||
|
|
||||||
# Create migration
|
# Create migration
|
||||||
self.router.create(migration_name, auto=auto)
|
self.router.create(migration_name, auto=auto)
|
||||||
|
|
||||||
logger.info(f"Successfully created migration: {migration_name}")
|
logger.info(f"Successfully created migration: {migration_name}")
|
||||||
return migration_name
|
return migration_name
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create migration: {str(e)}")
|
logger.error(f"Failed to create migration: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_migration_status(self) -> Dict[str, Any]:
|
def get_migration_status(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get the current migration status.
|
Get the current migration status.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: A dictionary containing migration status information
|
Dict[str, Any]: A dictionary containing migration status information
|
||||||
"""
|
"""
|
||||||
applied, pending = self.check_migrations()
|
applied, pending = self.check_migrations()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"applied_count": len(applied),
|
"applied_count": len(applied),
|
||||||
"pending_count": len(pending),
|
"pending_count": len(pending),
|
||||||
|
|
@ -209,14 +213,16 @@ class MigrationManager:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def init_migrations(db_path: Optional[str] = None, migrations_dir: Optional[str] = None) -> MigrationManager:
|
def init_migrations(
|
||||||
|
db_path: Optional[str] = None, migrations_dir: Optional[str] = None
|
||||||
|
) -> MigrationManager:
|
||||||
"""
|
"""
|
||||||
Initialize the migration manager.
|
Initialize the migration manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_path: Optional path to the database file
|
db_path: Optional path to the database file
|
||||||
migrations_dir: Optional path to the migrations directory
|
migrations_dir: Optional path to the migrations directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MigrationManager: Initialized migration manager
|
MigrationManager: Initialized migration manager
|
||||||
"""
|
"""
|
||||||
|
|
@ -226,10 +232,10 @@ def init_migrations(db_path: Optional[str] = None, migrations_dir: Optional[str]
|
||||||
def ensure_migrations_applied() -> bool:
|
def ensure_migrations_applied() -> bool:
|
||||||
"""
|
"""
|
||||||
Check for and apply any pending migrations.
|
Check for and apply any pending migrations.
|
||||||
|
|
||||||
This function should be called during application startup to ensure
|
This function should be called during application startup to ensure
|
||||||
the database schema is up to date.
|
the database schema is up to date.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if migrations were applied successfully or none were pending
|
bool: True if migrations were applied successfully or none were pending
|
||||||
"""
|
"""
|
||||||
|
|
@ -245,11 +251,11 @@ def ensure_migrations_applied() -> bool:
|
||||||
def create_new_migration(name: str, auto: bool = True) -> Optional[str]:
|
def create_new_migration(name: str, auto: bool = True) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Create a new migration with the given name.
|
Create a new migration with the given name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the migration
|
name: Name of the migration
|
||||||
auto: If True, automatically detect model changes
|
auto: If True, automatically detect model changes
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: The name of the created migration, or None if creation failed
|
Optional[str]: The name of the created migration, or None if creation failed
|
||||||
"""
|
"""
|
||||||
|
|
@ -265,7 +271,7 @@ def create_new_migration(name: str, auto: bool = True) -> Optional[str]:
|
||||||
def get_migration_status() -> Dict[str, Any]:
|
def get_migration_status() -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get the current migration status.
|
Get the current migration status.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: A dictionary containing migration status information
|
Dict[str, Any]: A dictionary containing migration status information
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -5,51 +5,53 @@ This module defines the base model class that all models will inherit from.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Any, Dict, Type, TypeVar
|
from typing import Any, Type, TypeVar
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
from ra_aid.database.connection import get_db
|
from ra_aid.database.connection import get_db
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
T = TypeVar('T', bound='BaseModel')
|
T = TypeVar("T", bound="BaseModel")
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(peewee.Model):
|
class BaseModel(peewee.Model):
|
||||||
"""
|
"""
|
||||||
Base model class for all ra_aid models.
|
Base model class for all ra_aid models.
|
||||||
|
|
||||||
All models should inherit from this class to ensure consistent
|
All models should inherit from this class to ensure consistent
|
||||||
behavior and database connection.
|
behavior and database connection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
created_at = peewee.DateTimeField(default=datetime.datetime.now)
|
created_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||||
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
|
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
database = get_db()
|
database = get_db()
|
||||||
|
|
||||||
def save(self, *args: Any, **kwargs: Any) -> int:
|
def save(self, *args: Any, **kwargs: Any) -> int:
|
||||||
"""
|
"""
|
||||||
Override save to update the updated_at field.
|
Override save to update the updated_at field.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: Arguments to pass to the parent save method
|
*args: Arguments to pass to the parent save method
|
||||||
**kwargs: Keyword arguments to pass to the parent save method
|
**kwargs: Keyword arguments to pass to the parent save method
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The primary key of the saved instance
|
int: The primary key of the saved instance
|
||||||
"""
|
"""
|
||||||
self.updated_at = datetime.datetime.now()
|
self.updated_at = datetime.datetime.now()
|
||||||
return super().save(*args, **kwargs)
|
return super().save(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_or_create(cls: Type[T], **kwargs: Any) -> tuple[T, bool]:
|
def get_or_create(cls: Type[T], **kwargs: Any) -> tuple[T, bool]:
|
||||||
"""
|
"""
|
||||||
Get an instance or create it if it doesn't exist.
|
Get an instance or create it if it doesn't exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Fields to use for lookup and creation
|
**kwargs: Fields to use for lookup and creation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (instance, created) where created is a boolean indicating
|
tuple: (instance, created) where created is a boolean indicating
|
||||||
whether a new instance was created
|
whether a new instance was created
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,18 @@ Tests for the database connection module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ra_aid.database.connection import (
|
from ra_aid.database.connection import (
|
||||||
init_db, get_db, close_db, db_var, DatabaseManager
|
DatabaseManager,
|
||||||
|
close_db,
|
||||||
|
db_var,
|
||||||
|
get_db,
|
||||||
|
init_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,17 +25,17 @@ def cleanup_db():
|
||||||
"""
|
"""
|
||||||
# Run the test
|
# Run the test
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Clean up after the test
|
# Clean up after the test
|
||||||
db = db_var.get()
|
db = db_var.get()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
if hasattr(db, '_is_in_memory'):
|
if hasattr(db, "_is_in_memory"):
|
||||||
delattr(db, '_is_in_memory')
|
delattr(db, "_is_in_memory")
|
||||||
if hasattr(db, '_message_shown'):
|
if hasattr(db, "_message_shown"):
|
||||||
delattr(db, '_message_shown')
|
delattr(db, "_message_shown")
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# Reset the contextvar
|
# Reset the contextvar
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
|
|
@ -43,10 +47,10 @@ def setup_in_memory_db():
|
||||||
"""
|
"""
|
||||||
# Initialize in-memory database
|
# Initialize in-memory database
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
yield db
|
yield db
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
|
|
@ -60,60 +64,68 @@ def test_init_db_creates_directory(cleanup_db, tmp_path):
|
||||||
# Get and print the original working directory
|
# Get and print the original working directory
|
||||||
original_cwd = os.getcwd()
|
original_cwd = os.getcwd()
|
||||||
print(f"Original working directory: {original_cwd}")
|
print(f"Original working directory: {original_cwd}")
|
||||||
|
|
||||||
# Convert tmp_path to string for consistent handling
|
# Convert tmp_path to string for consistent handling
|
||||||
tmp_path_str = str(tmp_path.absolute())
|
tmp_path_str = str(tmp_path.absolute())
|
||||||
print(f"Temporary directory path: {tmp_path_str}")
|
print(f"Temporary directory path: {tmp_path_str}")
|
||||||
|
|
||||||
# Change to the temporary directory
|
# Change to the temporary directory
|
||||||
os.chdir(tmp_path_str)
|
os.chdir(tmp_path_str)
|
||||||
current_cwd = os.getcwd()
|
current_cwd = os.getcwd()
|
||||||
print(f"Changed working directory to: {current_cwd}")
|
print(f"Changed working directory to: {current_cwd}")
|
||||||
assert current_cwd == tmp_path_str, f"Failed to change directory: {current_cwd} != {tmp_path_str}"
|
assert (
|
||||||
|
current_cwd == tmp_path_str
|
||||||
|
), f"Failed to change directory: {current_cwd} != {tmp_path_str}"
|
||||||
|
|
||||||
# Create the .ra-aid directory manually to ensure it exists
|
# Create the .ra-aid directory manually to ensure it exists
|
||||||
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
|
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
|
||||||
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
|
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
|
||||||
os.makedirs(ra_aid_path_str, exist_ok=True)
|
os.makedirs(ra_aid_path_str, exist_ok=True)
|
||||||
|
|
||||||
# Verify the directory was created
|
# Verify the directory was created
|
||||||
assert os.path.exists(ra_aid_path_str), f".ra-aid directory not found at {ra_aid_path_str}"
|
assert os.path.exists(
|
||||||
assert os.path.isdir(ra_aid_path_str), f"{ra_aid_path_str} exists but is not a directory"
|
ra_aid_path_str
|
||||||
|
), f".ra-aid directory not found at {ra_aid_path_str}"
|
||||||
|
assert os.path.isdir(
|
||||||
|
ra_aid_path_str
|
||||||
|
), f"{ra_aid_path_str} exists but is not a directory"
|
||||||
|
|
||||||
# Create a test file to verify write permissions
|
# Create a test file to verify write permissions
|
||||||
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
|
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
|
||||||
print(f"Creating test file to verify write permissions: {test_file_path}")
|
print(f"Creating test file to verify write permissions: {test_file_path}")
|
||||||
with open(test_file_path, 'w') as f:
|
with open(test_file_path, "w") as f:
|
||||||
f.write("Test write permissions")
|
f.write("Test write permissions")
|
||||||
|
|
||||||
# Verify the test file was created
|
# Verify the test file was created
|
||||||
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
|
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
|
||||||
|
|
||||||
# Create an empty database file to ensure it exists before init_db
|
# Create an empty database file to ensure it exists before init_db
|
||||||
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
|
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
|
||||||
print(f"Creating empty database file at: {db_file_str}")
|
print(f"Creating empty database file at: {db_file_str}")
|
||||||
with open(db_file_str, 'w') as f:
|
with open(db_file_str, "w") as f:
|
||||||
f.write("") # Create empty file
|
f.write("") # Create empty file
|
||||||
|
|
||||||
# Verify the database file was created
|
# Verify the database file was created
|
||||||
assert os.path.exists(db_file_str), f"Empty database file not created at {db_file_str}"
|
assert os.path.exists(
|
||||||
|
db_file_str
|
||||||
|
), f"Empty database file not created at {db_file_str}"
|
||||||
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
|
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
|
||||||
|
|
||||||
# Get directory permissions for debugging
|
# Get directory permissions for debugging
|
||||||
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
|
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
|
||||||
print(f"Directory permissions: {dir_perms}")
|
print(f"Directory permissions: {dir_perms}")
|
||||||
|
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
print("Calling init_db()")
|
print("Calling init_db()")
|
||||||
db = init_db()
|
db = init_db()
|
||||||
print("init_db() returned successfully")
|
print("init_db() returned successfully")
|
||||||
|
|
||||||
# List contents of the current directory for debugging
|
# List contents of the current directory for debugging
|
||||||
print(f"Contents of current directory: {os.listdir(current_cwd)}")
|
print(f"Contents of current directory: {os.listdir(current_cwd)}")
|
||||||
|
|
||||||
# List contents of the .ra-aid directory for debugging
|
# List contents of the .ra-aid directory for debugging
|
||||||
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
|
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
|
||||||
|
|
||||||
# Check that the database file exists using os.path
|
# Check that the database file exists using os.path
|
||||||
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
|
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
|
||||||
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
|
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
|
||||||
|
|
@ -126,10 +138,10 @@ def test_init_db_creates_database_file(cleanup_db, tmp_path):
|
||||||
"""
|
"""
|
||||||
# Change to the temporary directory
|
# Change to the temporary directory
|
||||||
os.chdir(tmp_path)
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
# Check that the database file was created
|
# Check that the database file was created
|
||||||
assert (tmp_path / ".ra-aid" / "pk.db").exists()
|
assert (tmp_path / ".ra-aid" / "pk.db").exists()
|
||||||
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
|
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
|
||||||
|
|
@ -141,7 +153,7 @@ def test_init_db_returns_database_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db = init_db()
|
db = init_db()
|
||||||
|
|
||||||
# Check that the database connection is returned
|
# Check that the database connection is returned
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
|
|
@ -153,11 +165,11 @@ def test_init_db_with_in_memory_mode(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database in in-memory mode
|
# Initialize the database in in-memory mode
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Check that the database connection is returned
|
# Check that the database connection is returned
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -167,10 +179,10 @@ def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path):
|
||||||
"""
|
"""
|
||||||
# Change to the temporary directory
|
# Change to the temporary directory
|
||||||
os.chdir(tmp_path)
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
# Initialize the database in in-memory mode
|
# Initialize the database in in-memory mode
|
||||||
init_db(in_memory=True)
|
init_db(in_memory=True)
|
||||||
|
|
||||||
# Check that the .ra-aid directory was not created
|
# Check that the .ra-aid directory was not created
|
||||||
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
|
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
|
||||||
# Instead, check that the database file was not created
|
# Instead, check that the database file was not created
|
||||||
|
|
@ -183,10 +195,10 @@ def test_init_db_returns_existing_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db1 = init_db()
|
db1 = init_db()
|
||||||
|
|
||||||
# Initialize the database again
|
# Initialize the database again
|
||||||
db2 = init_db()
|
db2 = init_db()
|
||||||
|
|
||||||
# Check that the same connection is returned
|
# Check that the same connection is returned
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
|
|
||||||
|
|
@ -197,13 +209,13 @@ def test_init_db_reopens_closed_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db1 = init_db()
|
db1 = init_db()
|
||||||
|
|
||||||
# Close the connection
|
# Close the connection
|
||||||
db1.close()
|
db1.close()
|
||||||
|
|
||||||
# Initialize the database again
|
# Initialize the database again
|
||||||
db2 = init_db()
|
db2 = init_db()
|
||||||
|
|
||||||
# Check that the same connection is returned and it's open
|
# Check that the same connection is returned and it's open
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
assert not db1.is_closed()
|
assert not db1.is_closed()
|
||||||
|
|
@ -215,7 +227,7 @@ def test_get_db_initializes_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Get the database connection
|
# Get the database connection
|
||||||
db = get_db()
|
db = get_db()
|
||||||
|
|
||||||
# Check that a connection was initialized
|
# Check that a connection was initialized
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
|
|
@ -227,10 +239,10 @@ def test_get_db_returns_existing_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db1 = init_db()
|
db1 = init_db()
|
||||||
|
|
||||||
# Get the database connection
|
# Get the database connection
|
||||||
db2 = get_db()
|
db2 = get_db()
|
||||||
|
|
||||||
# Check that the same connection is returned
|
# Check that the same connection is returned
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
|
|
||||||
|
|
@ -241,13 +253,13 @@ def test_get_db_reopens_closed_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db = init_db()
|
db = init_db()
|
||||||
|
|
||||||
# Close the connection
|
# Close the connection
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# Get the database connection
|
# Get the database connection
|
||||||
db2 = get_db()
|
db2 = get_db()
|
||||||
|
|
||||||
# Check that the same connection is returned and it's open
|
# Check that the same connection is returned and it's open
|
||||||
assert db is db2
|
assert db is db2
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
|
|
@ -259,24 +271,24 @@ def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db = init_db()
|
db = init_db()
|
||||||
|
|
||||||
# Close the connection
|
# Close the connection
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# Create a patched version of the connect method that raises an error
|
# Create a patched version of the connect method that raises an error
|
||||||
original_connect = peewee.SqliteDatabase.connect
|
original_connect = peewee.SqliteDatabase.connect
|
||||||
|
|
||||||
def mock_connect(self, *args, **kwargs):
|
def mock_connect(self, *args, **kwargs):
|
||||||
if self is db: # Only raise for the specific db instance
|
if self is db: # Only raise for the specific db instance
|
||||||
raise peewee.OperationalError("Test error")
|
raise peewee.OperationalError("Test error")
|
||||||
return original_connect(self, *args, **kwargs)
|
return original_connect(self, *args, **kwargs)
|
||||||
|
|
||||||
# Apply the patch
|
# Apply the patch
|
||||||
monkeypatch.setattr(peewee.SqliteDatabase, 'connect', mock_connect)
|
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
|
||||||
|
|
||||||
# Get the database connection
|
# Get the database connection
|
||||||
db2 = get_db()
|
db2 = get_db()
|
||||||
|
|
||||||
# Check that a new connection was initialized
|
# Check that a new connection was initialized
|
||||||
assert db is not db2
|
assert db is not db2
|
||||||
assert not db2.is_closed()
|
assert not db2.is_closed()
|
||||||
|
|
@ -288,10 +300,10 @@ def test_close_db_closes_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db = init_db()
|
db = init_db()
|
||||||
|
|
||||||
# Close the connection
|
# Close the connection
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
# Check that the connection is closed
|
# Check that the connection is closed
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
||||||
|
|
@ -302,7 +314,7 @@ def test_close_db_handles_no_connection():
|
||||||
"""
|
"""
|
||||||
# Reset the contextvar
|
# Reset the contextvar
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
# Close the connection (should not raise an error)
|
# Close the connection (should not raise an error)
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
|
|
@ -313,25 +325,25 @@ def test_close_db_handles_already_closed_connection(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db = init_db()
|
db = init_db()
|
||||||
|
|
||||||
# Close the connection
|
# Close the connection
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# Close the connection again (should not raise an error)
|
# Close the connection again (should not raise an error)
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
|
|
||||||
@patch('ra_aid.database.connection.peewee.SqliteDatabase.close')
|
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
|
||||||
def test_close_db_handles_error(mock_close, cleanup_db):
|
def test_close_db_handles_error(mock_close, cleanup_db):
|
||||||
"""
|
"""
|
||||||
Test that close_db handles errors when closing the connection.
|
Test that close_db handles errors when closing the connection.
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
# Make close raise an error
|
# Make close raise an error
|
||||||
mock_close.side_effect = peewee.DatabaseError("Test error")
|
mock_close.side_effect = peewee.DatabaseError("Test error")
|
||||||
|
|
||||||
# Close the connection (should not raise an error)
|
# Close the connection (should not raise an error)
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
|
|
@ -345,10 +357,10 @@ def test_database_manager_context_manager(cleanup_db):
|
||||||
# Check that a connection was initialized
|
# Check that a connection was initialized
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
|
|
||||||
# Store the connection for later
|
# Store the connection for later
|
||||||
db_in_context = db
|
db_in_context = db
|
||||||
|
|
||||||
# Check that the connection is closed after exiting the context
|
# Check that the connection is closed after exiting the context
|
||||||
assert db_in_context.is_closed()
|
assert db_in_context.is_closed()
|
||||||
|
|
||||||
|
|
@ -362,7 +374,7 @@ def test_database_manager_with_in_memory_mode(cleanup_db):
|
||||||
# Check that a connection was initialized
|
# Check that a connection was initialized
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -372,13 +384,13 @@ def test_init_db_shows_message_only_once(cleanup_db, caplog):
|
||||||
"""
|
"""
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
init_db(in_memory=True)
|
init_db(in_memory=True)
|
||||||
|
|
||||||
# Clear the log
|
# Clear the log
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
|
|
||||||
# Initialize the database again
|
# Initialize the database again
|
||||||
init_db(in_memory=True)
|
init_db(in_memory=True)
|
||||||
|
|
||||||
# Check that no message was logged
|
# Check that no message was logged
|
||||||
assert "database connection initialized" not in caplog.text.lower()
|
assert "database connection initialized" not in caplog.text.lower()
|
||||||
|
|
||||||
|
|
@ -389,41 +401,36 @@ def test_init_db_sets_is_in_memory_attribute(cleanup_db):
|
||||||
"""
|
"""
|
||||||
# Initialize the database with in_memory=False
|
# Initialize the database with in_memory=False
|
||||||
db = init_db(in_memory=False)
|
db = init_db(in_memory=False)
|
||||||
|
|
||||||
# Check that the _is_in_memory attribute is set to False
|
# Check that the _is_in_memory attribute is set to False
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
# Reset the contextvar
|
# Reset the contextvar
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
# Initialize the database with in_memory=True
|
# Initialize the database with in_memory=True
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Check that the _is_in_memory attribute is set to True
|
# Check that the _is_in_memory attribute is set to True
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Tests for the database connection module.
|
Tests for the database connection module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
import pytest
|
|
||||||
import peewee
|
|
||||||
|
|
||||||
from ra_aid.database.connection import (
|
import pytest
|
||||||
init_db, get_db, close_db,
|
|
||||||
db_var, DatabaseManager, logger
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def cleanup_db():
|
def cleanup_db():
|
||||||
"""
|
"""
|
||||||
Fixture to clean up database connections and files between tests.
|
Fixture to clean up database connections and files between tests.
|
||||||
|
|
||||||
This fixture:
|
This fixture:
|
||||||
1. Closes any open database connection
|
1. Closes any open database connection
|
||||||
2. Resets the contextvar
|
2. Resets the contextvar
|
||||||
|
|
@ -431,117 +438,119 @@ def cleanup_db():
|
||||||
"""
|
"""
|
||||||
# Store the current working directory
|
# Store the current working directory
|
||||||
original_cwd = os.getcwd()
|
original_cwd = os.getcwd()
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Clean up after the test
|
# Clean up after the test
|
||||||
try:
|
try:
|
||||||
# Close any open database connection
|
# Close any open database connection
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
# Reset the contextvar
|
# Reset the contextvar
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
# Clean up the .ra-aid directory if it exists
|
# Clean up the .ra-aid directory if it exists
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
||||||
|
|
||||||
# Check using both methods
|
# Check using both methods
|
||||||
path_exists = ra_aid_dir.exists()
|
path_exists = ra_aid_dir.exists()
|
||||||
os_exists = os.path.exists(ra_aid_dir_str)
|
os_exists = os.path.exists(ra_aid_dir_str)
|
||||||
|
|
||||||
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
||||||
|
|
||||||
if os_exists:
|
if os_exists:
|
||||||
# Only remove the database file, not the entire directory
|
# Only remove the database file, not the entire directory
|
||||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
||||||
if os.path.exists(db_file):
|
if os.path.exists(db_file):
|
||||||
os.unlink(db_file)
|
os.unlink(db_file)
|
||||||
|
|
||||||
# Remove WAL and SHM files if they exist
|
# Remove WAL and SHM files if they exist
|
||||||
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
|
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
|
||||||
if os.path.exists(wal_file):
|
if os.path.exists(wal_file):
|
||||||
os.unlink(wal_file)
|
os.unlink(wal_file)
|
||||||
|
|
||||||
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
|
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
|
||||||
if os.path.exists(shm_file):
|
if os.path.exists(shm_file):
|
||||||
os.unlink(shm_file)
|
os.unlink(shm_file)
|
||||||
|
|
||||||
# List remaining contents for debugging
|
# List remaining contents for debugging
|
||||||
if os.path.exists(ra_aid_dir_str):
|
if os.path.exists(ra_aid_dir_str):
|
||||||
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
|
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log but don't fail if cleanup has issues
|
# Log but don't fail if cleanup has issues
|
||||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
print(f"Cleanup error (non-fatal): {str(e)}")
|
||||||
|
|
||||||
# Make sure we're back in the original directory
|
# Make sure we're back in the original directory
|
||||||
os.chdir(original_cwd)
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
class TestInitDb:
|
class TestInitDb:
|
||||||
"""Tests for the init_db function."""
|
"""Tests for the init_db function."""
|
||||||
|
|
||||||
def test_init_db_default(self, cleanup_db):
|
def test_init_db_default(self, cleanup_db):
|
||||||
"""Test init_db with default parameters."""
|
"""Test init_db with default parameters."""
|
||||||
# Get the absolute path of the current working directory
|
# Get the absolute path of the current working directory
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
print(f"Current working directory: {cwd}")
|
print(f"Current working directory: {cwd}")
|
||||||
|
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
db = init_db()
|
db = init_db()
|
||||||
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
# Verify the database file was created using both Path and os.path methods
|
# Verify the database file was created using both Path and os.path methods
|
||||||
ra_aid_dir = Path(cwd) / ".ra-aid"
|
ra_aid_dir = Path(cwd) / ".ra-aid"
|
||||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
||||||
|
|
||||||
# Check directory existence using both methods
|
# Check directory existence using both methods
|
||||||
path_exists = ra_aid_dir.exists()
|
path_exists = ra_aid_dir.exists()
|
||||||
os_exists = os.path.exists(ra_aid_dir_str)
|
os_exists = os.path.exists(ra_aid_dir_str)
|
||||||
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
||||||
|
|
||||||
# List the contents of the current directory
|
# List the contents of the current directory
|
||||||
print(f"Contents of {cwd}: {os.listdir(cwd)}")
|
print(f"Contents of {cwd}: {os.listdir(cwd)}")
|
||||||
|
|
||||||
# If the directory exists, list its contents
|
# If the directory exists, list its contents
|
||||||
if os_exists:
|
if os_exists:
|
||||||
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
|
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
|
||||||
|
|
||||||
# Use os.path for assertions to be more reliable
|
# Use os.path for assertions to be more reliable
|
||||||
assert os.path.exists(ra_aid_dir_str), f"Directory {ra_aid_dir_str} does not exist"
|
assert os.path.exists(
|
||||||
|
ra_aid_dir_str
|
||||||
|
), f"Directory {ra_aid_dir_str} does not exist"
|
||||||
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
|
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
|
||||||
|
|
||||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
||||||
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
|
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
|
||||||
assert os.path.isfile(db_file), f"{db_file} is not a file"
|
assert os.path.isfile(db_file), f"{db_file} is not a file"
|
||||||
|
|
||||||
def test_init_db_in_memory(self, cleanup_db):
|
def test_init_db_in_memory(self, cleanup_db):
|
||||||
"""Test init_db with in_memory=True."""
|
"""Test init_db with in_memory=True."""
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
def test_init_db_reuses_connection(self, cleanup_db):
|
def test_init_db_reuses_connection(self, cleanup_db):
|
||||||
"""Test that init_db reuses an existing connection."""
|
"""Test that init_db reuses an existing connection."""
|
||||||
db1 = init_db()
|
db1 = init_db()
|
||||||
db2 = init_db()
|
db2 = init_db()
|
||||||
|
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
|
|
||||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||||
"""Test that init_db reopens a closed connection."""
|
"""Test that init_db reopens a closed connection."""
|
||||||
db1 = init_db()
|
db1 = init_db()
|
||||||
db1.close()
|
db1.close()
|
||||||
assert db1.is_closed()
|
assert db1.is_closed()
|
||||||
|
|
||||||
db2 = init_db()
|
db2 = init_db()
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
assert not db1.is_closed()
|
assert not db1.is_closed()
|
||||||
|
|
@ -549,32 +558,32 @@ class TestInitDb:
|
||||||
|
|
||||||
class TestGetDb:
|
class TestGetDb:
|
||||||
"""Tests for the get_db function."""
|
"""Tests for the get_db function."""
|
||||||
|
|
||||||
def test_get_db_creates_connection(self, cleanup_db):
|
def test_get_db_creates_connection(self, cleanup_db):
|
||||||
"""Test that get_db creates a new connection if none exists."""
|
"""Test that get_db creates a new connection if none exists."""
|
||||||
# Reset the contextvar to ensure no connection exists
|
# Reset the contextvar to ensure no connection exists
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
db = get_db()
|
db = get_db()
|
||||||
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
def test_get_db_reuses_connection(self, cleanup_db):
|
def test_get_db_reuses_connection(self, cleanup_db):
|
||||||
"""Test that get_db reuses an existing connection."""
|
"""Test that get_db reuses an existing connection."""
|
||||||
db1 = init_db()
|
db1 = init_db()
|
||||||
db2 = get_db()
|
db2 = get_db()
|
||||||
|
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
|
|
||||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||||
"""Test that get_db reopens a closed connection."""
|
"""Test that get_db reopens a closed connection."""
|
||||||
db1 = init_db()
|
db1 = init_db()
|
||||||
db1.close()
|
db1.close()
|
||||||
assert db1.is_closed()
|
assert db1.is_closed()
|
||||||
|
|
||||||
db2 = get_db()
|
db2 = get_db()
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
assert not db1.is_closed()
|
assert not db1.is_closed()
|
||||||
|
|
@ -582,63 +591,63 @@ class TestGetDb:
|
||||||
|
|
||||||
class TestCloseDb:
|
class TestCloseDb:
|
||||||
"""Tests for the close_db function."""
|
"""Tests for the close_db function."""
|
||||||
|
|
||||||
def test_close_db(self, cleanup_db):
|
def test_close_db(self, cleanup_db):
|
||||||
"""Test that close_db closes an open connection."""
|
"""Test that close_db closes an open connection."""
|
||||||
db = init_db()
|
db = init_db()
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
|
|
||||||
close_db()
|
close_db()
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
||||||
def test_close_db_no_connection(self, cleanup_db):
|
def test_close_db_no_connection(self, cleanup_db):
|
||||||
"""Test that close_db handles the case where no connection exists."""
|
"""Test that close_db handles the case where no connection exists."""
|
||||||
# Reset the contextvar to ensure no connection exists
|
# Reset the contextvar to ensure no connection exists
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
# This should not raise an exception
|
# This should not raise an exception
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
def test_close_db_already_closed(self, cleanup_db):
|
def test_close_db_already_closed(self, cleanup_db):
|
||||||
"""Test that close_db handles the case where the connection is already closed."""
|
"""Test that close_db handles the case where the connection is already closed."""
|
||||||
db = init_db()
|
db = init_db()
|
||||||
db.close()
|
db.close()
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
||||||
# This should not raise an exception
|
# This should not raise an exception
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
|
|
||||||
class TestDatabaseManager:
|
class TestDatabaseManager:
|
||||||
"""Tests for the DatabaseManager class."""
|
"""Tests for the DatabaseManager class."""
|
||||||
|
|
||||||
def test_database_manager_default(self, cleanup_db):
|
def test_database_manager_default(self, cleanup_db):
|
||||||
"""Test DatabaseManager with default parameters."""
|
"""Test DatabaseManager with default parameters."""
|
||||||
with DatabaseManager() as db:
|
with DatabaseManager() as db:
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
# Verify the database file was created
|
# Verify the database file was created
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||||
assert ra_aid_dir.exists()
|
assert ra_aid_dir.exists()
|
||||||
assert (ra_aid_dir / "pk.db").exists()
|
assert (ra_aid_dir / "pk.db").exists()
|
||||||
|
|
||||||
# Verify the connection is closed after exiting the context
|
# Verify the connection is closed after exiting the context
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
||||||
def test_database_manager_in_memory(self, cleanup_db):
|
def test_database_manager_in_memory(self, cleanup_db):
|
||||||
"""Test DatabaseManager with in_memory=True."""
|
"""Test DatabaseManager with in_memory=True."""
|
||||||
with DatabaseManager(in_memory=True) as db:
|
with DatabaseManager(in_memory=True) as db:
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
# Verify the connection is closed after exiting the context
|
# Verify the connection is closed after exiting the context
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
||||||
def test_database_manager_exception_handling(self, cleanup_db):
|
def test_database_manager_exception_handling(self, cleanup_db):
|
||||||
"""Test that DatabaseManager properly handles exceptions."""
|
"""Test that DatabaseManager properly handles exceptions."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -648,6 +657,6 @@ class TestDatabaseManager:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# The exception should be propagated
|
# The exception should be propagated
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Verify the connection is closed even if an exception occurred
|
# Verify the connection is closed even if an exception occurred
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ Database utility functions for ra_aid.
|
||||||
This module provides utility functions for common database operations.
|
This module provides utility functions for common database operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
|
|
@ -16,18 +15,19 @@ from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Ensure that database tables for the specified models exist.
|
Ensure that database tables for the specified models exist.
|
||||||
|
|
||||||
If no models are specified, this function will attempt to discover
|
If no models are specified, this function will attempt to discover
|
||||||
all models that inherit from BaseModel.
|
all models that inherit from BaseModel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
models: Optional list of model classes to create tables for
|
models: Optional list of model classes to create tables for
|
||||||
"""
|
"""
|
||||||
db = get_db()
|
db = get_db()
|
||||||
|
|
||||||
if models is None:
|
if models is None:
|
||||||
# If no models are specified, try to discover them
|
# If no models are specified, try to discover them
|
||||||
models = []
|
models = []
|
||||||
|
|
@ -36,20 +36,22 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
||||||
# This is a placeholder - in a real implementation, you would
|
# This is a placeholder - in a real implementation, you would
|
||||||
# dynamically discover and import all modules that might contain models
|
# dynamically discover and import all modules that might contain models
|
||||||
from ra_aid.database import models as models_module
|
from ra_aid.database import models as models_module
|
||||||
|
|
||||||
# Find all classes in the module that inherit from BaseModel
|
# Find all classes in the module that inherit from BaseModel
|
||||||
for name, obj in inspect.getmembers(models_module):
|
for name, obj in inspect.getmembers(models_module):
|
||||||
if (inspect.isclass(obj) and
|
if (
|
||||||
issubclass(obj, BaseModel) and
|
inspect.isclass(obj)
|
||||||
obj != BaseModel):
|
and issubclass(obj, BaseModel)
|
||||||
|
and obj != BaseModel
|
||||||
|
):
|
||||||
models.append(obj)
|
models.append(obj)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Error importing model modules: {e}")
|
logger.warning(f"Error importing model modules: {e}")
|
||||||
|
|
||||||
if not models:
|
if not models:
|
||||||
logger.warning("No models found to create tables for")
|
logger.warning("No models found to create tables for")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
db.create_tables(models, safe=True)
|
db.create_tables(models, safe=True)
|
||||||
|
|
@ -61,13 +63,14 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
||||||
logger.error(f"Error: Failed to create tables: {str(e)}")
|
logger.error(f"Error: Failed to create tables: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_model_count(model_class: Type[BaseModel]) -> int:
|
def get_model_count(model_class: Type[BaseModel]) -> int:
|
||||||
"""
|
"""
|
||||||
Get the count of records for a specific model.
|
Get the count of records for a specific model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_class: The model class to count records for
|
model_class: The model class to count records for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The number of records for the model
|
int: The number of records for the model
|
||||||
"""
|
"""
|
||||||
|
|
@ -77,10 +80,11 @@ def get_model_count(model_class: Type[BaseModel]) -> int:
|
||||||
logger.error(f"Database Error: Failed to count records: {str(e)}")
|
logger.error(f"Database Error: Failed to count records: {str(e)}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def truncate_table(model_class: Type[BaseModel]) -> None:
|
def truncate_table(model_class: Type[BaseModel]) -> None:
|
||||||
"""
|
"""
|
||||||
Delete all records from a model's table.
|
Delete all records from a model's table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_class: The model class to truncate
|
model_class: The model class to truncate
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
"""Module for checking system dependencies required by RA.Aid."""
|
"""Module for checking system dependencies required by RA.Aid."""
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ra_aid import print_error
|
from ra_aid import print_error
|
||||||
|
|
@ -23,9 +22,11 @@ class RipGrepDependency(Dependency):
|
||||||
def check(self):
|
def check(self):
|
||||||
"""Check if ripgrep is installed."""
|
"""Check if ripgrep is installed."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(['rg', '--version'],
|
result = subprocess.run(
|
||||||
stdout=subprocess.DEVNULL,
|
["rg", "--version"],
|
||||||
stderr=subprocess.DEVNULL)
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
except (FileNotFoundError, subprocess.SubprocessError):
|
except (FileNotFoundError, subprocess.SubprocessError):
|
||||||
|
|
|
||||||
|
|
@ -122,11 +122,8 @@ def create_openrouter_client(
|
||||||
is_expert: bool = False,
|
is_expert: bool = False,
|
||||||
) -> BaseChatModel:
|
) -> BaseChatModel:
|
||||||
"""Create OpenRouter client with appropriate configuration."""
|
"""Create OpenRouter client with appropriate configuration."""
|
||||||
default_headers = {
|
default_headers = {"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"}
|
||||||
"HTTP-Referer": "https://ra-aid.ai",
|
|
||||||
"X-Title": "RA.Aid"
|
|
||||||
}
|
|
||||||
|
|
||||||
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
|
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
|
||||||
return ChatDeepseekReasoner(
|
return ChatDeepseekReasoner(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
@ -243,12 +240,9 @@ def create_llm_client(
|
||||||
temp_kwargs = {"temperature": temperature}
|
temp_kwargs = {"temperature": temperature}
|
||||||
else:
|
else:
|
||||||
temp_kwargs = {}
|
temp_kwargs = {}
|
||||||
|
|
||||||
if supports_thinking:
|
if supports_thinking:
|
||||||
temp_kwargs = {"thinking": {
|
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||||
"type": "enabled",
|
|
||||||
"budget_tokens": 12000
|
|
||||||
}}
|
|
||||||
|
|
||||||
if provider == "deepseek":
|
if provider == "deepseek":
|
||||||
return create_deepseek_client(
|
return create_deepseek_client(
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import pyte
|
import pyte
|
||||||
from pyte.screens import HistoryScreen
|
from pyte.screens import HistoryScreen
|
||||||
|
|
@ -33,17 +33,20 @@ else:
|
||||||
|
|
||||||
|
|
||||||
def create_process(
|
def create_process(
|
||||||
cmd: List[str], env: Optional[dict] = None, cols: Optional[int] = None, rows: Optional[int] = None
|
cmd: List[str],
|
||||||
|
env: Optional[dict] = None,
|
||||||
|
cols: Optional[int] = None,
|
||||||
|
rows: Optional[int] = None,
|
||||||
) -> Tuple[subprocess.Popen, Optional[int]]:
|
) -> Tuple[subprocess.Popen, Optional[int]]:
|
||||||
"""
|
"""
|
||||||
Create a subprocess with appropriate settings for the current platform.
|
Create a subprocess with appropriate settings for the current platform.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cmd: Command to execute as a list of strings
|
cmd: Command to execute as a list of strings
|
||||||
env: Environment variables dictionary, defaults to os.environ.copy()
|
env: Environment variables dictionary, defaults to os.environ.copy()
|
||||||
cols: Number of columns for the terminal, defaults to current terminal width
|
cols: Number of columns for the terminal, defaults to current terminal width
|
||||||
rows: Number of rows for the terminal, defaults to current terminal height
|
rows: Number of rows for the terminal, defaults to current terminal height
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
On Unix: (process, master_fd) where master_fd is the file descriptor for the pty master
|
On Unix: (process, master_fd) where master_fd is the file descriptor for the pty master
|
||||||
On Windows: (process, None) as Windows doesn't use ptys
|
On Windows: (process, None) as Windows doesn't use ptys
|
||||||
|
|
@ -61,7 +64,7 @@ def create_process(
|
||||||
# Windows-specific process creation
|
# Windows-specific process creation
|
||||||
startupinfo = subprocess.STARTUPINFO()
|
startupinfo = subprocess.STARTUPINFO()
|
||||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||||
|
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
stdin=subprocess.PIPE,
|
stdin=subprocess.PIPE,
|
||||||
|
|
@ -78,7 +81,7 @@ def create_process(
|
||||||
master_fd, slave_fd = os.openpty()
|
master_fd, slave_fd = os.openpty()
|
||||||
# Set master_fd to non-blocking to avoid indefinite blocking
|
# Set master_fd to non-blocking to avoid indefinite blocking
|
||||||
os.set_blocking(master_fd, False)
|
os.set_blocking(master_fd, False)
|
||||||
|
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
stdin=slave_fd,
|
stdin=slave_fd,
|
||||||
|
|
@ -90,18 +93,18 @@ def create_process(
|
||||||
preexec_fn=os.setsid, # Create new process group for proper signal handling
|
preexec_fn=os.setsid, # Create new process group for proper signal handling
|
||||||
)
|
)
|
||||||
os.close(slave_fd) # Close slave end in the parent process
|
os.close(slave_fd) # Close slave end in the parent process
|
||||||
|
|
||||||
return proc, master_fd
|
return proc, master_fd
|
||||||
|
|
||||||
|
|
||||||
def get_terminal_size() -> Tuple[int, int]:
|
def get_terminal_size() -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Get the current terminal size in a cross-platform way.
|
Get the current terminal size in a cross-platform way.
|
||||||
|
|
||||||
This function works on both Unix and Windows systems, using shutil.get_terminal_size()
|
This function works on both Unix and Windows systems, using shutil.get_terminal_size()
|
||||||
which is available in Python 3.3+. If the terminal size cannot be determined
|
which is available in Python 3.3+. If the terminal size cannot be determined
|
||||||
(e.g., when running in a non-interactive environment), it falls back to default values.
|
(e.g., when running in a non-interactive environment), it falls back to default values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (columns, rows) representing the terminal dimensions.
|
A tuple of (columns, rows) representing the terminal dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
@ -117,11 +120,11 @@ def render_line(line, columns: int) -> str:
|
||||||
"""Render a single screen line from the pyte buffer (a mapping of column to Char)."""
|
"""Render a single screen line from the pyte buffer (a mapping of column to Char)."""
|
||||||
if not line:
|
if not line:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Handle string lines directly (from screen.display)
|
# Handle string lines directly (from screen.display)
|
||||||
if isinstance(line, str):
|
if isinstance(line, str):
|
||||||
return line
|
return line
|
||||||
|
|
||||||
# Handle dictionary-style lines (from history)
|
# Handle dictionary-style lines (from history)
|
||||||
try:
|
try:
|
||||||
return "".join(line[x].data for x in range(columns) if x in line)
|
return "".join(line[x].data for x in range(columns) if x in line)
|
||||||
|
|
@ -135,21 +138,21 @@ def run_interactive_command(
|
||||||
) -> Tuple[bytes, int]:
|
) -> Tuple[bytes, int]:
|
||||||
"""
|
"""
|
||||||
Runs an interactive command with output capture, capturing final scrollback history.
|
Runs an interactive command with output capture, capturing final scrollback history.
|
||||||
|
|
||||||
This function provides a cross-platform way to run interactive commands with:
|
This function provides a cross-platform way to run interactive commands with:
|
||||||
- Full terminal emulation using pyte's HistoryScreen
|
- Full terminal emulation using pyte's HistoryScreen
|
||||||
- Real-time display of command output
|
- Real-time display of command output
|
||||||
- Input forwarding when running in an interactive terminal
|
- Input forwarding when running in an interactive terminal
|
||||||
- Timeout handling to prevent runaway processes
|
- Timeout handling to prevent runaway processes
|
||||||
- Comprehensive output capture including ANSI escape sequences
|
- Comprehensive output capture including ANSI escape sequences
|
||||||
|
|
||||||
The implementation differs significantly between Windows and Unix:
|
The implementation differs significantly between Windows and Unix:
|
||||||
|
|
||||||
On Windows:
|
On Windows:
|
||||||
- Uses threading to handle I/O operations
|
- Uses threading to handle I/O operations
|
||||||
- Relies on msvcrt for keyboard input detection
|
- Relies on msvcrt for keyboard input detection
|
||||||
- Uses pipes for process communication
|
- Uses pipes for process communication
|
||||||
|
|
||||||
On Unix:
|
On Unix:
|
||||||
- Uses pseudo-terminals (PTY) for full terminal emulation
|
- Uses pseudo-terminals (PTY) for full terminal emulation
|
||||||
- Uses select() for non-blocking I/O
|
- Uses select() for non-blocking I/O
|
||||||
|
|
@ -230,7 +233,7 @@ def run_interactive_command(
|
||||||
# Windows implementation using threads for I/O
|
# Windows implementation using threads for I/O
|
||||||
running = True
|
running = True
|
||||||
stdin_thread = None
|
stdin_thread = None
|
||||||
|
|
||||||
def read_stdout():
|
def read_stdout():
|
||||||
nonlocal running
|
nonlocal running
|
||||||
while running and proc.poll() is None:
|
while running and proc.poll() is None:
|
||||||
|
|
@ -246,7 +249,7 @@ def run_interactive_command(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading stdout: {e}", file=sys.stderr)
|
print(f"Error reading stdout: {e}", file=sys.stderr)
|
||||||
break
|
break
|
||||||
|
|
||||||
def read_stderr():
|
def read_stderr():
|
||||||
nonlocal running
|
nonlocal running
|
||||||
while running and proc.poll() is None:
|
while running and proc.poll() is None:
|
||||||
|
|
@ -262,7 +265,7 @@ def run_interactive_command(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading stderr: {e}", file=sys.stderr)
|
print(f"Error reading stderr: {e}", file=sys.stderr)
|
||||||
break
|
break
|
||||||
|
|
||||||
def handle_input():
|
def handle_input():
|
||||||
nonlocal running
|
nonlocal running
|
||||||
try:
|
try:
|
||||||
|
|
@ -276,7 +279,7 @@ def run_interactive_command(
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling input: {e}", file=sys.stderr)
|
print(f"Error handling input: {e}", file=sys.stderr)
|
||||||
|
|
||||||
# Start I/O threads
|
# Start I/O threads
|
||||||
stdout_thread = threading.Thread(target=read_stdout)
|
stdout_thread = threading.Thread(target=read_stdout)
|
||||||
stderr_thread = threading.Thread(target=read_stderr)
|
stderr_thread = threading.Thread(target=read_stderr)
|
||||||
|
|
@ -284,13 +287,13 @@ def run_interactive_command(
|
||||||
stderr_thread.daemon = True
|
stderr_thread.daemon = True
|
||||||
stdout_thread.start()
|
stdout_thread.start()
|
||||||
stderr_thread.start()
|
stderr_thread.start()
|
||||||
|
|
||||||
# Only start stdin thread if we're in an interactive terminal
|
# Only start stdin thread if we're in an interactive terminal
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
stdin_thread = threading.Thread(target=handle_input)
|
stdin_thread = threading.Thread(target=handle_input)
|
||||||
stdin_thread.daemon = True
|
stdin_thread.daemon = True
|
||||||
stdin_thread.start()
|
stdin_thread.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Main thread monitors timeout
|
# Main thread monitors timeout
|
||||||
while proc.poll() is None:
|
while proc.poll() is None:
|
||||||
|
|
@ -307,7 +310,7 @@ def run_interactive_command(
|
||||||
stderr_thread.join(1.0)
|
stderr_thread.join(1.0)
|
||||||
if stdin_thread:
|
if stdin_thread:
|
||||||
stdin_thread.join(1.0)
|
stdin_thread.join(1.0)
|
||||||
|
|
||||||
# Close pipes
|
# Close pipes
|
||||||
if proc.stdout:
|
if proc.stdout:
|
||||||
proc.stdout.close()
|
proc.stdout.close()
|
||||||
|
|
@ -387,23 +390,23 @@ def run_interactive_command(
|
||||||
|
|
||||||
# Ensure we have captured data even if the screen processing failed
|
# Ensure we have captured data even if the screen processing failed
|
||||||
raw_output = b"".join(captured_data)
|
raw_output = b"".join(captured_data)
|
||||||
|
|
||||||
# Process the captured output through a fresh screen
|
# Process the captured output through a fresh screen
|
||||||
try:
|
try:
|
||||||
# Create a new screen and stream for final processing
|
# Create a new screen and stream for final processing
|
||||||
screen = HistoryScreen(cols, rows, history=2000, ratio=0.5)
|
screen = HistoryScreen(cols, rows, history=2000, ratio=0.5)
|
||||||
stream = pyte.Stream(screen)
|
stream = pyte.Stream(screen)
|
||||||
|
|
||||||
# Feed all captured data at once to get the final state
|
# Feed all captured data at once to get the final state
|
||||||
raw_output = b"".join(captured_data)
|
raw_output = b"".join(captured_data)
|
||||||
decoded = raw_output.decode("utf-8", errors="ignore")
|
decoded = raw_output.decode("utf-8", errors="ignore")
|
||||||
stream.feed(decoded)
|
stream.feed(decoded)
|
||||||
|
|
||||||
# Get all history lines (top and bottom) and current display
|
# Get all history lines (top and bottom) and current display
|
||||||
all_lines = []
|
all_lines = []
|
||||||
|
|
||||||
# Add history.top lines (older history)
|
# Add history.top lines (older history)
|
||||||
if hasattr(screen.history.top, 'keys'):
|
if hasattr(screen.history.top, "keys"):
|
||||||
# Dictionary-like object
|
# Dictionary-like object
|
||||||
for line_num in sorted(screen.history.top.keys()):
|
for line_num in sorted(screen.history.top.keys()):
|
||||||
line = screen.history.top[line_num]
|
line = screen.history.top[line_num]
|
||||||
|
|
@ -412,12 +415,12 @@ def run_interactive_command(
|
||||||
# Deque or other iterable
|
# Deque or other iterable
|
||||||
for i, line in enumerate(screen.history.top):
|
for i, line in enumerate(screen.history.top):
|
||||||
all_lines.append(render_line(line, cols))
|
all_lines.append(render_line(line, cols))
|
||||||
|
|
||||||
# Add current display lines
|
# Add current display lines
|
||||||
all_lines.extend([render_line(line, cols) for line in screen.display])
|
all_lines.extend([render_line(line, cols) for line in screen.display])
|
||||||
|
|
||||||
# Add history.bottom lines (newer history)
|
# Add history.bottom lines (newer history)
|
||||||
if hasattr(screen.history.bottom, 'keys'):
|
if hasattr(screen.history.bottom, "keys"):
|
||||||
# Dictionary-like object
|
# Dictionary-like object
|
||||||
for line_num in sorted(screen.history.bottom.keys()):
|
for line_num in sorted(screen.history.bottom.keys()):
|
||||||
line = screen.history.bottom[line_num]
|
line = screen.history.bottom[line_num]
|
||||||
|
|
@ -426,23 +429,23 @@ def run_interactive_command(
|
||||||
# Deque or other iterable
|
# Deque or other iterable
|
||||||
for i, line in enumerate(screen.history.bottom):
|
for i, line in enumerate(screen.history.bottom):
|
||||||
all_lines.append(render_line(line, cols))
|
all_lines.append(render_line(line, cols))
|
||||||
|
|
||||||
# Trim out empty lines to get only meaningful lines
|
# Trim out empty lines to get only meaningful lines
|
||||||
# Also strip trailing whitespace from each line
|
# Also strip trailing whitespace from each line
|
||||||
trimmed_lines = [line.rstrip() for line in all_lines if line and line.strip()]
|
trimmed_lines = [line.rstrip() for line in all_lines if line and line.strip()]
|
||||||
|
|
||||||
final_output = "\n".join(trimmed_lines)
|
final_output = "\n".join(trimmed_lines)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If anything goes wrong with screen processing, fall back to raw output
|
# If anything goes wrong with screen processing, fall back to raw output
|
||||||
print(f"Warning: Error processing terminal output: {e}", file=sys.stderr)
|
print(f"Warning: Error processing terminal output: {e}", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
# Decode raw output, strip trailing whitespace from each line
|
# Decode raw output, strip trailing whitespace from each line
|
||||||
decoded = raw_output.decode('utf-8', errors='replace')
|
decoded = raw_output.decode("utf-8", errors="replace")
|
||||||
lines = [line.rstrip() for line in decoded.splitlines()]
|
lines = [line.rstrip() for line in decoded.splitlines()]
|
||||||
final_output = "\n".join(lines)
|
final_output = "\n".join(lines)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Ultimate fallback if line processing fails
|
# Ultimate fallback if line processing fails
|
||||||
final_output = raw_output.decode('utf-8', errors='replace').strip()
|
final_output = raw_output.decode("utf-8", errors="replace").strip()
|
||||||
|
|
||||||
# Add timeout message if process was terminated due to timeout.
|
# Add timeout message if process was terminated due to timeout.
|
||||||
if was_terminated:
|
if was_terminated:
|
||||||
|
|
@ -458,7 +461,7 @@ def run_interactive_command(
|
||||||
else:
|
else:
|
||||||
# Handle any unexpected type
|
# Handle any unexpected type
|
||||||
final_output = str(final_output)[-8000:].encode("utf-8")
|
final_output = str(final_output)[-8000:].encode("utf-8")
|
||||||
|
|
||||||
return final_output, proc.returncode
|
return final_output, proc.returncode
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,9 @@ If you find this is an empty directory, you can stop research immediately and as
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
RESEARCH_PROMPT = RESEARCH_COMMON_PROMPT_HEADER + """
|
RESEARCH_PROMPT = (
|
||||||
|
RESEARCH_COMMON_PROMPT_HEADER
|
||||||
|
+ """
|
||||||
|
|
||||||
Project State Handling:
|
Project State Handling:
|
||||||
For new/empty projects:
|
For new/empty projects:
|
||||||
|
|
@ -280,9 +282,12 @@ NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||||
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATINO IS REQUIRED, CALL request_implementation.
|
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATINO IS REQUIRED, CALL request_implementation.
|
||||||
IF THE USER ASKED YOU TO UPDATE A FILE, JUST DO RESEARCH FIRST, EMIT YOUR RESEARCH NOTES, THEN CALL request_implementation.
|
IF THE USER ASKED YOU TO UPDATE A FILE, JUST DO RESEARCH FIRST, EMIT YOUR RESEARCH NOTES, THEN CALL request_implementation.
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Research-only prompt - similar to research prompt but without implementation references
|
# Research-only prompt - similar to research prompt but without implementation references
|
||||||
RESEARCH_ONLY_PROMPT = RESEARCH_COMMON_PROMPT_HEADER + """
|
RESEARCH_ONLY_PROMPT = (
|
||||||
|
RESEARCH_COMMON_PROMPT_HEADER
|
||||||
|
+ """
|
||||||
|
|
||||||
You have been spawned by a higher level research agent, so only spawn more research tasks sparingly if absolutely necessary. Keep your research *very* scoped and efficient.
|
You have been spawned by a higher level research agent, so only spawn more research tasks sparingly if absolutely necessary. Keep your research *very* scoped and efficient.
|
||||||
|
|
||||||
|
|
@ -290,6 +295,7 @@ When you emit research notes, keep it extremely concise and relevant only to the
|
||||||
|
|
||||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Web research prompt - guides web search and information gathering
|
# Web research prompt - guides web search and information gathering
|
||||||
WEB_RESEARCH_PROMPT = """Current Date: {current_date}
|
WEB_RESEARCH_PROMPT = """Current Date: {current_date}
|
||||||
|
|
|
||||||
|
|
@ -162,7 +162,9 @@ class AnthropicStrategy(ProviderStrategy):
|
||||||
if not base_key:
|
if not base_key:
|
||||||
missing.append("ANTHROPIC_API_KEY environment variable is not set")
|
missing.append("ANTHROPIC_API_KEY environment variable is not set")
|
||||||
else:
|
else:
|
||||||
missing.append("EXPERT_ANTHROPIC_API_KEY environment variable is not set")
|
missing.append(
|
||||||
|
"EXPERT_ANTHROPIC_API_KEY environment variable is not set"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
key = os.environ.get("ANTHROPIC_API_KEY")
|
key = os.environ.get("ANTHROPIC_API_KEY")
|
||||||
if not key:
|
if not key:
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,17 @@
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
AgentContext,
|
AgentContext,
|
||||||
agent_context,
|
agent_context,
|
||||||
get_current_context,
|
|
||||||
mark_task_completed,
|
|
||||||
mark_plan_completed,
|
|
||||||
reset_completion_flags,
|
|
||||||
is_completed,
|
|
||||||
get_completion_message,
|
get_completion_message,
|
||||||
|
get_current_context,
|
||||||
|
is_completed,
|
||||||
|
mark_plan_completed,
|
||||||
|
mark_task_completed,
|
||||||
|
reset_completion_flags,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,7 +30,7 @@ class TestAgentContext:
|
||||||
"""Test that child contexts inherit state from parent contexts."""
|
"""Test that child contexts inherit state from parent contexts."""
|
||||||
parent = AgentContext()
|
parent = AgentContext()
|
||||||
parent.mark_task_completed("Parent task completed")
|
parent.mark_task_completed("Parent task completed")
|
||||||
|
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
assert child.task_completed is True
|
assert child.task_completed is True
|
||||||
assert child.completion_message == "Parent task completed"
|
assert child.completion_message == "Parent task completed"
|
||||||
|
|
@ -39,7 +39,7 @@ class TestAgentContext:
|
||||||
"""Test marking a task as completed."""
|
"""Test marking a task as completed."""
|
||||||
context = AgentContext()
|
context = AgentContext()
|
||||||
context.mark_task_completed("Task done")
|
context.mark_task_completed("Task done")
|
||||||
|
|
||||||
assert context.task_completed is True
|
assert context.task_completed is True
|
||||||
assert context.plan_completed is False
|
assert context.plan_completed is False
|
||||||
assert context.completion_message == "Task done"
|
assert context.completion_message == "Task done"
|
||||||
|
|
@ -48,7 +48,7 @@ class TestAgentContext:
|
||||||
"""Test marking a plan as completed."""
|
"""Test marking a plan as completed."""
|
||||||
context = AgentContext()
|
context = AgentContext()
|
||||||
context.mark_plan_completed("Plan done")
|
context.mark_plan_completed("Plan done")
|
||||||
|
|
||||||
assert context.task_completed is True
|
assert context.task_completed is True
|
||||||
assert context.plan_completed is True
|
assert context.plan_completed is True
|
||||||
assert context.completion_message == "Plan done"
|
assert context.completion_message == "Plan done"
|
||||||
|
|
@ -57,7 +57,7 @@ class TestAgentContext:
|
||||||
"""Test resetting completion flags."""
|
"""Test resetting completion flags."""
|
||||||
context = AgentContext()
|
context = AgentContext()
|
||||||
context.mark_task_completed("Task done")
|
context.mark_task_completed("Task done")
|
||||||
|
|
||||||
context.reset_completion_flags()
|
context.reset_completion_flags()
|
||||||
assert context.task_completed is False
|
assert context.task_completed is False
|
||||||
assert context.plan_completed is False
|
assert context.plan_completed is False
|
||||||
|
|
@ -67,13 +67,13 @@ class TestAgentContext:
|
||||||
"""Test the is_completed property."""
|
"""Test the is_completed property."""
|
||||||
context = AgentContext()
|
context = AgentContext()
|
||||||
assert context.is_completed is False
|
assert context.is_completed is False
|
||||||
|
|
||||||
context.mark_task_completed("Task done")
|
context.mark_task_completed("Task done")
|
||||||
assert context.is_completed is True
|
assert context.is_completed is True
|
||||||
|
|
||||||
context.reset_completion_flags()
|
context.reset_completion_flags()
|
||||||
assert context.is_completed is False
|
assert context.is_completed is False
|
||||||
|
|
||||||
context.mark_plan_completed("Plan done")
|
context.mark_plan_completed("Plan done")
|
||||||
assert context.is_completed is True
|
assert context.is_completed is True
|
||||||
|
|
||||||
|
|
@ -84,29 +84,29 @@ class TestContextManager:
|
||||||
def test_context_manager_basic(self):
|
def test_context_manager_basic(self):
|
||||||
"""Test basic context manager functionality."""
|
"""Test basic context manager functionality."""
|
||||||
assert get_current_context() is None
|
assert get_current_context() is None
|
||||||
|
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
assert get_current_context() is ctx
|
assert get_current_context() is ctx
|
||||||
assert ctx.task_completed is False
|
assert ctx.task_completed is False
|
||||||
|
|
||||||
assert get_current_context() is None
|
assert get_current_context() is None
|
||||||
|
|
||||||
def test_nested_context_managers(self):
|
def test_nested_context_managers(self):
|
||||||
"""Test nested context managers."""
|
"""Test nested context managers."""
|
||||||
with agent_context() as outer_ctx:
|
with agent_context() as outer_ctx:
|
||||||
assert get_current_context() is outer_ctx
|
assert get_current_context() is outer_ctx
|
||||||
|
|
||||||
with agent_context() as inner_ctx:
|
with agent_context() as inner_ctx:
|
||||||
assert get_current_context() is inner_ctx
|
assert get_current_context() is inner_ctx
|
||||||
assert inner_ctx is not outer_ctx
|
assert inner_ctx is not outer_ctx
|
||||||
|
|
||||||
assert get_current_context() is outer_ctx
|
assert get_current_context() is outer_ctx
|
||||||
|
|
||||||
def test_context_manager_with_parent(self):
|
def test_context_manager_with_parent(self):
|
||||||
"""Test context manager with explicit parent context."""
|
"""Test context manager with explicit parent context."""
|
||||||
parent = AgentContext()
|
parent = AgentContext()
|
||||||
parent.mark_task_completed("Parent task")
|
parent.mark_task_completed("Parent task")
|
||||||
|
|
||||||
with agent_context(parent_context=parent) as ctx:
|
with agent_context(parent_context=parent) as ctx:
|
||||||
assert ctx.task_completed is True
|
assert ctx.task_completed is True
|
||||||
assert ctx.completion_message == "Parent task"
|
assert ctx.completion_message == "Parent task"
|
||||||
|
|
@ -115,13 +115,13 @@ class TestContextManager:
|
||||||
"""Test that nested contexts inherit from outer contexts by default."""
|
"""Test that nested contexts inherit from outer contexts by default."""
|
||||||
with agent_context() as outer:
|
with agent_context() as outer:
|
||||||
outer.mark_task_completed("Outer task")
|
outer.mark_task_completed("Outer task")
|
||||||
|
|
||||||
with agent_context() as inner:
|
with agent_context() as inner:
|
||||||
assert inner.task_completed is True
|
assert inner.task_completed is True
|
||||||
assert inner.completion_message == "Outer task"
|
assert inner.completion_message == "Outer task"
|
||||||
|
|
||||||
inner.mark_plan_completed("Inner plan")
|
inner.mark_plan_completed("Inner plan")
|
||||||
|
|
||||||
# Outer context should not be affected by inner context changes
|
# Outer context should not be affected by inner context changes
|
||||||
assert outer.task_completed is True
|
assert outer.task_completed is True
|
||||||
assert outer.plan_completed is False
|
assert outer.plan_completed is False
|
||||||
|
|
@ -134,23 +134,23 @@ class TestThreadIsolation:
|
||||||
def test_thread_isolation(self):
|
def test_thread_isolation(self):
|
||||||
"""Test that contexts are isolated between threads."""
|
"""Test that contexts are isolated between threads."""
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
def thread_func(thread_id):
|
def thread_func(thread_id):
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
ctx.mark_task_completed(f"Thread {thread_id}")
|
ctx.mark_task_completed(f"Thread {thread_id}")
|
||||||
time.sleep(0.1) # Give other threads time to run
|
time.sleep(0.1) # Give other threads time to run
|
||||||
# Store the context's message for verification
|
# Store the context's message for verification
|
||||||
results[thread_id] = get_completion_message()
|
results[thread_id] = get_completion_message()
|
||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
t = threading.Thread(target=thread_func, args=(i,))
|
t = threading.Thread(target=thread_func, args=(i,))
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
# Each thread should have its own message
|
# Each thread should have its own message
|
||||||
assert results[0] == "Thread 0"
|
assert results[0] == "Thread 0"
|
||||||
assert results[1] == "Thread 1"
|
assert results[1] == "Thread 1"
|
||||||
|
|
@ -188,26 +188,16 @@ class TestUtilityFunctions:
|
||||||
mark_task_completed("No context")
|
mark_task_completed("No context")
|
||||||
mark_plan_completed("No context")
|
mark_plan_completed("No context")
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
|
|
||||||
# These should have safe default returns
|
# These should have safe default returns
|
||||||
assert is_completed() is False
|
assert is_completed() is False
|
||||||
assert get_completion_message() == ""
|
assert get_completion_message() == ""
|
||||||
|
|
||||||
|
|
||||||
"""Unit tests for the agent_context module."""
|
"""Unit tests for the agent_context module."""
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
|
||||||
AgentContext,
|
|
||||||
agent_context,
|
|
||||||
get_current_context,
|
|
||||||
mark_task_completed,
|
|
||||||
mark_plan_completed,
|
|
||||||
reset_completion_flags,
|
|
||||||
is_completed,
|
|
||||||
get_completion_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentContext:
|
class TestAgentContext:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
"""Unit tests for agent_should_exit functionality."""
|
"""Unit tests for agent_should_exit functionality."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
AgentContext,
|
AgentContext,
|
||||||
agent_context,
|
agent_context,
|
||||||
get_current_context,
|
|
||||||
mark_should_exit,
|
mark_should_exit,
|
||||||
should_exit,
|
should_exit,
|
||||||
)
|
)
|
||||||
|
|
@ -18,7 +16,7 @@ class TestAgentShouldExit:
|
||||||
"""Test basic mark_should_exit functionality."""
|
"""Test basic mark_should_exit functionality."""
|
||||||
context = AgentContext()
|
context = AgentContext()
|
||||||
assert context.agent_should_exit is False
|
assert context.agent_should_exit is False
|
||||||
|
|
||||||
context.mark_should_exit()
|
context.mark_should_exit()
|
||||||
assert context.agent_should_exit is True
|
assert context.agent_should_exit is True
|
||||||
|
|
||||||
|
|
@ -34,10 +32,10 @@ class TestAgentShouldExit:
|
||||||
"""Test that mark_should_exit propagates to parent contexts."""
|
"""Test that mark_should_exit propagates to parent contexts."""
|
||||||
parent = AgentContext()
|
parent = AgentContext()
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Mark child as should exit
|
# Mark child as should exit
|
||||||
child.mark_should_exit()
|
child.mark_should_exit()
|
||||||
|
|
||||||
# Verify both child and parent are marked
|
# Verify both child and parent are marked
|
||||||
assert child.agent_should_exit is True
|
assert child.agent_should_exit is True
|
||||||
assert parent.agent_should_exit is True
|
assert parent.agent_should_exit is True
|
||||||
|
|
@ -49,10 +47,10 @@ class TestAgentShouldExit:
|
||||||
# Initially both should be False
|
# Initially both should be False
|
||||||
assert outer.agent_should_exit is False
|
assert outer.agent_should_exit is False
|
||||||
assert inner.agent_should_exit is False
|
assert inner.agent_should_exit is False
|
||||||
|
|
||||||
# Mark inner as should exit
|
# Mark inner as should exit
|
||||||
inner.mark_should_exit()
|
inner.mark_should_exit()
|
||||||
|
|
||||||
# Both should now be True
|
# Both should now be True
|
||||||
assert inner.agent_should_exit is True
|
assert inner.agent_should_exit is True
|
||||||
assert outer.agent_should_exit is True
|
assert outer.agent_should_exit is True
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,9 @@ from ra_aid.tools import (
|
||||||
emit_related_files,
|
emit_related_files,
|
||||||
emit_research_notes,
|
emit_research_notes,
|
||||||
file_str_replace,
|
file_str_replace,
|
||||||
put_complete_file_contents,
|
|
||||||
fuzzy_find_project_files,
|
fuzzy_find_project_files,
|
||||||
list_directory_tree,
|
list_directory_tree,
|
||||||
|
put_complete_file_contents,
|
||||||
read_file_tool,
|
read_file_tool,
|
||||||
ripgrep_search,
|
ripgrep_search,
|
||||||
run_programming_task,
|
run_programming_task,
|
||||||
|
|
@ -26,12 +26,12 @@ from ra_aid.tools.agent import (
|
||||||
request_task_implementation,
|
request_task_implementation,
|
||||||
request_web_research,
|
request_web_research,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.memory import one_shot_completed, plan_implementation_completed
|
from ra_aid.tools.memory import plan_implementation_completed
|
||||||
|
|
||||||
|
|
||||||
def set_modification_tools(use_aider=False):
|
def set_modification_tools(use_aider=False):
|
||||||
"""Set the MODIFICATION_TOOLS list based on configuration.
|
"""Set the MODIFICATION_TOOLS list based on configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_aider: Whether to use run_programming_task (True) or file modification tools (False)
|
use_aider: Whether to use run_programming_task (True) or file modification tools (False)
|
||||||
"""
|
"""
|
||||||
|
|
@ -46,7 +46,9 @@ def set_modification_tools(use_aider=False):
|
||||||
|
|
||||||
# Read-only tools that don't modify system state
|
# Read-only tools that don't modify system state
|
||||||
def get_read_only_tools(
|
def get_read_only_tools(
|
||||||
human_interaction: bool = False, web_research_enabled: bool = False, use_aider: bool = False
|
human_interaction: bool = False,
|
||||||
|
web_research_enabled: bool = False,
|
||||||
|
use_aider: bool = False,
|
||||||
):
|
):
|
||||||
"""Get the list of read-only tools, optionally including human interaction tools.
|
"""Get the list of read-only tools, optionally including human interaction tools.
|
||||||
|
|
||||||
|
|
@ -100,6 +102,7 @@ def get_all_tools() -> list[BaseTool]:
|
||||||
_config = {}
|
_config = {}
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
|
||||||
_config = _global_memory.get("config", {})
|
_config = _global_memory.get("config", {})
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
@ -137,15 +140,14 @@ def get_research_tools(
|
||||||
use_aider = False
|
use_aider = False
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
|
||||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
tools = get_read_only_tools(
|
tools = get_read_only_tools(
|
||||||
human_interaction,
|
human_interaction, web_research_enabled, use_aider=use_aider
|
||||||
web_research_enabled,
|
|
||||||
use_aider=use_aider
|
|
||||||
).copy()
|
).copy()
|
||||||
|
|
||||||
tools.extend(RESEARCH_TOOLS)
|
tools.extend(RESEARCH_TOOLS)
|
||||||
|
|
@ -179,14 +181,14 @@ def get_planning_tools(
|
||||||
use_aider = False
|
use_aider = False
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
|
||||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
tools = get_read_only_tools(
|
tools = get_read_only_tools(
|
||||||
web_research_enabled=web_research_enabled,
|
web_research_enabled=web_research_enabled, use_aider=use_aider
|
||||||
use_aider=use_aider
|
|
||||||
).copy()
|
).copy()
|
||||||
|
|
||||||
# Add planning-specific tools
|
# Add planning-specific tools
|
||||||
|
|
@ -218,14 +220,14 @@ def get_implementation_tools(
|
||||||
use_aider = False
|
use_aider = False
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
|
||||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
tools = get_read_only_tools(
|
tools = get_read_only_tools(
|
||||||
web_research_enabled=web_research_enabled,
|
web_research_enabled=web_research_enabled, use_aider=use_aider
|
||||||
use_aider=use_aider
|
|
||||||
).copy()
|
).copy()
|
||||||
|
|
||||||
# Add modification tools since it's not research-only
|
# Add modification tools since it's not research-only
|
||||||
|
|
@ -283,4 +285,4 @@ def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = Fal
|
||||||
if web_research_enabled:
|
if web_research_enabled:
|
||||||
tools.append(request_web_research)
|
tools.append(request_web_research)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,12 @@ from typing import Any, Dict, List, Union
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from ra_aid.agent_context import get_completion_message, get_crash_message, is_crashed, reset_completion_flags
|
from ra_aid.agent_context import (
|
||||||
|
get_completion_message,
|
||||||
|
get_crash_message,
|
||||||
|
is_crashed,
|
||||||
|
reset_completion_flags,
|
||||||
|
)
|
||||||
from ra_aid.console.formatting import print_error
|
from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
|
@ -85,7 +90,9 @@ def request_research(query: str) -> ResearchResult:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
finally:
|
finally:
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
completion_message = get_completion_message() or (
|
||||||
|
"Task was completed successfully." if success else None
|
||||||
|
)
|
||||||
|
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
|
|
@ -149,7 +156,9 @@ def request_web_research(query: str) -> ResearchResult:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
finally:
|
finally:
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
completion_message = get_completion_message() or (
|
||||||
|
"Task was completed successfully." if success else None
|
||||||
|
)
|
||||||
|
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
|
|
@ -215,7 +224,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
completion_message = get_completion_message() or (
|
||||||
|
"Task was completed successfully." if success else None
|
||||||
|
)
|
||||||
|
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
|
|
@ -293,7 +304,9 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
completion_message = get_completion_message() or (
|
||||||
|
"Task was completed successfully." if success else None
|
||||||
|
)
|
||||||
|
|
||||||
# Get and reset work log if at root depth
|
# Get and reset work log if at root depth
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
@ -317,45 +330,53 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
}
|
}
|
||||||
if work_log is not None:
|
if work_log is not None:
|
||||||
response_data["work_log"] = work_log
|
response_data["work_log"] = work_log
|
||||||
|
|
||||||
# Convert the response data to a markdown string
|
# Convert the response data to a markdown string
|
||||||
markdown_parts = []
|
markdown_parts = []
|
||||||
|
|
||||||
# Add header and completion message
|
# Add header and completion message
|
||||||
markdown_parts.append("# Task Implementation")
|
markdown_parts.append("# Task Implementation")
|
||||||
if response_data.get("completion_message"):
|
if response_data.get("completion_message"):
|
||||||
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
|
markdown_parts.append(
|
||||||
|
f"\n## Completion Message\n\n{response_data['completion_message']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add crash information if applicable
|
# Add crash information if applicable
|
||||||
if response_data.get("agent_crashed"):
|
if response_data.get("agent_crashed"):
|
||||||
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
|
markdown_parts.append(
|
||||||
|
f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add success status
|
# Add success status
|
||||||
status = "Success" if response_data.get("success", False) else "Failed"
|
status = "Success" if response_data.get("success", False) else "Failed"
|
||||||
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
reason_text = (
|
||||||
|
f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||||
|
)
|
||||||
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
|
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
|
||||||
|
|
||||||
# Add key facts
|
# Add key facts
|
||||||
if response_data.get("key_facts"):
|
if response_data.get("key_facts"):
|
||||||
markdown_parts.append(f"\n## Key Facts\n\n{response_data['key_facts']}")
|
markdown_parts.append(f"\n## Key Facts\n\n{response_data['key_facts']}")
|
||||||
|
|
||||||
# Add related files
|
# Add related files
|
||||||
if response_data.get("related_files"):
|
if response_data.get("related_files"):
|
||||||
files_list = "\n".join([f"- {file}" for file in response_data["related_files"]])
|
files_list = "\n".join([f"- {file}" for file in response_data["related_files"]])
|
||||||
markdown_parts.append(f"\n## Related Files\n\n{files_list}")
|
markdown_parts.append(f"\n## Related Files\n\n{files_list}")
|
||||||
|
|
||||||
# Add key snippets
|
# Add key snippets
|
||||||
if response_data.get("key_snippets"):
|
if response_data.get("key_snippets"):
|
||||||
markdown_parts.append(f"\n## Key Snippets\n\n{response_data['key_snippets']}")
|
markdown_parts.append(f"\n## Key Snippets\n\n{response_data['key_snippets']}")
|
||||||
|
|
||||||
# Add work log
|
# Add work log
|
||||||
if response_data.get("work_log"):
|
if response_data.get("work_log"):
|
||||||
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
|
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
|
||||||
markdown_parts.append(f"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**")
|
markdown_parts.append(
|
||||||
|
"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**"
|
||||||
|
)
|
||||||
|
|
||||||
# Join all parts into a single markdown string
|
# Join all parts into a single markdown string
|
||||||
markdown_output = "".join(markdown_parts)
|
markdown_output = "".join(markdown_parts)
|
||||||
|
|
||||||
return markdown_output
|
return markdown_output
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -403,7 +424,9 @@ def request_implementation(task_spec: str) -> str:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
completion_message = get_completion_message() or (
|
||||||
|
"Task was completed successfully." if success else None
|
||||||
|
)
|
||||||
|
|
||||||
# Get and reset work log if at root depth
|
# Get and reset work log if at root depth
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
@ -427,43 +450,51 @@ def request_implementation(task_spec: str) -> str:
|
||||||
}
|
}
|
||||||
if work_log is not None:
|
if work_log is not None:
|
||||||
response_data["work_log"] = work_log
|
response_data["work_log"] = work_log
|
||||||
|
|
||||||
# Convert the response data to a markdown string
|
# Convert the response data to a markdown string
|
||||||
markdown_parts = []
|
markdown_parts = []
|
||||||
|
|
||||||
# Add header and completion message
|
# Add header and completion message
|
||||||
markdown_parts.append("# Implementation Plan")
|
markdown_parts.append("# Implementation Plan")
|
||||||
if response_data.get("completion_message"):
|
if response_data.get("completion_message"):
|
||||||
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
|
markdown_parts.append(
|
||||||
|
f"\n## Completion Message\n\n{response_data['completion_message']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add crash information if applicable
|
# Add crash information if applicable
|
||||||
if response_data.get("agent_crashed"):
|
if response_data.get("agent_crashed"):
|
||||||
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
|
markdown_parts.append(
|
||||||
|
f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add success status
|
# Add success status
|
||||||
status = "Success" if response_data.get("success", False) else "Failed"
|
status = "Success" if response_data.get("success", False) else "Failed"
|
||||||
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
reason_text = (
|
||||||
|
f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||||
|
)
|
||||||
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
|
markdown_parts.append(f"\n## Status\n\n**{status}**{reason_text}")
|
||||||
|
|
||||||
# Add key facts
|
# Add key facts
|
||||||
if response_data.get("key_facts"):
|
if response_data.get("key_facts"):
|
||||||
markdown_parts.append(f"\n## Key Facts\n\n{response_data['key_facts']}")
|
markdown_parts.append(f"\n## Key Facts\n\n{response_data['key_facts']}")
|
||||||
|
|
||||||
# Add related files
|
# Add related files
|
||||||
if response_data.get("related_files"):
|
if response_data.get("related_files"):
|
||||||
files_list = "\n".join([f"- {file}" for file in response_data["related_files"]])
|
files_list = "\n".join([f"- {file}" for file in response_data["related_files"]])
|
||||||
markdown_parts.append(f"\n## Related Files\n\n{files_list}")
|
markdown_parts.append(f"\n## Related Files\n\n{files_list}")
|
||||||
|
|
||||||
# Add key snippets
|
# Add key snippets
|
||||||
if response_data.get("key_snippets"):
|
if response_data.get("key_snippets"):
|
||||||
markdown_parts.append(f"\n## Key Snippets\n\n{response_data['key_snippets']}")
|
markdown_parts.append(f"\n## Key Snippets\n\n{response_data['key_snippets']}")
|
||||||
|
|
||||||
# Add work log
|
# Add work log
|
||||||
if response_data.get("work_log"):
|
if response_data.get("work_log"):
|
||||||
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
|
markdown_parts.append(f"\n## Work Log\n\n{response_data['work_log']}")
|
||||||
markdown_parts.append(f"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**")
|
markdown_parts.append(
|
||||||
|
"\n\nTHE ABOVE WORK HAS ALREADY BEEN COMPLETED --**DO NOT REQUEST IMPLEMENTATION OF IT AGAIN**"
|
||||||
|
)
|
||||||
|
|
||||||
# Join all parts into a single markdown string
|
# Join all parts into a single markdown string
|
||||||
markdown_output = "".join(markdown_parts)
|
markdown_output = "".join(markdown_parts)
|
||||||
|
|
||||||
return markdown_output
|
return markdown_output
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import magic
|
import magic
|
||||||
|
|
@ -12,7 +12,11 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from ra_aid.agent_context import mark_task_completed, mark_plan_completed, mark_should_exit
|
from ra_aid.agent_context import (
|
||||||
|
mark_plan_completed,
|
||||||
|
mark_should_exit,
|
||||||
|
mark_task_completed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkLogEntry(TypedDict):
|
class WorkLogEntry(TypedDict):
|
||||||
|
|
@ -389,7 +393,7 @@ def emit_related_files(files: List[str]) -> str:
|
||||||
invalid_paths.append(file)
|
invalid_paths.append(file)
|
||||||
results.append(f"Error: Path '{file}' exists but is not a regular file")
|
results.append(f"Error: Path '{file}' exists but is not a regular file")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if it's a binary file
|
# Check if it's a binary file
|
||||||
if is_binary_file(file):
|
if is_binary_file(file):
|
||||||
binary_files.append(file)
|
binary_files.append(file)
|
||||||
|
|
@ -430,7 +434,7 @@ def emit_related_files(files: List[str]) -> str:
|
||||||
border_style="green",
|
border_style="green",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display skipped binary files
|
# Display skipped binary files
|
||||||
if binary_files:
|
if binary_files:
|
||||||
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
|
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
|
||||||
|
|
@ -478,18 +482,36 @@ def is_binary_file(filepath):
|
||||||
if magic:
|
if magic:
|
||||||
try:
|
try:
|
||||||
mime = magic.from_file(filepath, mime=True)
|
mime = magic.from_file(filepath, mime=True)
|
||||||
return not mime.startswith('text/')
|
file_type = magic.from_file(filepath)
|
||||||
except Exception:
|
|
||||||
# Fallback if magic fails
|
if not mime.startswith("text/"):
|
||||||
return False
|
return True
|
||||||
else:
|
|
||||||
# Basic binary detection if magic is not available
|
if "ASCII text" in file_type:
|
||||||
try:
|
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
|
||||||
f.read(1024) # Try to read as text
|
|
||||||
return False
|
return False
|
||||||
except UnicodeDecodeError:
|
|
||||||
return True
|
return True
|
||||||
|
except Exception:
|
||||||
|
return _is_binary_fallback(filepath)
|
||||||
|
else:
|
||||||
|
return _is_binary_fallback(filepath)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_binary_fallback(filepath):
|
||||||
|
"""Fallback method to detect binary files without using magic."""
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
chunk = f.read(1024)
|
||||||
|
|
||||||
|
# Check for null bytes which indicate binary content
|
||||||
|
if "\0" in chunk:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If we can read it as text without errors, it's probably not binary
|
||||||
|
return False
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# If we can't decode as UTF-8, it's likely binary
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_work_log() -> str:
|
def get_work_log() -> str:
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,9 @@ def run_programming_task(
|
||||||
|
|
||||||
# Return structured output
|
# Return structured output
|
||||||
return {
|
return {
|
||||||
"output": (truncate_output(result[0].decode()) + extra_ins) if result[0] else "",
|
"output": (truncate_output(result[0].decode()) + extra_ins)
|
||||||
|
if result[0]
|
||||||
|
else "",
|
||||||
"return_code": result[1],
|
"return_code": result[1],
|
||||||
"success": result[1] == 0,
|
"success": result[1] == 0,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,7 @@ def ripgrep_search(
|
||||||
params.append(f"**Before Context Lines**: {before_context_lines}")
|
params.append(f"**Before Context Lines**: {before_context_lines}")
|
||||||
if after_context_lines is not None:
|
if after_context_lines is not None:
|
||||||
params.append(f"**After Context Lines**: {after_context_lines}")
|
params.append(f"**After Context Lines**: {after_context_lines}")
|
||||||
|
|
||||||
if include_hidden:
|
if include_hidden:
|
||||||
params.append("**Including Hidden Files**: yes")
|
params.append("**Including Hidden Files**: yes")
|
||||||
if follow_links:
|
if follow_links:
|
||||||
|
|
|
||||||
|
|
@ -61,9 +61,7 @@ def put_complete_file_contents(
|
||||||
f"at {filepath} in {result['elapsed_time']:.3f}s"
|
f"at {filepath} in {result['elapsed_time']:.3f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(f"File write complete: {bytes_written} bytes in {elapsed:.2f}s")
|
||||||
f"File write complete: {bytes_written} bytes in {elapsed:.2f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,21 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import queue
|
|
||||||
import traceback
|
|
||||||
import shutil
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.WARNING,
|
level=logging.WARNING,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.StreamHandler(sys.__stderr__) # Use the real stderr
|
logging.StreamHandler(sys.__stderr__) # Use the real stderr
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -26,12 +24,12 @@ project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
@ -55,10 +53,12 @@ app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
||||||
# Store active WebSocket connections
|
# Store active WebSocket connections
|
||||||
active_connections: List[WebSocket] = []
|
active_connections: List[WebSocket] = []
|
||||||
|
|
||||||
|
|
||||||
def run_ra_aid(message_content, output_queue):
|
def run_ra_aid(message_content, output_queue):
|
||||||
"""Run ra-aid in a separate thread"""
|
"""Run ra-aid in a separate thread"""
|
||||||
try:
|
try:
|
||||||
import ra_aid.__main__
|
import ra_aid.__main__
|
||||||
|
|
||||||
logger.info("Successfully imported ra_aid.__main__")
|
logger.info("Successfully imported ra_aid.__main__")
|
||||||
|
|
||||||
# Override sys.argv
|
# Override sys.argv
|
||||||
|
|
@ -72,47 +72,47 @@ def run_ra_aid(message_content, output_queue):
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.box_start = False
|
self.box_start = False
|
||||||
self._real_stderr = sys.__stderr__
|
self._real_stderr = sys.__stderr__
|
||||||
|
|
||||||
def write(self, text):
|
def write(self, text):
|
||||||
# Always log raw output for debugging
|
# Always log raw output for debugging
|
||||||
logger.debug(f"Raw output: {repr(text)}")
|
logger.debug(f"Raw output: {repr(text)}")
|
||||||
|
|
||||||
# Check if this is a box drawing character
|
# Check if this is a box drawing character
|
||||||
if any(c in text for c in '╭╮╰╯│─'):
|
if any(c in text for c in "╭╮╰╯│─"):
|
||||||
self.box_start = True
|
self.box_start = True
|
||||||
self.buffer.append(text)
|
self.buffer.append(text)
|
||||||
elif self.box_start and text.strip():
|
elif self.box_start and text.strip():
|
||||||
self.buffer.append(text)
|
self.buffer.append(text)
|
||||||
if '╯' in text: # End of box
|
if "╯" in text: # End of box
|
||||||
full_text = ''.join(self.buffer)
|
full_text = "".join(self.buffer)
|
||||||
# Extract content from inside the box
|
# Extract content from inside the box
|
||||||
lines = full_text.split('\n')
|
lines = full_text.split("\n")
|
||||||
content_lines = []
|
content_lines = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
# Remove box characters and leading/trailing spaces
|
# Remove box characters and leading/trailing spaces
|
||||||
clean_line = line.strip('╭╮╰╯│─ ')
|
clean_line = line.strip("╭╮╰╯│─ ")
|
||||||
if clean_line:
|
if clean_line:
|
||||||
content_lines.append(clean_line)
|
content_lines.append(clean_line)
|
||||||
if content_lines:
|
if content_lines:
|
||||||
self.queue.put('\n'.join(content_lines))
|
self.queue.put("\n".join(content_lines))
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.box_start = False
|
self.box_start = False
|
||||||
elif not self.box_start and text.strip():
|
elif not self.box_start and text.strip():
|
||||||
self.queue.put(text.strip())
|
self.queue.put(text.strip())
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
if self.buffer:
|
if self.buffer:
|
||||||
full_text = ''.join(self.buffer)
|
full_text = "".join(self.buffer)
|
||||||
# Extract content from partial box
|
# Extract content from partial box
|
||||||
lines = full_text.split('\n')
|
lines = full_text.split("\n")
|
||||||
content_lines = []
|
content_lines = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
# Remove box characters and leading/trailing spaces
|
# Remove box characters and leading/trailing spaces
|
||||||
clean_line = line.strip('╭╮╰╯│─ ')
|
clean_line = line.strip("╭╮╰╯│─ ")
|
||||||
if clean_line:
|
if clean_line:
|
||||||
content_lines.append(clean_line)
|
content_lines.append(clean_line)
|
||||||
if content_lines:
|
if content_lines:
|
||||||
self.queue.put('\n'.join(content_lines))
|
self.queue.put("\n".join(content_lines))
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.box_start = False
|
self.box_start = False
|
||||||
|
|
||||||
|
|
@ -144,6 +144,7 @@ def run_ra_aid(message_content, output_queue):
|
||||||
traceback.print_exc(file=sys.__stderr__)
|
traceback.print_exc(file=sys.__stderr__)
|
||||||
output_queue.put(f"Error: {str(e)}")
|
output_queue.put(f"Error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/", response_class=HTMLResponse)
|
@app.get("/", response_class=HTMLResponse)
|
||||||
async def get_root(request: Request):
|
async def get_root(request: Request):
|
||||||
"""Serve the index.html file with port parameter."""
|
"""Serve the index.html file with port parameter."""
|
||||||
|
|
@ -151,6 +152,7 @@ async def get_root(request: Request):
|
||||||
"index.html", {"request": request, "server_port": request.url.port or 8080}
|
"index.html", {"request": request, "server_port": request.url.port or 8080}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
@ -170,7 +172,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||||
output_queue = queue.Queue()
|
output_queue = queue.Queue()
|
||||||
|
|
||||||
# Create and start thread
|
# Create and start thread
|
||||||
thread = threading.Thread(target=run_ra_aid, args=(content, output_queue))
|
thread = threading.Thread(
|
||||||
|
target=run_ra_aid, args=(content, output_queue)
|
||||||
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -183,17 +187,21 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||||
line = output_queue.get(timeout=0.1)
|
line = output_queue.get(timeout=0.1)
|
||||||
if line and line.strip(): # Only send non-empty messages
|
if line and line.strip(): # Only send non-empty messages
|
||||||
logger.debug(f"WebSocket sending: {repr(line)}")
|
logger.debug(f"WebSocket sending: {repr(line)}")
|
||||||
await websocket.send_json({
|
await websocket.send_json(
|
||||||
"type": "chunk",
|
{
|
||||||
"chunk": {
|
"type": "chunk",
|
||||||
"agent": {
|
"chunk": {
|
||||||
"messages": [{
|
"agent": {
|
||||||
"content": line.strip(),
|
"messages": [
|
||||||
"status": "info"
|
{
|
||||||
}]
|
"content": line.strip(),
|
||||||
}
|
"status": "info",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -211,10 +219,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error running ra-aid: {str(e)}"
|
error_msg = f"Error running ra-aid: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
await websocket.send_json({
|
await websocket.send_json({"type": "error", "message": error_msg})
|
||||||
"type": "error",
|
|
||||||
"message": error_msg
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info("Waiting for message...")
|
logger.info("Waiting for message...")
|
||||||
|
|
||||||
|
|
@ -243,6 +248,7 @@ def run_server(host: str = "0.0.0.0", port: int = 8080):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
|
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"
|
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,23 @@
|
||||||
"""
|
"""
|
||||||
Tests for the database connection module.
|
Tests for the database connection module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pytest
|
from unittest.mock import patch
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
from unittest.mock import patch, MagicMock
|
import pytest
|
||||||
|
|
||||||
from ra_aid.database.connection import (
|
from ra_aid.database.connection import (
|
||||||
init_db, get_db, close_db,
|
DatabaseManager,
|
||||||
db_var, DatabaseManager, logger
|
close_db,
|
||||||
|
db_var,
|
||||||
|
get_db,
|
||||||
|
init_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def cleanup_db():
|
def cleanup_db():
|
||||||
"""
|
"""
|
||||||
|
|
@ -48,20 +53,23 @@ def cleanup_db():
|
||||||
# Log but don't fail if cleanup has issues
|
# Log but don't fail if cleanup has issues
|
||||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
print(f"Cleanup error (non-fatal): {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_logger():
|
def mock_logger():
|
||||||
"""Mock the logger to test for output messages."""
|
"""Mock the logger to test for output messages."""
|
||||||
with patch('ra_aid.database.connection.logger') as mock:
|
with patch("ra_aid.database.connection.logger") as mock:
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
class TestInitDb:
|
class TestInitDb:
|
||||||
"""Tests for the init_db function."""
|
"""Tests for the init_db function."""
|
||||||
|
|
||||||
def test_init_db_default(self, cleanup_db):
|
def test_init_db_default(self, cleanup_db):
|
||||||
"""Test init_db with default parameters."""
|
"""Test init_db with default parameters."""
|
||||||
db = init_db()
|
db = init_db()
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False
|
||||||
# Verify the database file was created
|
# Verify the database file was created
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||||
|
|
@ -73,7 +81,7 @@ class TestInitDb:
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
def test_init_db_reuses_connection(self, cleanup_db):
|
def test_init_db_reuses_connection(self, cleanup_db):
|
||||||
|
|
@ -91,8 +99,10 @@ class TestInitDb:
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
assert not db1.is_closed()
|
assert not db1.is_closed()
|
||||||
|
|
||||||
|
|
||||||
class TestGetDb:
|
class TestGetDb:
|
||||||
"""Tests for the get_db function."""
|
"""Tests for the get_db function."""
|
||||||
|
|
||||||
def test_get_db_creates_connection(self, cleanup_db):
|
def test_get_db_creates_connection(self, cleanup_db):
|
||||||
"""Test that get_db creates a new connection if none exists."""
|
"""Test that get_db creates a new connection if none exists."""
|
||||||
# Reset the contextvar to ensure no connection exists
|
# Reset the contextvar to ensure no connection exists
|
||||||
|
|
@ -100,7 +110,7 @@ class TestGetDb:
|
||||||
db = get_db()
|
db = get_db()
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
def test_get_db_reuses_connection(self, cleanup_db):
|
def test_get_db_reuses_connection(self, cleanup_db):
|
||||||
|
|
@ -118,8 +128,10 @@ class TestGetDb:
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
assert not db1.is_closed()
|
assert not db1.is_closed()
|
||||||
|
|
||||||
|
|
||||||
class TestCloseDb:
|
class TestCloseDb:
|
||||||
"""Tests for the close_db function."""
|
"""Tests for the close_db function."""
|
||||||
|
|
||||||
def test_close_db(self, cleanup_db):
|
def test_close_db(self, cleanup_db):
|
||||||
"""Test that close_db closes an open connection."""
|
"""Test that close_db closes an open connection."""
|
||||||
db = init_db()
|
db = init_db()
|
||||||
|
|
@ -142,14 +154,16 @@ class TestCloseDb:
|
||||||
# This should not raise an exception
|
# This should not raise an exception
|
||||||
close_db()
|
close_db()
|
||||||
|
|
||||||
|
|
||||||
class TestDatabaseManager:
|
class TestDatabaseManager:
|
||||||
"""Tests for the DatabaseManager class."""
|
"""Tests for the DatabaseManager class."""
|
||||||
|
|
||||||
def test_database_manager_default(self, cleanup_db):
|
def test_database_manager_default(self, cleanup_db):
|
||||||
"""Test DatabaseManager with default parameters."""
|
"""Test DatabaseManager with default parameters."""
|
||||||
with DatabaseManager() as db:
|
with DatabaseManager() as db:
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False
|
||||||
# Verify the database file was created
|
# Verify the database file was created
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||||
|
|
@ -163,7 +177,7 @@ class TestDatabaseManager:
|
||||||
with DatabaseManager(in_memory=True) as db:
|
with DatabaseManager(in_memory=True) as db:
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, '_is_in_memory')
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
# Verify the connection is closed after exiting the context
|
# Verify the connection is closed after exiting the context
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
|
||||||
|
|
@ -5,22 +5,19 @@ Tests for the database migrations module.
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
from unittest.mock import patch, MagicMock, call, PropertyMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import peewee
|
|
||||||
from peewee_migrate import Router
|
|
||||||
|
|
||||||
from ra_aid.database.connection import DatabaseManager, db_var
|
from ra_aid.database.connection import DatabaseManager, db_var
|
||||||
from ra_aid.database.migrations import (
|
from ra_aid.database.migrations import (
|
||||||
MigrationManager,
|
|
||||||
init_migrations,
|
|
||||||
ensure_migrations_applied,
|
|
||||||
create_new_migration,
|
|
||||||
get_migration_status,
|
|
||||||
MIGRATIONS_DIRNAME,
|
MIGRATIONS_DIRNAME,
|
||||||
MIGRATIONS_TABLE
|
MIGRATIONS_TABLE,
|
||||||
|
MigrationManager,
|
||||||
|
create_new_migration,
|
||||||
|
ensure_migrations_applied,
|
||||||
|
get_migration_status,
|
||||||
|
init_migrations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,10 +34,10 @@ def cleanup_db():
|
||||||
# Ignore errors when closing the database
|
# Ignore errors when closing the database
|
||||||
pass
|
pass
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Reset after the test
|
# Reset after the test
|
||||||
db = db_var.get()
|
db = db_var.get()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
|
|
@ -56,7 +53,7 @@ def cleanup_db():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_logger():
|
def mock_logger():
|
||||||
"""Mock the logger to test for output messages."""
|
"""Mock the logger to test for output messages."""
|
||||||
with patch('ra_aid.database.migrations.logger') as mock:
|
with patch("ra_aid.database.migrations.logger") as mock:
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,276 +80,294 @@ def temp_migrations_dir(temp_dir):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_router():
|
def mock_router():
|
||||||
"""Mock the peewee_migrate Router class."""
|
"""Mock the peewee_migrate Router class."""
|
||||||
with patch('ra_aid.database.migrations.Router') as mock:
|
with patch("ra_aid.database.migrations.Router") as mock:
|
||||||
# Configure the mock router
|
# Configure the mock router
|
||||||
mock_instance = MagicMock()
|
mock_instance = MagicMock()
|
||||||
mock.return_value = mock_instance
|
mock.return_value = mock_instance
|
||||||
|
|
||||||
# Set up router properties
|
# Set up router properties
|
||||||
mock_instance.todo = ["001_initial", "002_add_users"]
|
mock_instance.todo = ["001_initial", "002_add_users"]
|
||||||
mock_instance.done = ["001_initial"]
|
mock_instance.done = ["001_initial"]
|
||||||
|
|
||||||
yield mock_instance
|
yield mock_instance
|
||||||
|
|
||||||
|
|
||||||
class TestMigrationManager:
|
class TestMigrationManager:
|
||||||
"""Tests for the MigrationManager class."""
|
"""Tests for the MigrationManager class."""
|
||||||
|
|
||||||
def test_init(self, cleanup_db, temp_dir, mock_logger):
|
def test_init(self, cleanup_db, temp_dir, mock_logger):
|
||||||
"""Test MigrationManager initialization."""
|
"""Test MigrationManager initialization."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Initialize manager
|
# Initialize manager
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Verify initialization
|
# Verify initialization
|
||||||
assert manager.db_path == db_path
|
assert manager.db_path == db_path
|
||||||
assert manager.migrations_dir == migrations_dir
|
assert manager.migrations_dir == migrations_dir
|
||||||
assert os.path.exists(migrations_dir)
|
assert os.path.exists(migrations_dir)
|
||||||
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
||||||
|
|
||||||
# Verify router initialization was logged
|
# Verify router initialization was logged
|
||||||
mock_logger.debug.assert_any_call(f"Using migrations directory: {migrations_dir}")
|
mock_logger.debug.assert_any_call(
|
||||||
mock_logger.debug.assert_any_call(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
|
f"Using migrations directory: {migrations_dir}"
|
||||||
|
)
|
||||||
|
mock_logger.debug.assert_any_call(
|
||||||
|
f"Initialized migration router with table: {MIGRATIONS_TABLE}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_ensure_migrations_dir(self, cleanup_db, temp_dir, mock_logger):
|
def test_ensure_migrations_dir(self, cleanup_db, temp_dir, mock_logger):
|
||||||
"""Test _ensure_migrations_dir creates directory if it doesn't exist."""
|
"""Test _ensure_migrations_dir creates directory if it doesn't exist."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, "nonexistent_dir", MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, "nonexistent_dir", MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Initialize manager
|
# Initialize manager
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Verify directory was created
|
# Verify directory was created
|
||||||
assert os.path.exists(migrations_dir)
|
assert os.path.exists(migrations_dir)
|
||||||
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
||||||
|
|
||||||
# Verify creation was logged
|
# Verify creation was logged
|
||||||
mock_logger.debug.assert_any_call(f"Creating migrations directory at: {migrations_dir}")
|
mock_logger.debug.assert_any_call(
|
||||||
|
f"Creating migrations directory at: {migrations_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_ensure_migrations_dir_error(self, cleanup_db, mock_logger):
|
def test_ensure_migrations_dir_error(self, cleanup_db, mock_logger):
|
||||||
"""Test _ensure_migrations_dir handles errors."""
|
"""Test _ensure_migrations_dir handles errors."""
|
||||||
# Mock os.makedirs to raise an exception
|
# Mock os.makedirs to raise an exception
|
||||||
with patch('pathlib.Path.mkdir', side_effect=PermissionError("Permission denied")):
|
with patch(
|
||||||
|
"pathlib.Path.mkdir", side_effect=PermissionError("Permission denied")
|
||||||
|
):
|
||||||
# Set up test paths - use a path that would require elevated permissions
|
# Set up test paths - use a path that would require elevated permissions
|
||||||
db_path = "/root/test.db"
|
db_path = "/root/test.db"
|
||||||
migrations_dir = "/root/migrations"
|
migrations_dir = "/root/migrations"
|
||||||
|
|
||||||
# Initialize manager should raise an exception
|
# Initialize manager should raise an exception
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(
|
||||||
|
db_path=db_path, migrations_dir=migrations_dir
|
||||||
|
)
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with(
|
mock_logger.error.assert_called_with(
|
||||||
f"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'"
|
"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_init_router(self, cleanup_db, temp_dir, mock_router):
|
def test_init_router(self, cleanup_db, temp_dir, mock_router):
|
||||||
"""Test _init_router initializes the Router correctly."""
|
"""Test _init_router initializes the Router correctly."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Create the migrations directory
|
# Create the migrations directory
|
||||||
os.makedirs(migrations_dir, exist_ok=True)
|
os.makedirs(migrations_dir, exist_ok=True)
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Verify router was initialized
|
# Verify router was initialized
|
||||||
assert manager.router == mock_router
|
assert manager.router == mock_router
|
||||||
|
|
||||||
def test_check_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
def test_check_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
||||||
"""Test check_migrations returns correct migration lists."""
|
"""Test check_migrations returns correct migration lists."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Call check_migrations
|
# Call check_migrations
|
||||||
applied, pending = manager.check_migrations()
|
applied, pending = manager.check_migrations()
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert applied == ["001_initial"]
|
assert applied == ["001_initial"]
|
||||||
assert pending == ["002_add_users"]
|
assert pending == ["002_add_users"]
|
||||||
|
|
||||||
# Verify logging
|
# Verify logging
|
||||||
mock_logger.debug.assert_called_with(
|
mock_logger.debug.assert_called_with(
|
||||||
"Found 1 applied migrations and 1 pending migrations"
|
"Found 1 applied migrations and 1 pending migrations"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_check_migrations_error(self, cleanup_db, temp_dir, mock_logger):
|
def test_check_migrations_error(self, cleanup_db, temp_dir, mock_logger):
|
||||||
"""Test check_migrations handles errors."""
|
"""Test check_migrations handles errors."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Create a mock router with a property that raises an exception
|
# Create a mock router with a property that raises an exception
|
||||||
mock_router = MagicMock()
|
mock_router = MagicMock()
|
||||||
# Configure the todo property to raise an exception when accessed
|
# Configure the todo property to raise an exception when accessed
|
||||||
type(mock_router).todo = PropertyMock(side_effect=Exception("Test error"))
|
type(mock_router).todo = PropertyMock(side_effect=Exception("Test error"))
|
||||||
mock_router.done = []
|
mock_router.done = []
|
||||||
|
|
||||||
# Initialize manager with the mocked Router
|
# Initialize manager with the mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Directly call check_migrations on the manager with the mocked router
|
# Directly call check_migrations on the manager with the mocked router
|
||||||
applied, pending = manager.check_migrations()
|
applied, pending = manager.check_migrations()
|
||||||
|
|
||||||
# Verify empty results are returned on error
|
# Verify empty results are returned on error
|
||||||
assert applied == []
|
assert applied == []
|
||||||
assert pending == []
|
assert pending == []
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with("Failed to check migrations: Test error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Failed to check migrations: Test error"
|
||||||
|
)
|
||||||
|
|
||||||
def test_apply_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
def test_apply_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
||||||
"""Test apply_migrations applies pending migrations."""
|
"""Test apply_migrations applies pending migrations."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Call apply_migrations
|
# Call apply_migrations
|
||||||
result = manager.apply_migrations()
|
result = manager.apply_migrations()
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify migrations were applied
|
# Verify migrations were applied
|
||||||
mock_router.run.assert_called_once_with("002_add_users", fake=False)
|
mock_router.run.assert_called_once_with("002_add_users", fake=False)
|
||||||
|
|
||||||
# Verify logging
|
# Verify logging
|
||||||
mock_logger.info.assert_any_call("Applying 1 pending migrations...")
|
mock_logger.info.assert_any_call("Applying 1 pending migrations...")
|
||||||
mock_logger.info.assert_any_call("Applying migration: 002_add_users")
|
mock_logger.info.assert_any_call("Applying migration: 002_add_users")
|
||||||
mock_logger.info.assert_any_call("Successfully applied migration: 002_add_users")
|
mock_logger.info.assert_any_call(
|
||||||
|
"Successfully applied migration: 002_add_users"
|
||||||
|
)
|
||||||
mock_logger.info.assert_any_call("Successfully applied 1 migrations")
|
mock_logger.info.assert_any_call("Successfully applied 1 migrations")
|
||||||
|
|
||||||
def test_apply_migrations_no_pending(self, cleanup_db, temp_dir, mock_logger):
|
def test_apply_migrations_no_pending(self, cleanup_db, temp_dir, mock_logger):
|
||||||
"""Test apply_migrations when no migrations are pending."""
|
"""Test apply_migrations when no migrations are pending."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Create a mock router with no pending migrations
|
# Create a mock router with no pending migrations
|
||||||
mock_router = MagicMock()
|
mock_router = MagicMock()
|
||||||
mock_router.todo = ["001_initial"]
|
mock_router.todo = ["001_initial"]
|
||||||
mock_router.done = ["001_initial"]
|
mock_router.done = ["001_initial"]
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Call apply_migrations
|
# Call apply_migrations
|
||||||
result = manager.apply_migrations()
|
result = manager.apply_migrations()
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify no migrations were applied
|
# Verify no migrations were applied
|
||||||
mock_router.run.assert_not_called()
|
mock_router.run.assert_not_called()
|
||||||
|
|
||||||
# Verify logging
|
# Verify logging
|
||||||
mock_logger.info.assert_called_with("No pending migrations to apply")
|
mock_logger.info.assert_called_with("No pending migrations to apply")
|
||||||
|
|
||||||
def test_apply_migrations_error(self, cleanup_db, temp_dir, mock_logger):
|
def test_apply_migrations_error(self, cleanup_db, temp_dir, mock_logger):
|
||||||
"""Test apply_migrations handles errors during migration."""
|
"""Test apply_migrations handles errors during migration."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Create a mock router that raises an exception during run
|
# Create a mock router that raises an exception during run
|
||||||
mock_router = MagicMock()
|
mock_router = MagicMock()
|
||||||
mock_router.todo = ["001_initial", "002_add_users"]
|
mock_router.todo = ["001_initial", "002_add_users"]
|
||||||
mock_router.done = ["001_initial"]
|
mock_router.done = ["001_initial"]
|
||||||
mock_router.run.side_effect = Exception("Migration error")
|
mock_router.run.side_effect = Exception("Migration error")
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Call apply_migrations
|
# Call apply_migrations
|
||||||
result = manager.apply_migrations()
|
result = manager.apply_migrations()
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with(
|
mock_logger.error.assert_called_with(
|
||||||
"Failed to apply migration 002_add_users: Migration error"
|
"Failed to apply migration 002_add_users: Migration error"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_migration(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
def test_create_migration(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
||||||
"""Test create_migration creates a new migration."""
|
"""Test create_migration creates a new migration."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Call create_migration
|
# Call create_migration
|
||||||
result = manager.create_migration("add_users", auto=True)
|
result = manager.create_migration("add_users", auto=True)
|
||||||
|
|
||||||
# Verify result contains timestamp and name
|
# Verify result contains timestamp and name
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "add_users" in result
|
assert "add_users" in result
|
||||||
|
|
||||||
# Verify migration was created
|
# Verify migration was created
|
||||||
mock_router.create.assert_called_once()
|
mock_router.create.assert_called_once()
|
||||||
|
|
||||||
# Verify logging
|
# Verify logging
|
||||||
mock_logger.info.assert_any_call(f"Creating new migration: {result}")
|
mock_logger.info.assert_any_call(f"Creating new migration: {result}")
|
||||||
mock_logger.info.assert_any_call(f"Successfully created migration: {result}")
|
mock_logger.info.assert_any_call(
|
||||||
|
f"Successfully created migration: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_migration_error(self, cleanup_db, temp_dir, mock_logger):
|
def test_create_migration_error(self, cleanup_db, temp_dir, mock_logger):
|
||||||
"""Test create_migration handles errors."""
|
"""Test create_migration handles errors."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Create a mock router that raises an exception during create
|
# Create a mock router that raises an exception during create
|
||||||
mock_router = MagicMock()
|
mock_router = MagicMock()
|
||||||
mock_router.create.side_effect = Exception("Creation error")
|
mock_router.create.side_effect = Exception("Creation error")
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Call create_migration
|
# Call create_migration
|
||||||
result = manager.create_migration("add_users", auto=True)
|
result = manager.create_migration("add_users", auto=True)
|
||||||
|
|
||||||
# Verify result is None on error
|
# Verify result is None on error
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with("Failed to create migration: Creation error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Failed to create migration: Creation error"
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_migration_status(self, cleanup_db, temp_dir, mock_router):
|
def test_get_migration_status(self, cleanup_db, temp_dir, mock_router):
|
||||||
"""Test get_migration_status returns correct status information."""
|
"""Test get_migration_status returns correct status information."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Initialize manager with mocked Router
|
# Initialize manager with mocked Router
|
||||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
with patch("ra_aid.database.migrations.Router", return_value=mock_router):
|
||||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Call get_migration_status
|
# Call get_migration_status
|
||||||
status = manager.get_migration_status()
|
status = manager.get_migration_status()
|
||||||
|
|
||||||
# Verify status information
|
# Verify status information
|
||||||
assert status["applied_count"] == 1
|
assert status["applied_count"] == 1
|
||||||
assert status["pending_count"] == 1
|
assert status["pending_count"] == 1
|
||||||
|
|
@ -364,81 +379,95 @@ class TestMigrationManager:
|
||||||
|
|
||||||
class TestMigrationFunctions:
|
class TestMigrationFunctions:
|
||||||
"""Tests for the migration utility functions."""
|
"""Tests for the migration utility functions."""
|
||||||
|
|
||||||
def test_init_migrations(self, cleanup_db, temp_dir):
|
def test_init_migrations(self, cleanup_db, temp_dir):
|
||||||
"""Test init_migrations returns a MigrationManager instance."""
|
"""Test init_migrations returns a MigrationManager instance."""
|
||||||
# Set up test paths
|
# Set up test paths
|
||||||
db_path = os.path.join(temp_dir, "test.db")
|
db_path = os.path.join(temp_dir, "test.db")
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
|
|
||||||
# Call init_migrations
|
# Call init_migrations
|
||||||
with patch('ra_aid.database.migrations.MigrationManager') as mock_manager:
|
with patch("ra_aid.database.migrations.MigrationManager") as mock_manager:
|
||||||
mock_manager.return_value = MagicMock()
|
mock_manager.return_value = MagicMock()
|
||||||
|
|
||||||
manager = init_migrations(db_path=db_path, migrations_dir=migrations_dir)
|
manager = init_migrations(db_path=db_path, migrations_dir=migrations_dir)
|
||||||
|
|
||||||
# Verify MigrationManager was initialized with correct parameters
|
# Verify MigrationManager was initialized with correct parameters
|
||||||
mock_manager.assert_called_once_with(db_path, migrations_dir)
|
mock_manager.assert_called_once_with(db_path, migrations_dir)
|
||||||
assert manager == mock_manager.return_value
|
assert manager == mock_manager.return_value
|
||||||
|
|
||||||
def test_ensure_migrations_applied(self, cleanup_db, mock_logger):
|
def test_ensure_migrations_applied(self, cleanup_db, mock_logger):
|
||||||
"""Test ensure_migrations_applied applies pending migrations."""
|
"""Test ensure_migrations_applied applies pending migrations."""
|
||||||
# Mock MigrationManager
|
# Mock MigrationManager
|
||||||
mock_manager = MagicMock()
|
mock_manager = MagicMock()
|
||||||
mock_manager.apply_migrations.return_value = True
|
mock_manager.apply_migrations.return_value = True
|
||||||
|
|
||||||
# Call ensure_migrations_applied
|
# Call ensure_migrations_applied
|
||||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
with patch(
|
||||||
|
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
|
||||||
|
):
|
||||||
result = ensure_migrations_applied()
|
result = ensure_migrations_applied()
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify migrations were applied
|
# Verify migrations were applied
|
||||||
mock_manager.apply_migrations.assert_called_once()
|
mock_manager.apply_migrations.assert_called_once()
|
||||||
|
|
||||||
def test_ensure_migrations_applied_error(self, cleanup_db, mock_logger):
|
def test_ensure_migrations_applied_error(self, cleanup_db, mock_logger):
|
||||||
"""Test ensure_migrations_applied handles errors."""
|
"""Test ensure_migrations_applied handles errors."""
|
||||||
# Call ensure_migrations_applied with an exception
|
# Call ensure_migrations_applied with an exception
|
||||||
with patch('ra_aid.database.migrations.init_migrations',
|
with patch(
|
||||||
side_effect=Exception("Test error")):
|
"ra_aid.database.migrations.init_migrations",
|
||||||
|
side_effect=Exception("Test error"),
|
||||||
|
):
|
||||||
result = ensure_migrations_applied()
|
result = ensure_migrations_applied()
|
||||||
|
|
||||||
# Verify result is False on error
|
# Verify result is False on error
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with("Failed to apply migrations: Test error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Failed to apply migrations: Test error"
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_new_migration(self, cleanup_db, mock_logger):
|
def test_create_new_migration(self, cleanup_db, mock_logger):
|
||||||
"""Test create_new_migration creates a new migration."""
|
"""Test create_new_migration creates a new migration."""
|
||||||
# Mock MigrationManager
|
# Mock MigrationManager
|
||||||
mock_manager = MagicMock()
|
mock_manager = MagicMock()
|
||||||
mock_manager.create_migration.return_value = "20250226_123456_test_migration"
|
mock_manager.create_migration.return_value = "20250226_123456_test_migration"
|
||||||
|
|
||||||
# Call create_new_migration
|
# Call create_new_migration
|
||||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
with patch(
|
||||||
|
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
|
||||||
|
):
|
||||||
result = create_new_migration("test_migration", auto=True)
|
result = create_new_migration("test_migration", auto=True)
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result == "20250226_123456_test_migration"
|
assert result == "20250226_123456_test_migration"
|
||||||
|
|
||||||
# Verify migration was created
|
# Verify migration was created
|
||||||
mock_manager.create_migration.assert_called_once_with("test_migration", True)
|
mock_manager.create_migration.assert_called_once_with(
|
||||||
|
"test_migration", True
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_new_migration_error(self, cleanup_db, mock_logger):
|
def test_create_new_migration_error(self, cleanup_db, mock_logger):
|
||||||
"""Test create_new_migration handles errors."""
|
"""Test create_new_migration handles errors."""
|
||||||
# Call create_new_migration with an exception
|
# Call create_new_migration with an exception
|
||||||
with patch('ra_aid.database.migrations.init_migrations',
|
with patch(
|
||||||
side_effect=Exception("Test error")):
|
"ra_aid.database.migrations.init_migrations",
|
||||||
|
side_effect=Exception("Test error"),
|
||||||
|
):
|
||||||
result = create_new_migration("test_migration", auto=True)
|
result = create_new_migration("test_migration", auto=True)
|
||||||
|
|
||||||
# Verify result is None on error
|
# Verify result is None on error
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with("Failed to create migration: Test error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Failed to create migration: Test error"
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_migration_status(self, cleanup_db, mock_logger):
|
def test_get_migration_status(self, cleanup_db, mock_logger):
|
||||||
"""Test get_migration_status returns correct status information."""
|
"""Test get_migration_status returns correct status information."""
|
||||||
# Mock MigrationManager
|
# Mock MigrationManager
|
||||||
|
|
@ -449,13 +478,15 @@ class TestMigrationFunctions:
|
||||||
"applied": ["001_initial", "002_add_users"],
|
"applied": ["001_initial", "002_add_users"],
|
||||||
"pending": ["003_add_profiles"],
|
"pending": ["003_add_profiles"],
|
||||||
"migrations_dir": "/test/migrations",
|
"migrations_dir": "/test/migrations",
|
||||||
"db_path": "/test/db.sqlite"
|
"db_path": "/test/db.sqlite",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Call get_migration_status
|
# Call get_migration_status
|
||||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
with patch(
|
||||||
|
"ra_aid.database.migrations.init_migrations", return_value=mock_manager
|
||||||
|
):
|
||||||
status = get_migration_status()
|
status = get_migration_status()
|
||||||
|
|
||||||
# Verify status information
|
# Verify status information
|
||||||
assert status["applied_count"] == 2
|
assert status["applied_count"] == 2
|
||||||
assert status["pending_count"] == 1
|
assert status["pending_count"] == 1
|
||||||
|
|
@ -463,31 +494,35 @@ class TestMigrationFunctions:
|
||||||
assert status["pending"] == ["003_add_profiles"]
|
assert status["pending"] == ["003_add_profiles"]
|
||||||
assert status["migrations_dir"] == "/test/migrations"
|
assert status["migrations_dir"] == "/test/migrations"
|
||||||
assert status["db_path"] == "/test/db.sqlite"
|
assert status["db_path"] == "/test/db.sqlite"
|
||||||
|
|
||||||
# Verify migration status was retrieved
|
# Verify migration status was retrieved
|
||||||
mock_manager.get_migration_status.assert_called_once()
|
mock_manager.get_migration_status.assert_called_once()
|
||||||
|
|
||||||
def test_get_migration_status_error(self, cleanup_db, mock_logger):
|
def test_get_migration_status_error(self, cleanup_db, mock_logger):
|
||||||
"""Test get_migration_status handles errors."""
|
"""Test get_migration_status handles errors."""
|
||||||
# Call get_migration_status with an exception
|
# Call get_migration_status with an exception
|
||||||
with patch('ra_aid.database.migrations.init_migrations',
|
with patch(
|
||||||
side_effect=Exception("Test error")):
|
"ra_aid.database.migrations.init_migrations",
|
||||||
|
side_effect=Exception("Test error"),
|
||||||
|
):
|
||||||
status = get_migration_status()
|
status = get_migration_status()
|
||||||
|
|
||||||
# Verify default status on error
|
# Verify default status on error
|
||||||
assert status["error"] == "Test error"
|
assert status["error"] == "Test error"
|
||||||
assert status["applied_count"] == 0
|
assert status["applied_count"] == 0
|
||||||
assert status["pending_count"] == 0
|
assert status["pending_count"] == 0
|
||||||
assert status["applied"] == []
|
assert status["applied"] == []
|
||||||
assert status["pending"] == []
|
assert status["pending"] == []
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with("Failed to get migration status: Test error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Failed to get migration status: Test error"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestIntegration:
|
class TestIntegration:
|
||||||
"""Integration tests for the migrations module."""
|
"""Integration tests for the migrations module."""
|
||||||
|
|
||||||
def test_in_memory_migrations(self, cleanup_db):
|
def test_in_memory_migrations(self, cleanup_db):
|
||||||
"""Test migrations with in-memory database."""
|
"""Test migrations with in-memory database."""
|
||||||
# Initialize in-memory database
|
# Initialize in-memory database
|
||||||
|
|
@ -496,17 +531,19 @@ class TestIntegration:
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||||
os.makedirs(migrations_dir, exist_ok=True)
|
os.makedirs(migrations_dir, exist_ok=True)
|
||||||
|
|
||||||
# Create __init__.py to make it a proper package
|
# Create __init__.py to make it a proper package
|
||||||
with open(os.path.join(migrations_dir, "__init__.py"), "w") as f:
|
with open(os.path.join(migrations_dir, "__init__.py"), "w") as f:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Initialize migration manager
|
# Initialize migration manager
|
||||||
manager = MigrationManager(db_path=":memory:", migrations_dir=migrations_dir)
|
manager = MigrationManager(
|
||||||
|
db_path=":memory:", migrations_dir=migrations_dir
|
||||||
|
)
|
||||||
|
|
||||||
# Create a test migration
|
# Create a test migration
|
||||||
migration_name = manager.create_migration("test_migration", auto=False)
|
migration_name = manager.create_migration("test_migration", auto=False)
|
||||||
|
|
||||||
# Write a simple migration file
|
# Write a simple migration file
|
||||||
migration_path = os.path.join(migrations_dir, f"{migration_name}.py")
|
migration_path = os.path.join(migrations_dir, f"{migration_name}.py")
|
||||||
with open(migration_path, "w") as f:
|
with open(migration_path, "w") as f:
|
||||||
|
|
@ -520,23 +557,27 @@ def migrate(migrator, database, fake=False, **kwargs):
|
||||||
def rollback(migrator, database, fake=False, **kwargs):
|
def rollback(migrator, database, fake=False, **kwargs):
|
||||||
migrator.drop_table('test_table')
|
migrator.drop_table('test_table')
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Check migrations
|
# Check migrations
|
||||||
applied, pending = manager.check_migrations()
|
applied, pending = manager.check_migrations()
|
||||||
assert len(applied) == 0
|
assert len(applied) == 0
|
||||||
assert len(pending) == 1
|
assert len(pending) == 1
|
||||||
assert migration_name in pending[0] # Instead of exact equality, check if name is contained
|
assert (
|
||||||
|
migration_name in pending[0]
|
||||||
|
) # Instead of exact equality, check if name is contained
|
||||||
|
|
||||||
# Apply migrations
|
# Apply migrations
|
||||||
result = manager.apply_migrations()
|
result = manager.apply_migrations()
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Check migrations again
|
# Check migrations again
|
||||||
applied, pending = manager.check_migrations()
|
applied, pending = manager.check_migrations()
|
||||||
assert len(applied) == 1
|
assert len(applied) == 1
|
||||||
assert len(pending) == 0
|
assert len(pending) == 0
|
||||||
assert migration_name in applied[0] # Instead of exact equality, check if name is contained
|
assert (
|
||||||
|
migration_name in applied[0]
|
||||||
|
) # Instead of exact equality, check if name is contained
|
||||||
|
|
||||||
# Verify migration status
|
# Verify migration status
|
||||||
status = manager.get_migration_status()
|
status = manager.get_migration_status()
|
||||||
assert status["applied_count"] == 1
|
assert status["applied_count"] == 1
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,11 @@ Tests for the database models module.
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import peewee
|
import peewee
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.database.connection import db_var, init_db
|
||||||
from ra_aid.database.models import BaseModel
|
from ra_aid.database.models import BaseModel
|
||||||
from ra_aid.database.connection import (
|
|
||||||
db_var, get_db, init_db, close_db
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -26,10 +24,10 @@ def cleanup_db():
|
||||||
# Ignore errors when closing the database
|
# Ignore errors when closing the database
|
||||||
pass
|
pass
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Reset after the test
|
# Reset after the test
|
||||||
db = db_var.get()
|
db = db_var.get()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
|
|
@ -47,21 +45,21 @@ def setup_test_model(cleanup_db):
|
||||||
"""Set up a test model class for testing."""
|
"""Set up a test model class for testing."""
|
||||||
# Initialize an in-memory database connection
|
# Initialize an in-memory database connection
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Define a test model
|
# Define a test model
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
name = peewee.CharField()
|
name = peewee.CharField()
|
||||||
value = peewee.IntegerField(null=True)
|
value = peewee.IntegerField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
database = db
|
database = db
|
||||||
|
|
||||||
# Create the table
|
# Create the table
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
db.create_tables([TestModel], safe=True)
|
db.create_tables([TestModel], safe=True)
|
||||||
|
|
||||||
yield TestModel
|
yield TestModel
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
TestModel.drop_table(safe=True)
|
TestModel.drop_table(safe=True)
|
||||||
|
|
@ -70,27 +68,28 @@ def setup_test_model(cleanup_db):
|
||||||
def test_base_model_save_updates_timestamps(setup_test_model):
|
def test_base_model_save_updates_timestamps(setup_test_model):
|
||||||
"""Test that saving a model updates the timestamps."""
|
"""Test that saving a model updates the timestamps."""
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# Create a new instance
|
# Create a new instance
|
||||||
instance = TestModel(name="test", value=42)
|
instance = TestModel(name="test", value=42)
|
||||||
instance.save()
|
instance.save()
|
||||||
|
|
||||||
# Check that created_at and updated_at are set
|
# Check that created_at and updated_at are set
|
||||||
assert instance.created_at is not None
|
assert instance.created_at is not None
|
||||||
assert instance.updated_at is not None
|
assert instance.updated_at is not None
|
||||||
|
|
||||||
# Store the original timestamps
|
# Store the original timestamps
|
||||||
original_created_at = instance.created_at
|
original_created_at = instance.created_at
|
||||||
original_updated_at = instance.updated_at
|
original_updated_at = instance.updated_at
|
||||||
|
|
||||||
# Wait a moment to ensure timestamps would be different
|
# Wait a moment to ensure timestamps would be different
|
||||||
import time
|
import time
|
||||||
|
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
# Update the instance
|
# Update the instance
|
||||||
instance.value = 43
|
instance.value = 43
|
||||||
instance.save()
|
instance.save()
|
||||||
|
|
||||||
# Check that updated_at changed but created_at didn't
|
# Check that updated_at changed but created_at didn't
|
||||||
assert instance.created_at == original_created_at
|
assert instance.created_at == original_created_at
|
||||||
assert instance.updated_at != original_updated_at
|
assert instance.updated_at != original_updated_at
|
||||||
|
|
@ -99,34 +98,36 @@ def test_base_model_save_updates_timestamps(setup_test_model):
|
||||||
def test_base_model_get_or_create(setup_test_model):
|
def test_base_model_get_or_create(setup_test_model):
|
||||||
"""Test the get_or_create method."""
|
"""Test the get_or_create method."""
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# First call should create a new instance
|
# First call should create a new instance
|
||||||
instance, created = TestModel.get_or_create(name="test", value=42)
|
instance, created = TestModel.get_or_create(name="test", value=42)
|
||||||
assert created is True
|
assert created is True
|
||||||
assert instance.name == "test"
|
assert instance.name == "test"
|
||||||
assert instance.value == 42
|
assert instance.value == 42
|
||||||
|
|
||||||
# Second call with same parameters should return existing instance
|
# Second call with same parameters should return existing instance
|
||||||
instance2, created2 = TestModel.get_or_create(name="test", value=42)
|
instance2, created2 = TestModel.get_or_create(name="test", value=42)
|
||||||
assert created2 is False
|
assert created2 is False
|
||||||
assert instance2.id == instance.id
|
assert instance2.id == instance.id
|
||||||
|
|
||||||
# Call with different parameters should create a new instance
|
# Call with different parameters should create a new instance
|
||||||
instance3, created3 = TestModel.get_or_create(name="test2", value=43)
|
instance3, created3 = TestModel.get_or_create(name="test2", value=43)
|
||||||
assert created3 is True
|
assert created3 is True
|
||||||
assert instance3.id != instance.id
|
assert instance3.id != instance.id
|
||||||
|
|
||||||
|
|
||||||
@patch('ra_aid.database.models.logger')
|
@patch("ra_aid.database.models.logger")
|
||||||
def test_base_model_get_or_create_handles_errors(mock_logger, setup_test_model):
|
def test_base_model_get_or_create_handles_errors(mock_logger, setup_test_model):
|
||||||
"""Test that get_or_create handles database errors properly."""
|
"""Test that get_or_create handles database errors properly."""
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# Mock the parent get_or_create to raise a DatabaseError
|
# Mock the parent get_or_create to raise a DatabaseError
|
||||||
with patch('peewee.Model.get_or_create', side_effect=peewee.DatabaseError("Test error")):
|
with patch(
|
||||||
|
"peewee.Model.get_or_create", side_effect=peewee.DatabaseError("Test error")
|
||||||
|
):
|
||||||
# Call should raise the error
|
# Call should raise the error
|
||||||
with pytest.raises(peewee.DatabaseError):
|
with pytest.raises(peewee.DatabaseError):
|
||||||
TestModel.get_or_create(name="test")
|
TestModel.get_or_create(name="test")
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logger.error.assert_called_with("Failed in get_or_create: Test error")
|
mock_logger.error.assert_called_with("Failed in get_or_create: Test error")
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,12 @@
|
||||||
Tests for the database utils module.
|
Tests for the database utils module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import peewee
|
import peewee
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ra_aid.database.connection import (
|
from ra_aid.database.connection import db_var, init_db
|
||||||
db_var, get_db, init_db, close_db
|
|
||||||
)
|
|
||||||
from ra_aid.database.models import BaseModel
|
from ra_aid.database.models import BaseModel
|
||||||
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
||||||
|
|
||||||
|
|
@ -27,10 +25,10 @@ def cleanup_db():
|
||||||
# Ignore errors when closing the database
|
# Ignore errors when closing the database
|
||||||
pass
|
pass
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Reset after the test
|
# Reset after the test
|
||||||
db = db_var.get()
|
db = db_var.get()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
|
|
@ -46,7 +44,7 @@ def cleanup_db():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_logger():
|
def mock_logger():
|
||||||
"""Mock the logger to test for output messages."""
|
"""Mock the logger to test for output messages."""
|
||||||
with patch('ra_aid.database.utils.logger') as mock:
|
with patch("ra_aid.database.utils.logger") as mock:
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -55,22 +53,22 @@ def setup_test_model(cleanup_db):
|
||||||
"""Set up a test model for database tests."""
|
"""Set up a test model for database tests."""
|
||||||
# Initialize the database in memory
|
# Initialize the database in memory
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Define a test model class
|
# Define a test model class
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
name = peewee.CharField(max_length=100)
|
name = peewee.CharField(max_length=100)
|
||||||
value = peewee.IntegerField(default=0)
|
value = peewee.IntegerField(default=0)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
database = db
|
database = db
|
||||||
|
|
||||||
# Create the test table in a transaction
|
# Create the test table in a transaction
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
db.create_tables([TestModel], safe=True)
|
db.create_tables([TestModel], safe=True)
|
||||||
|
|
||||||
# Yield control to the test
|
# Yield control to the test
|
||||||
yield TestModel
|
yield TestModel
|
||||||
|
|
||||||
# Clean up: drop the test table
|
# Clean up: drop the test table
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
db.drop_tables([TestModel], safe=True)
|
db.drop_tables([TestModel], safe=True)
|
||||||
|
|
@ -80,98 +78,90 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
||||||
"""Test ensure_tables_created with explicit models."""
|
"""Test ensure_tables_created with explicit models."""
|
||||||
# Initialize the database in memory
|
# Initialize the database in memory
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Define a test model that uses this database
|
# Define a test model that uses this database
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
name = peewee.CharField(max_length=100)
|
name = peewee.CharField(max_length=100)
|
||||||
value = peewee.IntegerField(default=0)
|
value = peewee.IntegerField(default=0)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
database = db
|
database = db
|
||||||
|
|
||||||
# Call ensure_tables_created with explicit models
|
# Call ensure_tables_created with explicit models
|
||||||
ensure_tables_created([TestModel])
|
ensure_tables_created([TestModel])
|
||||||
|
|
||||||
# Verify success message was logged
|
# Verify success message was logged
|
||||||
mock_logger.info.assert_called_with("Successfully created tables for 1 models")
|
mock_logger.info.assert_called_with("Successfully created tables for 1 models")
|
||||||
|
|
||||||
# Verify the table exists by trying to use it
|
# Verify the table exists by trying to use it
|
||||||
TestModel.create(name="test", value=42)
|
TestModel.create(name="test", value=42)
|
||||||
count = TestModel.select().count()
|
count = TestModel.select().count()
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_tables_created_no_models(cleanup_db, mock_logger):
|
@patch("ra_aid.database.utils.get_db")
|
||||||
"""Test ensure_tables_created with no models."""
|
def test_ensure_tables_created_database_error(
|
||||||
# Initialize the database in memory
|
mock_get_db, setup_test_model, cleanup_db, mock_logger
|
||||||
db = init_db(in_memory=True)
|
):
|
||||||
|
|
||||||
# Mock the import to simulate no models found
|
|
||||||
with patch('ra_aid.database.utils.importlib.import_module', side_effect=ImportError("No module")):
|
|
||||||
# Call ensure_tables_created with no models
|
|
||||||
ensure_tables_created()
|
|
||||||
|
|
||||||
# Verify warning message was logged
|
|
||||||
mock_logger.warning.assert_called_with("No models found to create tables for")
|
|
||||||
|
|
||||||
|
|
||||||
@patch('ra_aid.database.utils.get_db')
|
|
||||||
def test_ensure_tables_created_database_error(mock_get_db, setup_test_model, cleanup_db, mock_logger):
|
|
||||||
"""Test ensure_tables_created handles database errors."""
|
"""Test ensure_tables_created handles database errors."""
|
||||||
# Get the TestModel class from the fixture
|
# Get the TestModel class from the fixture
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# Create a mock database with a create_tables method that raises an error
|
# Create a mock database with a create_tables method that raises an error
|
||||||
mock_db = MagicMock()
|
mock_db = MagicMock()
|
||||||
mock_db.atomic.return_value.__enter__.return_value = None
|
mock_db.atomic.return_value.__enter__.return_value = None
|
||||||
mock_db.atomic.return_value.__exit__.return_value = None
|
mock_db.atomic.return_value.__exit__.return_value = None
|
||||||
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
|
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
|
||||||
|
|
||||||
# Configure get_db to return our mock
|
# Configure get_db to return our mock
|
||||||
mock_get_db.return_value = mock_db
|
mock_get_db.return_value = mock_db
|
||||||
|
|
||||||
# Call ensure_tables_created and expect an exception
|
# Call ensure_tables_created and expect an exception
|
||||||
with pytest.raises(peewee.DatabaseError):
|
with pytest.raises(peewee.DatabaseError):
|
||||||
ensure_tables_created([TestModel])
|
ensure_tables_created([TestModel])
|
||||||
|
|
||||||
# Verify error message was logged
|
# Verify error message was logged
|
||||||
mock_logger.error.assert_called_with("Database Error: Failed to create tables: Test database error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Database Error: Failed to create tables: Test database error"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_count(setup_test_model, mock_logger):
|
def test_get_model_count(setup_test_model, mock_logger):
|
||||||
"""Test get_model_count returns the correct count."""
|
"""Test get_model_count returns the correct count."""
|
||||||
# Get the TestModel class from the fixture
|
# Get the TestModel class from the fixture
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# First ensure the table is empty
|
# First ensure the table is empty
|
||||||
TestModel.delete().execute()
|
TestModel.delete().execute()
|
||||||
|
|
||||||
# Create some test records
|
# Create some test records
|
||||||
TestModel.create(name="test1", value=1)
|
TestModel.create(name="test1", value=1)
|
||||||
TestModel.create(name="test2", value=2)
|
TestModel.create(name="test2", value=2)
|
||||||
|
|
||||||
# Call get_model_count
|
# Call get_model_count
|
||||||
count = get_model_count(TestModel)
|
count = get_model_count(TestModel)
|
||||||
|
|
||||||
# Verify the count is correct
|
# Verify the count is correct
|
||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
@patch('peewee.ModelSelect.count')
|
@patch("peewee.ModelSelect.count")
|
||||||
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
|
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
|
||||||
"""Test get_model_count handles database errors."""
|
"""Test get_model_count handles database errors."""
|
||||||
# Get the TestModel class from the fixture
|
# Get the TestModel class from the fixture
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# Configure the mock to raise a DatabaseError
|
# Configure the mock to raise a DatabaseError
|
||||||
mock_count.side_effect = peewee.DatabaseError("Test count error")
|
mock_count.side_effect = peewee.DatabaseError("Test count error")
|
||||||
|
|
||||||
# Call get_model_count
|
# Call get_model_count
|
||||||
count = get_model_count(TestModel)
|
count = get_model_count(TestModel)
|
||||||
|
|
||||||
# Verify error message was logged
|
# Verify error message was logged
|
||||||
mock_logger.error.assert_called_with("Database Error: Failed to count records: Test count error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Database Error: Failed to count records: Test count error"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the function returns 0 on error
|
# Verify the function returns 0 on error
|
||||||
assert count == 0
|
assert count == 0
|
||||||
|
|
||||||
|
|
@ -180,41 +170,45 @@ def test_truncate_table(setup_test_model, mock_logger):
|
||||||
"""Test truncate_table deletes all records."""
|
"""Test truncate_table deletes all records."""
|
||||||
# Get the TestModel class from the fixture
|
# Get the TestModel class from the fixture
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# Create some test records
|
# Create some test records
|
||||||
TestModel.create(name="test1", value=1)
|
TestModel.create(name="test1", value=1)
|
||||||
TestModel.create(name="test2", value=2)
|
TestModel.create(name="test2", value=2)
|
||||||
|
|
||||||
# Verify records exist
|
# Verify records exist
|
||||||
assert TestModel.select().count() == 2
|
assert TestModel.select().count() == 2
|
||||||
|
|
||||||
# Call truncate_table
|
# Call truncate_table
|
||||||
truncate_table(TestModel)
|
truncate_table(TestModel)
|
||||||
|
|
||||||
# Verify success message was logged
|
# Verify success message was logged
|
||||||
mock_logger.info.assert_called_with(f"Successfully truncated table for {TestModel.__name__}")
|
mock_logger.info.assert_called_with(
|
||||||
|
f"Successfully truncated table for {TestModel.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify all records were deleted
|
# Verify all records were deleted
|
||||||
assert TestModel.select().count() == 0
|
assert TestModel.select().count() == 0
|
||||||
|
|
||||||
|
|
||||||
@patch('ra_aid.database.models.BaseModel.delete')
|
@patch("ra_aid.database.models.BaseModel.delete")
|
||||||
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
|
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
|
||||||
"""Test truncate_table handles database errors."""
|
"""Test truncate_table handles database errors."""
|
||||||
# Get the TestModel class from the fixture
|
# Get the TestModel class from the fixture
|
||||||
TestModel = setup_test_model
|
TestModel = setup_test_model
|
||||||
|
|
||||||
# Create a test record
|
# Create a test record
|
||||||
TestModel.create(name="test", value=42)
|
TestModel.create(name="test", value=42)
|
||||||
|
|
||||||
# Configure the mock to return a mock query with execute that raises an error
|
# Configure the mock to return a mock query with execute that raises an error
|
||||||
mock_query = MagicMock()
|
mock_query = MagicMock()
|
||||||
mock_query.execute.side_effect = peewee.DatabaseError("Test delete error")
|
mock_query.execute.side_effect = peewee.DatabaseError("Test delete error")
|
||||||
mock_delete.return_value = mock_query
|
mock_delete.return_value = mock_query
|
||||||
|
|
||||||
# Call truncate_table and expect an exception
|
# Call truncate_table and expect an exception
|
||||||
with pytest.raises(peewee.DatabaseError):
|
with pytest.raises(peewee.DatabaseError):
|
||||||
truncate_table(TestModel)
|
truncate_table(TestModel)
|
||||||
|
|
||||||
# Verify error message was logged
|
# Verify error message was logged
|
||||||
mock_logger.error.assert_called_with("Database Error: Failed to truncate table: Test delete error")
|
mock_logger.error.assert_called_with(
|
||||||
|
"Database Error: Failed to truncate table: Test delete error"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
|
||||||
|
This is a test ascii file.
|
||||||
|
|
||||||
|
ASCII text, with very long lines (885), with CRLF line terminators
|
||||||
|
|
||||||
|
Mumblecore craft beer taxidermy, flannel YOLO pug brunch ugh you probably haven't heard of them art party next level Pinterest squid pork belly. Yr next level Carles, Thundercats dreamcatcher scenester master cleanse bitters disrupt tote bag keffiyeh narwhal organic salvia cray. Whatever heirloom Vice art party pickled, try-hard Williamsburg. Authentic pickled pop-up, letterpress bicycle rights cornhole vinyl Etsy readymade disrupt shabby chic Pitchfork keffiyeh. Master cleanse small batch keytar biodiesel Brooklyn, meggings four loko try-hard McSweeney's vinyl tattooed. Cred Schlitz selvage, tousled Odd Future literally before they sold out synth cardigan retro banh mi next level jean shorts meggings fap. Pork belly four dollar toast quinoa, stumptown taxidermy sriracha whatever you probably haven't heard of them squid single-origin coffee freegan disrupt cliche cardigan.
|
||||||
|
|
||||||
|
Bicycle rights cold-pressed Pinterest, beard butcher pickled pop-up synth DIY hashtag. Austin fanny pack farm-to-table keytar, kitsch fap tousled trust fund swag irony +1. Viral fanny pack vinyl, master cleanse 3 wolf moon readymade occupy before they sold out YOLO meggings XOXO art party fap try-hard. Photo booth you probably haven't heard of them artisan pickled Brooklyn cred umami meh, heirloom cray raw denim tousled drinking vinegar. Gentrify Williamsburg iPhone messenger bag heirloom, swag quinoa ennui brunch. Selvage tofu hella gastropub Pinterest, bicycle rights church-key cardigan semiotics cornhole Shoreditch iPhone fixie biodiesel narwhal. Small batch kogi Shoreditch cliche YOLO.
|
||||||
|
|
||||||
|
Literally yr ugh Truffaut raw denim four loko. Vice chia mustache, Intelligentsia authentic taxidermy Truffaut synth health goth. Locavore semiotics occupy, synth 8-bit hoodie umami meh PBR&B Wes Anderson brunch shabby chic Helvetica quinoa. YOLO beard pop-up Neutra PBR&B vinyl fixie, stumptown shabby chic flexitarian umami. Cronut Blue Bottle scenester sriracha keytar PBR ennui flannel VHS swag. Dreamcatcher 3 wolf moon fanny pack, tattooed XOXO bitters High Life fixie 8-bit Austin lomo single-origin coffee put a bird on it. High Life Kickstarter twee Blue Bottle shabby chic, biodiesel heirloom.
|
||||||
|
|
||||||
|
Wayfarers tousled stumptown pop-up slow-carb. Aesthetic American Apparel hoodie irony YOLO. Meggings synth meh, normcore lomo tote bag post-ironic twee sartorial butcher occupy. Tilde photo booth +1 kogi Williamsburg. Pork belly keytar seitan, pug iPhone fingerstache bitters. Ennui Schlitz actually, cardigan fashion axe Helvetica vegan. Swag lumbersexual blog Carles, cred synth asymmetrical heirloom Tumblr bitters letterpress aesthetic.
|
||||||
|
|
@ -176,7 +176,7 @@ def test_strip_trailing_whitespace():
|
||||||
# Create a command that outputs text with trailing whitespace
|
# Create a command that outputs text with trailing whitespace
|
||||||
cmd = 'echo "Line with spaces at end "; echo "Another trailing space line "; echo "Line with tabs at end\t\t"'
|
cmd = 'echo "Line with spaces at end "; echo "Another trailing space line "; echo "Line with tabs at end\t\t"'
|
||||||
output, retcode = run_interactive_command(["/bin/bash", "-c", cmd])
|
output, retcode = run_interactive_command(["/bin/bash", "-c", cmd])
|
||||||
|
|
||||||
# Check that the output contains the lines without trailing whitespace
|
# Check that the output contains the lines without trailing whitespace
|
||||||
lines = output.splitlines()
|
lines = output.splitlines()
|
||||||
assert b"Line with spaces at end" in lines[0]
|
assert b"Line with spaces at end" in lines[0]
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,17 @@
|
||||||
"""Tests for Windows-specific functionality."""
|
"""Tests for Windows-specific functionality."""
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import pytest
|
import sys
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.proc.interactive import (
|
||||||
|
create_process,
|
||||||
|
get_terminal_size,
|
||||||
|
run_interactive_command,
|
||||||
|
)
|
||||||
|
|
||||||
from ra_aid.proc.interactive import get_terminal_size, create_process, run_interactive_command
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests")
|
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests")
|
||||||
class TestWindowsCompatibility:
|
class TestWindowsCompatibility:
|
||||||
|
|
@ -14,7 +19,7 @@ class TestWindowsCompatibility:
|
||||||
|
|
||||||
def test_get_terminal_size(self):
|
def test_get_terminal_size(self):
|
||||||
"""Test terminal size detection on Windows."""
|
"""Test terminal size detection on Windows."""
|
||||||
with patch('shutil.get_terminal_size') as mock_get_size:
|
with patch("shutil.get_terminal_size") as mock_get_size:
|
||||||
mock_get_size.return_value = MagicMock(columns=120, lines=30)
|
mock_get_size.return_value = MagicMock(columns=120, lines=30)
|
||||||
cols, rows = get_terminal_size()
|
cols, rows = get_terminal_size()
|
||||||
assert cols == 120
|
assert cols == 120
|
||||||
|
|
@ -23,30 +28,31 @@ class TestWindowsCompatibility:
|
||||||
|
|
||||||
def test_create_process(self):
|
def test_create_process(self):
|
||||||
"""Test process creation on Windows."""
|
"""Test process creation on Windows."""
|
||||||
with patch('subprocess.Popen') as mock_popen:
|
with patch("subprocess.Popen") as mock_popen:
|
||||||
mock_process = MagicMock()
|
mock_process = MagicMock()
|
||||||
mock_process.returncode = 0
|
mock_process.returncode = 0
|
||||||
mock_popen.return_value = mock_process
|
mock_popen.return_value = mock_process
|
||||||
|
|
||||||
proc, _ = create_process(['echo', 'test'])
|
proc, _ = create_process(["echo", "test"])
|
||||||
|
|
||||||
assert mock_popen.called
|
assert mock_popen.called
|
||||||
args, kwargs = mock_popen.call_args
|
args, kwargs = mock_popen.call_args
|
||||||
assert kwargs['stdin'] == subprocess.PIPE
|
assert kwargs["stdin"] == subprocess.PIPE
|
||||||
assert kwargs['stdout'] == subprocess.PIPE
|
assert kwargs["stdout"] == subprocess.PIPE
|
||||||
assert kwargs['stderr'] == subprocess.STDOUT
|
assert kwargs["stderr"] == subprocess.STDOUT
|
||||||
assert 'startupinfo' in kwargs
|
assert "startupinfo" in kwargs
|
||||||
assert kwargs['startupinfo'].dwFlags & subprocess.STARTF_USESHOWWINDOW
|
assert kwargs["startupinfo"].dwFlags & subprocess.STARTF_USESHOWWINDOW
|
||||||
|
|
||||||
def test_run_interactive_command(self):
|
def test_run_interactive_command(self):
|
||||||
"""Test running an interactive command on Windows."""
|
"""Test running an interactive command on Windows."""
|
||||||
test_output = "Test output\n"
|
test_output = "Test output\n"
|
||||||
|
|
||||||
with patch('subprocess.Popen') as mock_popen, \
|
with (
|
||||||
patch('pyte.Stream') as mock_stream, \
|
patch("subprocess.Popen") as mock_popen,
|
||||||
patch('pyte.HistoryScreen') as mock_screen, \
|
patch("pyte.Stream") as mock_stream,
|
||||||
patch('threading.Thread') as mock_thread:
|
patch("pyte.HistoryScreen") as mock_screen,
|
||||||
|
patch("threading.Thread") as mock_thread,
|
||||||
|
):
|
||||||
# Setup mock process
|
# Setup mock process
|
||||||
mock_process = MagicMock()
|
mock_process = MagicMock()
|
||||||
mock_process.stdout = MagicMock()
|
mock_process.stdout = MagicMock()
|
||||||
|
|
@ -54,25 +60,25 @@ class TestWindowsCompatibility:
|
||||||
mock_process.poll.side_effect = [None, 0] # First None, then return 0
|
mock_process.poll.side_effect = [None, 0] # First None, then return 0
|
||||||
mock_process.returncode = 0
|
mock_process.returncode = 0
|
||||||
mock_popen.return_value = mock_process
|
mock_popen.return_value = mock_process
|
||||||
|
|
||||||
# Setup mock screen with history
|
# Setup mock screen with history
|
||||||
mock_screen_instance = MagicMock()
|
mock_screen_instance = MagicMock()
|
||||||
mock_screen_instance.history.top = []
|
mock_screen_instance.history.top = []
|
||||||
mock_screen_instance.history.bottom = []
|
mock_screen_instance.history.bottom = []
|
||||||
mock_screen_instance.display = ["Test output"]
|
mock_screen_instance.display = ["Test output"]
|
||||||
mock_screen.return_value = mock_screen_instance
|
mock_screen.return_value = mock_screen_instance
|
||||||
|
|
||||||
# Setup mock thread
|
# Setup mock thread
|
||||||
mock_thread_instance = MagicMock()
|
mock_thread_instance = MagicMock()
|
||||||
mock_thread.return_value = mock_thread_instance
|
mock_thread.return_value = mock_thread_instance
|
||||||
|
|
||||||
# Run the command
|
# Run the command
|
||||||
output, return_code = run_interactive_command(['echo', 'test'])
|
output, return_code = run_interactive_command(["echo", "test"])
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert return_code == 0
|
assert return_code == 0
|
||||||
assert "Test output" in output.decode()
|
assert "Test output" in output.decode()
|
||||||
|
|
||||||
# Verify the thread was started and joined
|
# Verify the thread was started and joined
|
||||||
mock_thread_instance.start.assert_called()
|
mock_thread_instance.start.assert_called()
|
||||||
mock_thread_instance.join.assert_called()
|
mock_thread_instance.join.assert_called()
|
||||||
|
|
@ -80,29 +86,29 @@ class TestWindowsCompatibility:
|
||||||
def test_windows_dependencies(self):
|
def test_windows_dependencies(self):
|
||||||
"""Test that required Windows dependencies are available."""
|
"""Test that required Windows dependencies are available."""
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
import msvcrt
|
|
||||||
|
|
||||||
# If we get here without ImportError, the test passes
|
# If we get here without ImportError, the test passes
|
||||||
assert True
|
assert True
|
||||||
|
|
||||||
def test_windows_output_handling(self):
|
def test_windows_output_handling(self):
|
||||||
"""Test handling of multi-chunk output on Windows."""
|
"""Test handling of multi-chunk output on Windows."""
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
pytest.skip("Windows-specific test")
|
pytest.skip("Windows-specific test")
|
||||||
|
|
||||||
# Test with multiple chunks of output to verify proper handling
|
# Test with multiple chunks of output to verify proper handling
|
||||||
with patch('subprocess.Popen') as mock_popen, \
|
with (
|
||||||
patch('msvcrt.kbhit', return_value=False), \
|
patch("subprocess.Popen") as mock_popen,
|
||||||
patch('threading.Thread') as mock_thread, \
|
patch("msvcrt.kbhit", return_value=False),
|
||||||
patch('time.sleep'): # Mock sleep to speed up test
|
patch("threading.Thread") as mock_thread,
|
||||||
|
patch("time.sleep"),
|
||||||
|
): # Mock sleep to speed up test
|
||||||
# Setup mock process
|
# Setup mock process
|
||||||
mock_process = MagicMock()
|
mock_process = MagicMock()
|
||||||
mock_process.stdout = MagicMock()
|
mock_process.stdout = MagicMock()
|
||||||
mock_process.poll.return_value = 0
|
mock_process.poll.return_value = 0
|
||||||
mock_process.returncode = 0
|
mock_process.returncode = 0
|
||||||
mock_popen.return_value = mock_process
|
mock_popen.return_value = mock_process
|
||||||
|
|
||||||
# Setup mock thread to simulate output collection
|
# Setup mock thread to simulate output collection
|
||||||
def side_effect(*args, **kwargs):
|
def side_effect(*args, **kwargs):
|
||||||
# Simulate thread collecting output
|
# Simulate thread collecting output
|
||||||
|
|
@ -110,15 +116,15 @@ class TestWindowsCompatibility:
|
||||||
b"First chunk\n",
|
b"First chunk\n",
|
||||||
b"Second chunk\n",
|
b"Second chunk\n",
|
||||||
b"Third chunk with unicode \xe2\x9c\x93\n", # UTF-8 checkmark
|
b"Third chunk with unicode \xe2\x9c\x93\n", # UTF-8 checkmark
|
||||||
None # End of output
|
None, # End of output
|
||||||
]
|
]
|
||||||
return MagicMock()
|
return MagicMock()
|
||||||
|
|
||||||
mock_thread.side_effect = side_effect
|
mock_thread.side_effect = side_effect
|
||||||
|
|
||||||
# Run the command
|
# Run the command
|
||||||
output, return_code = run_interactive_command(['test', 'command'])
|
output, return_code = run_interactive_command(["test", "command"])
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert return_code == 0
|
assert return_code == 0
|
||||||
# We can't verify exact output content in this test since we're mocking the thread
|
# We can't verify exact output content in this test since we're mocking the thread
|
||||||
|
|
|
||||||
|
|
@ -2,22 +2,19 @@
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
AgentContext,
|
AgentContext,
|
||||||
agent_context,
|
agent_context,
|
||||||
get_current_context,
|
|
||||||
mark_task_completed,
|
|
||||||
mark_plan_completed,
|
|
||||||
reset_completion_flags,
|
|
||||||
is_completed,
|
|
||||||
get_completion_message,
|
get_completion_message,
|
||||||
|
get_current_context,
|
||||||
|
is_completed,
|
||||||
|
mark_plan_completed,
|
||||||
mark_should_exit,
|
mark_should_exit,
|
||||||
|
mark_task_completed,
|
||||||
|
reset_completion_flags,
|
||||||
should_exit,
|
should_exit,
|
||||||
mark_agent_crashed,
|
|
||||||
is_crashed,
|
|
||||||
get_crash_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -120,42 +117,42 @@ class TestContextManager:
|
||||||
|
|
||||||
class TestExitPropagation:
|
class TestExitPropagation:
|
||||||
"""Test cases for the agent_should_exit flag propagation."""
|
"""Test cases for the agent_should_exit flag propagation."""
|
||||||
|
|
||||||
def test_mark_should_exit_propagation(self):
|
def test_mark_should_exit_propagation(self):
|
||||||
"""Test that mark_should_exit propagates to parent contexts."""
|
"""Test that mark_should_exit propagates to parent contexts."""
|
||||||
parent = AgentContext()
|
parent = AgentContext()
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Initially both contexts should have agent_should_exit as False
|
# Initially both contexts should have agent_should_exit as False
|
||||||
assert parent.agent_should_exit is False
|
assert parent.agent_should_exit is False
|
||||||
assert child.agent_should_exit is False
|
assert child.agent_should_exit is False
|
||||||
|
|
||||||
# Mark the child context as should exit
|
# Mark the child context as should exit
|
||||||
child.mark_should_exit()
|
child.mark_should_exit()
|
||||||
|
|
||||||
# Both child and parent should now have agent_should_exit as True
|
# Both child and parent should now have agent_should_exit as True
|
||||||
assert child.agent_should_exit is True
|
assert child.agent_should_exit is True
|
||||||
assert parent.agent_should_exit is True
|
assert parent.agent_should_exit is True
|
||||||
|
|
||||||
def test_nested_should_exit_propagation(self):
|
def test_nested_should_exit_propagation(self):
|
||||||
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
|
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
|
||||||
grandparent = AgentContext()
|
grandparent = AgentContext()
|
||||||
parent = AgentContext(parent_context=grandparent)
|
parent = AgentContext(parent_context=grandparent)
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Initially all contexts should have agent_should_exit as False
|
# Initially all contexts should have agent_should_exit as False
|
||||||
assert grandparent.agent_should_exit is False
|
assert grandparent.agent_should_exit is False
|
||||||
assert parent.agent_should_exit is False
|
assert parent.agent_should_exit is False
|
||||||
assert child.agent_should_exit is False
|
assert child.agent_should_exit is False
|
||||||
|
|
||||||
# Mark the child context as should exit
|
# Mark the child context as should exit
|
||||||
child.mark_should_exit()
|
child.mark_should_exit()
|
||||||
|
|
||||||
# All contexts should now have agent_should_exit as True
|
# All contexts should now have agent_should_exit as True
|
||||||
assert child.agent_should_exit is True
|
assert child.agent_should_exit is True
|
||||||
assert parent.agent_should_exit is True
|
assert parent.agent_should_exit is True
|
||||||
assert grandparent.agent_should_exit is True
|
assert grandparent.agent_should_exit is True
|
||||||
|
|
||||||
def test_context_manager_should_exit_propagation(self):
|
def test_context_manager_should_exit_propagation(self):
|
||||||
"""Test that mark_should_exit propagates when using context managers."""
|
"""Test that mark_should_exit propagates when using context managers."""
|
||||||
with agent_context() as outer:
|
with agent_context() as outer:
|
||||||
|
|
@ -163,10 +160,10 @@ class TestExitPropagation:
|
||||||
# Initially both contexts should have agent_should_exit as False
|
# Initially both contexts should have agent_should_exit as False
|
||||||
assert outer.agent_should_exit is False
|
assert outer.agent_should_exit is False
|
||||||
assert inner.agent_should_exit is False
|
assert inner.agent_should_exit is False
|
||||||
|
|
||||||
# Mark the inner context as should exit
|
# Mark the inner context as should exit
|
||||||
inner.mark_should_exit()
|
inner.mark_should_exit()
|
||||||
|
|
||||||
# Both inner and outer should now have agent_should_exit as True
|
# Both inner and outer should now have agent_should_exit as True
|
||||||
assert inner.agent_should_exit is True
|
assert inner.agent_should_exit is True
|
||||||
assert outer.agent_should_exit is True
|
assert outer.agent_should_exit is True
|
||||||
|
|
@ -174,39 +171,39 @@ class TestExitPropagation:
|
||||||
|
|
||||||
class TestCrashPropagation:
|
class TestCrashPropagation:
|
||||||
"""Test cases for the agent_has_crashed flag non-propagation."""
|
"""Test cases for the agent_has_crashed flag non-propagation."""
|
||||||
|
|
||||||
def test_mark_agent_crashed_no_propagation(self):
|
def test_mark_agent_crashed_no_propagation(self):
|
||||||
"""Test that mark_agent_crashed does not propagate to parent contexts."""
|
"""Test that mark_agent_crashed does not propagate to parent contexts."""
|
||||||
parent = AgentContext()
|
parent = AgentContext()
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Initially both contexts should have agent_has_crashed as False
|
# Initially both contexts should have agent_has_crashed as False
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
assert child.is_crashed() is False
|
assert child.is_crashed() is False
|
||||||
|
|
||||||
# Mark the child context as crashed
|
# Mark the child context as crashed
|
||||||
child.mark_agent_crashed("Child crashed")
|
child.mark_agent_crashed("Child crashed")
|
||||||
|
|
||||||
# Child should be crashed, but parent should not
|
# Child should be crashed, but parent should not
|
||||||
assert child.is_crashed() is True
|
assert child.is_crashed() is True
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
assert child.agent_crashed_message == "Child crashed"
|
assert child.agent_crashed_message == "Child crashed"
|
||||||
assert parent.agent_crashed_message is None
|
assert parent.agent_crashed_message is None
|
||||||
|
|
||||||
def test_nested_crash_no_propagation(self):
|
def test_nested_crash_no_propagation(self):
|
||||||
"""Test that crash states don't propagate through multiple levels of parent contexts."""
|
"""Test that crash states don't propagate through multiple levels of parent contexts."""
|
||||||
grandparent = AgentContext()
|
grandparent = AgentContext()
|
||||||
parent = AgentContext(parent_context=grandparent)
|
parent = AgentContext(parent_context=grandparent)
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Initially all contexts should have agent_has_crashed as False
|
# Initially all contexts should have agent_has_crashed as False
|
||||||
assert grandparent.is_crashed() is False
|
assert grandparent.is_crashed() is False
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
assert child.is_crashed() is False
|
assert child.is_crashed() is False
|
||||||
|
|
||||||
# Mark the child context as crashed
|
# Mark the child context as crashed
|
||||||
child.mark_agent_crashed("Child crashed")
|
child.mark_agent_crashed("Child crashed")
|
||||||
|
|
||||||
# Only child should be crashed, parent and grandparent should not
|
# Only child should be crashed, parent and grandparent should not
|
||||||
assert child.is_crashed() is True
|
assert child.is_crashed() is True
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
|
|
@ -214,7 +211,7 @@ class TestCrashPropagation:
|
||||||
assert child.agent_crashed_message == "Child crashed"
|
assert child.agent_crashed_message == "Child crashed"
|
||||||
assert parent.agent_crashed_message is None
|
assert parent.agent_crashed_message is None
|
||||||
assert grandparent.agent_crashed_message is None
|
assert grandparent.agent_crashed_message is None
|
||||||
|
|
||||||
def test_context_manager_crash_no_propagation(self):
|
def test_context_manager_crash_no_propagation(self):
|
||||||
"""Test that crash state doesn't propagate when using context managers."""
|
"""Test that crash state doesn't propagate when using context managers."""
|
||||||
with agent_context() as outer:
|
with agent_context() as outer:
|
||||||
|
|
@ -222,27 +219,27 @@ class TestCrashPropagation:
|
||||||
# Initially both contexts should have agent_has_crashed as False
|
# Initially both contexts should have agent_has_crashed as False
|
||||||
assert outer.is_crashed() is False
|
assert outer.is_crashed() is False
|
||||||
assert inner.is_crashed() is False
|
assert inner.is_crashed() is False
|
||||||
|
|
||||||
# Mark the inner context as crashed
|
# Mark the inner context as crashed
|
||||||
inner.mark_agent_crashed("Inner crashed")
|
inner.mark_agent_crashed("Inner crashed")
|
||||||
|
|
||||||
# Inner should be crashed, but outer should not
|
# Inner should be crashed, but outer should not
|
||||||
assert inner.is_crashed() is True
|
assert inner.is_crashed() is True
|
||||||
assert outer.is_crashed() is False
|
assert outer.is_crashed() is False
|
||||||
assert inner.agent_crashed_message == "Inner crashed"
|
assert inner.agent_crashed_message == "Inner crashed"
|
||||||
assert outer.agent_crashed_message is None
|
assert outer.agent_crashed_message is None
|
||||||
|
|
||||||
def test_crash_state_not_inherited(self):
|
def test_crash_state_not_inherited(self):
|
||||||
"""Test that new child contexts don't inherit crash states from parent contexts."""
|
"""Test that new child contexts don't inherit crash states from parent contexts."""
|
||||||
parent = AgentContext()
|
parent = AgentContext()
|
||||||
|
|
||||||
# Mark the parent as crashed
|
# Mark the parent as crashed
|
||||||
parent.mark_agent_crashed("Parent crashed")
|
parent.mark_agent_crashed("Parent crashed")
|
||||||
assert parent.is_crashed() is True
|
assert parent.is_crashed() is True
|
||||||
|
|
||||||
# Create a child context with the crashed parent as parent_context
|
# Create a child context with the crashed parent as parent_context
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Child should not be crashed even though parent is
|
# Child should not be crashed even though parent is
|
||||||
assert parent.is_crashed() is True
|
assert parent.is_crashed() is True
|
||||||
assert child.is_crashed() is False
|
assert child.is_crashed() is False
|
||||||
|
|
@ -313,18 +310,18 @@ class TestUtilityFunctions:
|
||||||
# These should have safe default returns
|
# These should have safe default returns
|
||||||
assert is_completed() is False
|
assert is_completed() is False
|
||||||
assert get_completion_message() == ""
|
assert get_completion_message() == ""
|
||||||
|
|
||||||
def test_mark_should_exit_utility(self):
|
def test_mark_should_exit_utility(self):
|
||||||
"""Test the mark_should_exit utility function."""
|
"""Test the mark_should_exit utility function."""
|
||||||
with agent_context() as outer:
|
with agent_context() as outer:
|
||||||
with agent_context() as inner:
|
with agent_context() as inner:
|
||||||
# Initially both contexts should have agent_should_exit as False
|
# Initially both contexts should have agent_should_exit as False
|
||||||
assert should_exit() is False
|
assert should_exit() is False
|
||||||
|
|
||||||
# Mark the current context (inner) as should exit
|
# Mark the current context (inner) as should exit
|
||||||
mark_should_exit()
|
mark_should_exit()
|
||||||
|
|
||||||
# Both inner and outer should now have agent_should_exit as True
|
# Both inner and outer should now have agent_should_exit as True
|
||||||
assert should_exit() is True
|
assert should_exit() is True
|
||||||
assert inner.agent_should_exit is True
|
assert inner.agent_should_exit is True
|
||||||
assert outer.agent_should_exit is True
|
assert outer.agent_should_exit is True
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,9 @@ import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
from ra_aid.agent_context import agent_context, get_current_context, reset_completion_flags
|
from ra_aid.agent_context import (
|
||||||
|
agent_context,
|
||||||
|
)
|
||||||
from ra_aid.agent_utils import (
|
from ra_aid.agent_utils import (
|
||||||
AgentState,
|
AgentState,
|
||||||
create_agent,
|
create_agent,
|
||||||
|
|
@ -116,7 +118,10 @@ def test_create_agent_anthropic(mock_model, mock_memory):
|
||||||
|
|
||||||
assert agent == "react_agent"
|
assert agent == "react_agent"
|
||||||
mock_react.assert_called_once_with(
|
mock_react.assert_called_once_with(
|
||||||
mock_model, [], version='v2', state_modifier=mock_react.call_args[1]["state_modifier"]
|
mock_model,
|
||||||
|
[],
|
||||||
|
version="v2",
|
||||||
|
state_modifier=mock_react.call_args[1]["state_modifier"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -259,7 +264,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
|
||||||
agent = create_agent(mock_model, [])
|
agent = create_agent(mock_model, [])
|
||||||
|
|
||||||
assert agent == "react_agent"
|
assert agent == "react_agent"
|
||||||
mock_react.assert_called_once_with(mock_model, [], version='v2')
|
mock_react.assert_called_once_with(mock_model, [], version="v2")
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_research(mock_memory):
|
def test_get_model_token_limit_research(mock_memory):
|
||||||
|
|
@ -339,7 +344,7 @@ def test_run_agent_stream(monkeypatch):
|
||||||
ctx.plan_completed = True
|
ctx.plan_completed = True
|
||||||
ctx.task_completed = True
|
ctx.task_completed = True
|
||||||
ctx.completion_message = "existing"
|
ctx.completion_message = "existing"
|
||||||
|
|
||||||
call_flag = {"called": False}
|
call_flag = {"called": False}
|
||||||
|
|
||||||
def fake_print_agent_output(
|
def fake_print_agent_output(
|
||||||
|
|
@ -352,7 +357,7 @@ def test_run_agent_stream(monkeypatch):
|
||||||
)
|
)
|
||||||
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
|
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
|
||||||
assert call_flag["called"]
|
assert call_flag["called"]
|
||||||
|
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
assert ctx.plan_completed is False
|
assert ctx.plan_completed is False
|
||||||
assert ctx.task_completed is False
|
assert ctx.task_completed is False
|
||||||
|
|
@ -457,74 +462,93 @@ def test_is_anthropic_claude():
|
||||||
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
||||||
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
||||||
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
||||||
|
|
||||||
# Test OpenRouter provider cases
|
# Test OpenRouter provider cases
|
||||||
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-2"})
|
assert is_anthropic_claude(
|
||||||
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-instant"})
|
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
||||||
|
)
|
||||||
|
assert is_anthropic_claude(
|
||||||
|
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
||||||
|
)
|
||||||
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
||||||
|
|
||||||
# Test edge cases
|
# Test edge cases
|
||||||
assert not is_anthropic_claude({}) # Empty config
|
assert not is_anthropic_claude({}) # Empty config
|
||||||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||||
assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider
|
assert not is_anthropic_claude(
|
||||||
|
{"provider": "other", "model": "claude-2"}
|
||||||
|
) # Wrong provider
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
||||||
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
||||||
from ra_aid.agent_utils import run_agent_with_retry
|
|
||||||
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
||||||
|
from ra_aid.agent_utils import run_agent_with_retry
|
||||||
|
|
||||||
# Setup mocks for dependencies to isolate our test
|
# Setup mocks for dependencies to isolate our test
|
||||||
dummy_agent = Mock()
|
dummy_agent = Mock()
|
||||||
|
|
||||||
# Track function calls
|
# Track function calls
|
||||||
mock_calls = {"run_agent_stream": 0}
|
mock_calls = {"run_agent_stream": 0}
|
||||||
|
|
||||||
def mock_run_agent_stream(*args, **kwargs):
|
def mock_run_agent_stream(*args, **kwargs):
|
||||||
mock_calls["run_agent_stream"] += 1
|
mock_calls["run_agent_stream"] += 1
|
||||||
|
|
||||||
def mock_setup_interrupt_handling():
|
def mock_setup_interrupt_handling():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def mock_restore_interrupt_handling(handler):
|
def mock_restore_interrupt_handling(handler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_increment_agent_depth():
|
def mock_increment_agent_depth():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_decrement_agent_depth():
|
def mock_decrement_agent_depth():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_is_crashed():
|
def mock_is_crashed():
|
||||||
return ctx.is_crashed() if ctx else False
|
return ctx.is_crashed() if ctx else False
|
||||||
|
|
||||||
def mock_get_crash_message():
|
def mock_get_crash_message():
|
||||||
return ctx.agent_crashed_message if ctx and ctx.is_crashed() else None
|
return ctx.agent_crashed_message if ctx and ctx.is_crashed() else None
|
||||||
|
|
||||||
# Apply mocks
|
# Apply mocks
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||||
|
mock_restore_interrupt_handling,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||||
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||||
|
|
||||||
# First, run without a crash - agent should be run
|
# First, run without a crash - agent should be run
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||||
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_context.get_crash_message", mock_get_crash_message
|
||||||
|
)
|
||||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||||
assert mock_calls["run_agent_stream"] == 1
|
assert mock_calls["run_agent_stream"] == 1
|
||||||
|
|
||||||
# Reset call counter
|
# Reset call counter
|
||||||
mock_calls["run_agent_stream"] = 0
|
mock_calls["run_agent_stream"] = 0
|
||||||
|
|
||||||
# Now run with a crash - agent should not be run
|
# Now run with a crash - agent should not be run
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
mark_agent_crashed("Test crash message")
|
mark_agent_crashed("Test crash message")
|
||||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||||
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_context.get_crash_message", mock_get_crash_message
|
||||||
|
)
|
||||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||||
# Verify _run_agent_stream was not called
|
# Verify _run_agent_stream was not called
|
||||||
assert mock_calls["run_agent_stream"] == 0
|
assert mock_calls["run_agent_stream"] == 0
|
||||||
|
|
@ -534,54 +558,65 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
||||||
|
|
||||||
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
||||||
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
||||||
|
from ra_aid.agent_context import agent_context, is_crashed
|
||||||
from ra_aid.agent_utils import run_agent_with_retry
|
from ra_aid.agent_utils import run_agent_with_retry
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from ra_aid.agent_context import agent_context, is_crashed
|
|
||||||
|
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
dummy_agent = Mock()
|
dummy_agent = Mock()
|
||||||
|
|
||||||
# Track function calls and simulate BadRequestError
|
# Track function calls and simulate BadRequestError
|
||||||
run_count = [0]
|
run_count = [0]
|
||||||
|
|
||||||
def mock_run_agent_stream(*args, **kwargs):
|
def mock_run_agent_stream(*args, **kwargs):
|
||||||
run_count[0] += 1
|
run_count[0] += 1
|
||||||
if run_count[0] == 1:
|
if run_count[0] == 1:
|
||||||
# First call throws a 400 BadRequestError
|
# First call throws a 400 BadRequestError
|
||||||
raise ToolExecutionError("400 Bad Request: Invalid input")
|
raise ToolExecutionError("400 Bad Request: Invalid input")
|
||||||
# If it's called again, it should run normally
|
# If it's called again, it should run normally
|
||||||
|
|
||||||
def mock_setup_interrupt_handling():
|
def mock_setup_interrupt_handling():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def mock_restore_interrupt_handling(handler):
|
def mock_restore_interrupt_handling(handler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_increment_agent_depth():
|
def mock_increment_agent_depth():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_decrement_agent_depth():
|
def mock_decrement_agent_depth():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_mark_agent_crashed(message):
|
def mock_mark_agent_crashed(message):
|
||||||
ctx.agent_has_crashed = True
|
ctx.agent_has_crashed = True
|
||||||
ctx.agent_crashed_message = message
|
ctx.agent_crashed_message = message
|
||||||
|
|
||||||
def mock_is_crashed():
|
def mock_is_crashed():
|
||||||
return ctx.is_crashed() if ctx else False
|
return ctx.is_crashed() if ctx else False
|
||||||
|
|
||||||
# Apply mocks
|
# Apply mocks
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||||
|
mock_restore_interrupt_handling,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||||
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||||
|
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed
|
||||||
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||||
|
|
||||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||||
# Verify the agent was only run once and not retried
|
# Verify the agent was only run once and not retried
|
||||||
assert run_count[0] == 1
|
assert run_count[0] == 1
|
||||||
|
|
@ -594,60 +629,73 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
||||||
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
||||||
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
||||||
# Import APIError from anthropic module and patch it on the agent_utils module
|
# Import APIError from anthropic module and patch it on the agent_utils module
|
||||||
from anthropic import APIError as AnthropicAPIError
|
|
||||||
from ra_aid.agent_utils import run_agent_with_retry
|
|
||||||
from ra_aid.agent_context import agent_context, is_crashed
|
from ra_aid.agent_context import agent_context, is_crashed
|
||||||
|
from ra_aid.agent_utils import run_agent_with_retry
|
||||||
|
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
dummy_agent = Mock()
|
dummy_agent = Mock()
|
||||||
|
|
||||||
# Track function calls and simulate BadRequestError
|
# Track function calls and simulate BadRequestError
|
||||||
run_count = [0]
|
run_count = [0]
|
||||||
|
|
||||||
# Create a mock APIError class that simulates Anthropic's APIError
|
# Create a mock APIError class that simulates Anthropic's APIError
|
||||||
class MockAPIError(Exception):
|
class MockAPIError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_run_agent_stream(*args, **kwargs):
|
def mock_run_agent_stream(*args, **kwargs):
|
||||||
run_count[0] += 1
|
run_count[0] += 1
|
||||||
if run_count[0] == 1:
|
if run_count[0] == 1:
|
||||||
# First call throws a 400 Bad Request APIError
|
# First call throws a 400 Bad Request APIError
|
||||||
mock_error = MockAPIError("400 Bad Request")
|
mock_error = MockAPIError("400 Bad Request")
|
||||||
mock_error.__class__.__name__ = "APIError" # Make it look like Anthropic's APIError
|
mock_error.__class__.__name__ = (
|
||||||
|
"APIError" # Make it look like Anthropic's APIError
|
||||||
|
)
|
||||||
raise mock_error
|
raise mock_error
|
||||||
# If it's called again, it should run normally
|
# If it's called again, it should run normally
|
||||||
|
|
||||||
def mock_setup_interrupt_handling():
|
def mock_setup_interrupt_handling():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def mock_restore_interrupt_handling(handler):
|
def mock_restore_interrupt_handling(handler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_increment_agent_depth():
|
def mock_increment_agent_depth():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_decrement_agent_depth():
|
def mock_decrement_agent_depth():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_mark_agent_crashed(message):
|
def mock_mark_agent_crashed(message):
|
||||||
ctx.agent_has_crashed = True
|
ctx.agent_has_crashed = True
|
||||||
ctx.agent_crashed_message = message
|
ctx.agent_crashed_message = message
|
||||||
|
|
||||||
def mock_is_crashed():
|
def mock_is_crashed():
|
||||||
return ctx.is_crashed() if ctx else False
|
return ctx.is_crashed() if ctx else False
|
||||||
|
|
||||||
# Apply mocks
|
# Apply mocks
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
"ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||||
|
mock_restore_interrupt_handling,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||||
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
|
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
|
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
|
||||||
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed
|
||||||
|
)
|
||||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||||
|
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||||
# Verify the agent was only run once and not retried
|
# Verify the agent was only run once and not retried
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -12,14 +12,18 @@ def clean_env():
|
||||||
"""Remove relevant environment variables before each test."""
|
"""Remove relevant environment variables before each test."""
|
||||||
# Save existing values
|
# Save existing values
|
||||||
saved_vars = {}
|
saved_vars = {}
|
||||||
for var in ['ANTHROPIC_API_KEY', 'EXPERT_ANTHROPIC_API_KEY',
|
for var in [
|
||||||
'ANTHROPIC_MODEL', 'EXPERT_ANTHROPIC_MODEL']:
|
"ANTHROPIC_API_KEY",
|
||||||
|
"EXPERT_ANTHROPIC_API_KEY",
|
||||||
|
"ANTHROPIC_MODEL",
|
||||||
|
"EXPERT_ANTHROPIC_MODEL",
|
||||||
|
]:
|
||||||
saved_vars[var] = os.environ.get(var)
|
saved_vars[var] = os.environ.get(var)
|
||||||
if var in os.environ:
|
if var in os.environ:
|
||||||
del os.environ[var]
|
del os.environ[var]
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Restore saved values
|
# Restore saved values
|
||||||
for var, value in saved_vars.items():
|
for var, value in saved_vars.items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
|
|
@ -31,6 +35,7 @@ def clean_env():
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockArgs:
|
class MockArgs:
|
||||||
"""Mock arguments class for testing."""
|
"""Mock arguments class for testing."""
|
||||||
|
|
||||||
expert_provider: str
|
expert_provider: str
|
||||||
expert_model: Optional[str] = None
|
expert_model: Optional[str] = None
|
||||||
|
|
||||||
|
|
@ -39,9 +44,9 @@ def test_anthropic_expert_validation_message(clean_env):
|
||||||
"""Test that validation message refers to base key when neither key exists."""
|
"""Test that validation message refers to base key when neither key exists."""
|
||||||
strategy = AnthropicStrategy()
|
strategy = AnthropicStrategy()
|
||||||
args = MockArgs(expert_provider="anthropic")
|
args = MockArgs(expert_provider="anthropic")
|
||||||
|
|
||||||
result = strategy.validate(args)
|
result = strategy.validate(args)
|
||||||
|
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert len(result.missing_vars) > 0
|
assert len(result.missing_vars) > 0
|
||||||
assert "ANTHROPIC_API_KEY environment variable is not set" in result.missing_vars[0]
|
assert "ANTHROPIC_API_KEY environment variable is not set" in result.missing_vars[0]
|
||||||
|
|
|
||||||
|
|
@ -1,56 +1,55 @@
|
||||||
"""Unit tests for crash propagation behavior in agent_context."""
|
"""Unit tests for crash propagation behavior in agent_context."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
AgentContext,
|
AgentContext,
|
||||||
agent_context,
|
agent_context,
|
||||||
mark_agent_crashed,
|
|
||||||
is_crashed,
|
|
||||||
get_crash_message,
|
get_crash_message,
|
||||||
|
is_crashed,
|
||||||
|
mark_agent_crashed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCrashPropagation:
|
class TestCrashPropagation:
|
||||||
"""Test cases for crash state propagation behavior."""
|
"""Test cases for crash state propagation behavior."""
|
||||||
|
|
||||||
def test_mark_agent_crashed_no_propagation(self):
|
def test_mark_agent_crashed_no_propagation(self):
|
||||||
"""Test that mark_agent_crashed does not propagate to parent contexts."""
|
"""Test that mark_agent_crashed does not propagate to parent contexts."""
|
||||||
parent = AgentContext()
|
parent = AgentContext()
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Initially both contexts should have is_crashed as False
|
# Initially both contexts should have is_crashed as False
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
assert child.is_crashed() is False
|
assert child.is_crashed() is False
|
||||||
|
|
||||||
# Mark the child context as crashed
|
# Mark the child context as crashed
|
||||||
child.mark_agent_crashed("Child crashed")
|
child.mark_agent_crashed("Child crashed")
|
||||||
|
|
||||||
# Child should be crashed but parent should not
|
# Child should be crashed but parent should not
|
||||||
assert child.is_crashed() is True
|
assert child.is_crashed() is True
|
||||||
assert child.agent_crashed_message == "Child crashed"
|
assert child.agent_crashed_message == "Child crashed"
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
assert parent.agent_crashed_message is None
|
assert parent.agent_crashed_message is None
|
||||||
|
|
||||||
def test_nested_crash_no_propagation(self):
|
def test_nested_crash_no_propagation(self):
|
||||||
"""Test that crash state doesn't propagate through multiple levels of parent contexts."""
|
"""Test that crash state doesn't propagate through multiple levels of parent contexts."""
|
||||||
grandparent = AgentContext()
|
grandparent = AgentContext()
|
||||||
parent = AgentContext(parent_context=grandparent)
|
parent = AgentContext(parent_context=grandparent)
|
||||||
child = AgentContext(parent_context=parent)
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
# Initially all contexts should have is_crashed as False
|
# Initially all contexts should have is_crashed as False
|
||||||
assert grandparent.is_crashed() is False
|
assert grandparent.is_crashed() is False
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
assert child.is_crashed() is False
|
assert child.is_crashed() is False
|
||||||
|
|
||||||
# Mark the child context as crashed
|
# Mark the child context as crashed
|
||||||
child.mark_agent_crashed("Child crashed")
|
child.mark_agent_crashed("Child crashed")
|
||||||
|
|
||||||
# Only child should be crashed
|
# Only child should be crashed
|
||||||
assert child.is_crashed() is True
|
assert child.is_crashed() is True
|
||||||
assert parent.is_crashed() is False
|
assert parent.is_crashed() is False
|
||||||
assert grandparent.is_crashed() is False
|
assert grandparent.is_crashed() is False
|
||||||
|
|
||||||
def test_context_manager_crash_no_propagation(self):
|
def test_context_manager_crash_no_propagation(self):
|
||||||
"""Test that crash states don't propagate when using context managers."""
|
"""Test that crash states don't propagate when using context managers."""
|
||||||
with agent_context() as outer:
|
with agent_context() as outer:
|
||||||
|
|
@ -58,14 +57,14 @@ class TestCrashPropagation:
|
||||||
# Initially both contexts should have is_crashed as False
|
# Initially both contexts should have is_crashed as False
|
||||||
assert outer.is_crashed() is False
|
assert outer.is_crashed() is False
|
||||||
assert inner.is_crashed() is False
|
assert inner.is_crashed() is False
|
||||||
|
|
||||||
# Mark the inner context as crashed
|
# Mark the inner context as crashed
|
||||||
inner.mark_agent_crashed("Inner crashed")
|
inner.mark_agent_crashed("Inner crashed")
|
||||||
|
|
||||||
# Inner should be crashed but outer should not
|
# Inner should be crashed but outer should not
|
||||||
assert inner.is_crashed() is True
|
assert inner.is_crashed() is True
|
||||||
assert outer.is_crashed() is False
|
assert outer.is_crashed() is False
|
||||||
|
|
||||||
def test_utility_functions_for_crash_state(self):
|
def test_utility_functions_for_crash_state(self):
|
||||||
"""Test utility functions for crash state."""
|
"""Test utility functions for crash state."""
|
||||||
with agent_context() as outer:
|
with agent_context() as outer:
|
||||||
|
|
@ -73,12 +72,12 @@ class TestCrashPropagation:
|
||||||
# Initially both contexts should have is_crashed as False
|
# Initially both contexts should have is_crashed as False
|
||||||
assert is_crashed() is False
|
assert is_crashed() is False
|
||||||
assert get_crash_message() is None
|
assert get_crash_message() is None
|
||||||
|
|
||||||
# Mark the current context (inner) as crashed
|
# Mark the current context (inner) as crashed
|
||||||
mark_agent_crashed("Utility function crash")
|
mark_agent_crashed("Utility function crash")
|
||||||
|
|
||||||
# Current context should be crashed but outer should not
|
# Current context should be crashed but outer should not
|
||||||
assert is_crashed() is True
|
assert is_crashed() is True
|
||||||
assert get_crash_message() == "Utility function crash"
|
assert get_crash_message() == "Utility function crash"
|
||||||
assert inner.is_crashed() is True
|
assert inner.is_crashed() is True
|
||||||
assert outer.is_crashed() is False
|
assert outer.is_crashed() is False
|
||||||
|
|
|
||||||
|
|
@ -45,12 +45,23 @@ def test_default_anthropic_provider(clean_env, monkeypatch):
|
||||||
"""Test that Anthropic is the default provider when no environment variables are set."""
|
"""Test that Anthropic is the default provider when no environment variables are set."""
|
||||||
args = parse_arguments(["-m", "test message"])
|
args = parse_arguments(["-m", "test message"])
|
||||||
assert args.provider == "anthropic"
|
assert args.provider == "anthropic"
|
||||||
assert args.model == "claude-3-7-sonnet-20250219" # Updated to match current default
|
assert (
|
||||||
|
args.model == "claude-3-7-sonnet-20250219"
|
||||||
|
) # Updated to match current default
|
||||||
|
|
||||||
|
|
||||||
def test_respects_user_specified_anthropic_model(clean_env):
|
def test_respects_user_specified_anthropic_model(clean_env):
|
||||||
"""Test that user-specified Anthropic models are respected."""
|
"""Test that user-specified Anthropic models are respected."""
|
||||||
args = parse_arguments(["-m", "test message", "--provider", "anthropic", "--model", "claude-3-5-sonnet-20241022"])
|
args = parse_arguments(
|
||||||
|
[
|
||||||
|
"-m",
|
||||||
|
"test message",
|
||||||
|
"--provider",
|
||||||
|
"anthropic",
|
||||||
|
"--model",
|
||||||
|
"claude-3-5-sonnet-20241022",
|
||||||
|
]
|
||||||
|
)
|
||||||
assert args.provider == "anthropic"
|
assert args.provider == "anthropic"
|
||||||
assert args.model == "claude-3-5-sonnet-20241022" # Should not be overridden
|
assert args.model == "claude-3-5-sonnet-20241022" # Should not be overridden
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
|
||||||
|
|
||||||
# Check that mock_anthropic was called
|
# Check that mock_anthropic was called
|
||||||
assert mock_anthropic.called
|
assert mock_anthropic.called
|
||||||
|
|
||||||
# Verify essential parameters
|
# Verify essential parameters
|
||||||
kwargs = mock_anthropic.call_args.kwargs
|
kwargs = mock_anthropic.call_args.kwargs
|
||||||
assert kwargs["api_key"] == "test-key"
|
assert kwargs["api_key"] == "test-key"
|
||||||
|
|
@ -123,10 +123,7 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
|
||||||
temperature=0,
|
temperature=0,
|
||||||
timeout=180,
|
timeout=180,
|
||||||
max_retries=5,
|
max_retries=5,
|
||||||
default_headers={
|
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||||
"HTTP-Referer": "https://ra-aid.ai",
|
|
||||||
"X-Title": "RA.Aid"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -203,7 +200,7 @@ def test_initialize_anthropic(clean_env, mock_anthropic):
|
||||||
|
|
||||||
# Check that mock_anthropic was called
|
# Check that mock_anthropic was called
|
||||||
assert mock_anthropic.called
|
assert mock_anthropic.called
|
||||||
|
|
||||||
# Verify essential parameters
|
# Verify essential parameters
|
||||||
kwargs = mock_anthropic.call_args.kwargs
|
kwargs = mock_anthropic.call_args.kwargs
|
||||||
assert kwargs["api_key"] == "test-key"
|
assert kwargs["api_key"] == "test-key"
|
||||||
|
|
@ -225,10 +222,7 @@ def test_initialize_openrouter(clean_env, mock_openai):
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
timeout=180,
|
timeout=180,
|
||||||
max_retries=5,
|
max_retries=5,
|
||||||
default_headers={
|
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||||
"HTTP-Referer": "https://ra-aid.ai",
|
|
||||||
"X-Title": "RA.Aid"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -285,7 +279,7 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin
|
||||||
)
|
)
|
||||||
|
|
||||||
initialize_llm("anthropic", "test-model")
|
initialize_llm("anthropic", "test-model")
|
||||||
|
|
||||||
# Verify essential parameters for Anthropic
|
# Verify essential parameters for Anthropic
|
||||||
kwargs = mock_anthropic.call_args.kwargs
|
kwargs = mock_anthropic.call_args.kwargs
|
||||||
assert kwargs["api_key"] == "test-key"
|
assert kwargs["api_key"] == "test-key"
|
||||||
|
|
@ -354,7 +348,7 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
|
||||||
|
|
||||||
# Test Anthropic
|
# Test Anthropic
|
||||||
initialize_llm("anthropic", "test-model", temperature=test_temp)
|
initialize_llm("anthropic", "test-model", temperature=test_temp)
|
||||||
|
|
||||||
# Verify essential parameters for Anthropic
|
# Verify essential parameters for Anthropic
|
||||||
kwargs = mock_anthropic.call_args.kwargs
|
kwargs = mock_anthropic.call_args.kwargs
|
||||||
assert kwargs["api_key"] == "test-key"
|
assert kwargs["api_key"] == "test-key"
|
||||||
|
|
@ -372,10 +366,7 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
|
||||||
temperature=test_temp,
|
temperature=test_temp,
|
||||||
timeout=180,
|
timeout=180,
|
||||||
max_retries=5,
|
max_retries=5,
|
||||||
default_headers={
|
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||||
"HTTP-Referer": "https://ra-aid.ai",
|
|
||||||
"X-Title": "RA.Aid"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -482,7 +473,7 @@ def test_initialize_llm_cross_provider(
|
||||||
# Initialize Anthropic
|
# Initialize Anthropic
|
||||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
|
||||||
_llm2 = initialize_llm("anthropic", "claude-3", temperature=0.7)
|
_llm2 = initialize_llm("anthropic", "claude-3", temperature=0.7)
|
||||||
|
|
||||||
# Verify essential parameters for Anthropic
|
# Verify essential parameters for Anthropic
|
||||||
kwargs = mock_anthropic.call_args.kwargs
|
kwargs = mock_anthropic.call_args.kwargs
|
||||||
assert kwargs["api_key"] == "anthropic-key"
|
assert kwargs["api_key"] == "anthropic-key"
|
||||||
|
|
@ -586,13 +577,15 @@ def mock_deepseek_reasoner():
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def test_reasoning_effort_only_passed_to_supported_models(clean_env, mock_openai, monkeypatch):
|
def test_reasoning_effort_only_passed_to_supported_models(
|
||||||
|
clean_env, mock_openai, monkeypatch
|
||||||
|
):
|
||||||
"""Test that reasoning_effort is only passed to supported models."""
|
"""Test that reasoning_effort is only passed to supported models."""
|
||||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
||||||
|
|
||||||
# Initialize expert LLM with GPT-4 (which doesn't support reasoning_effort)
|
# Initialize expert LLM with GPT-4 (which doesn't support reasoning_effort)
|
||||||
_llm = initialize_expert_llm("openai", "gpt-4")
|
_llm = initialize_expert_llm("openai", "gpt-4")
|
||||||
|
|
||||||
# Verify reasoning_effort was not included in kwargs
|
# Verify reasoning_effort was not included in kwargs
|
||||||
mock_openai.assert_called_with(
|
mock_openai.assert_called_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
|
|
@ -603,13 +596,15 @@ def test_reasoning_effort_only_passed_to_supported_models(clean_env, mock_openai
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_reasoning_effort_passed_to_supported_models(clean_env, mock_openai, monkeypatch):
|
def test_reasoning_effort_passed_to_supported_models(
|
||||||
|
clean_env, mock_openai, monkeypatch
|
||||||
|
):
|
||||||
"""Test that reasoning_effort is passed to models that support it."""
|
"""Test that reasoning_effort is passed to models that support it."""
|
||||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
||||||
|
|
||||||
# Initialize expert LLM with o1 (which supports reasoning_effort)
|
# Initialize expert LLM with o1 (which supports reasoning_effort)
|
||||||
_llm = initialize_expert_llm("openai", "o1")
|
_llm = initialize_expert_llm("openai", "o1")
|
||||||
|
|
||||||
# Verify reasoning_effort was included in kwargs
|
# Verify reasoning_effort was included in kwargs
|
||||||
mock_openai.assert_called_with(
|
mock_openai.assert_called_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
|
|
@ -664,8 +659,5 @@ def test_initialize_openrouter_deepseek(
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
timeout=180,
|
timeout=180,
|
||||||
max_retries=5,
|
max_retries=5,
|
||||||
default_headers={
|
default_headers={"HTTP-Referer": "https://ra-aid.ai", "X-Title": "RA.Aid"},
|
||||||
"HTTP-Referer": "https://ra-aid.ai",
|
|
||||||
"X-Title": "RA.Aid"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -204,15 +204,15 @@ def test_use_aider_flag(mock_dependencies):
|
||||||
"""Test that use-aider flag is correctly stored in config."""
|
"""Test that use-aider flag is correctly stored in config."""
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
|
|
||||||
|
|
||||||
from ra_aid.__main__ import main
|
from ra_aid.__main__ import main
|
||||||
|
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
|
||||||
|
|
||||||
_global_memory.clear()
|
_global_memory.clear()
|
||||||
|
|
||||||
# Reset to default state
|
# Reset to default state
|
||||||
set_modification_tools(False)
|
set_modification_tools(False)
|
||||||
|
|
||||||
# Check default behavior (use_aider=False)
|
# Check default behavior (use_aider=False)
|
||||||
with patch.object(
|
with patch.object(
|
||||||
sys,
|
sys,
|
||||||
|
|
@ -222,15 +222,15 @@ def test_use_aider_flag(mock_dependencies):
|
||||||
main()
|
main()
|
||||||
config = _global_memory["config"]
|
config = _global_memory["config"]
|
||||||
assert config.get("use_aider") is False
|
assert config.get("use_aider") is False
|
||||||
|
|
||||||
# Check that file tools are enabled by default
|
# Check that file tools are enabled by default
|
||||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||||
assert "file_str_replace" in tool_names
|
assert "file_str_replace" in tool_names
|
||||||
assert "put_complete_file_contents" in tool_names
|
assert "put_complete_file_contents" in tool_names
|
||||||
assert "run_programming_task" not in tool_names
|
assert "run_programming_task" not in tool_names
|
||||||
|
|
||||||
_global_memory.clear()
|
_global_memory.clear()
|
||||||
|
|
||||||
# Check with --use-aider flag
|
# Check with --use-aider flag
|
||||||
with patch.object(
|
with patch.object(
|
||||||
sys,
|
sys,
|
||||||
|
|
@ -240,12 +240,12 @@ def test_use_aider_flag(mock_dependencies):
|
||||||
main()
|
main()
|
||||||
config = _global_memory["config"]
|
config = _global_memory["config"]
|
||||||
assert config.get("use_aider") is True
|
assert config.get("use_aider") is True
|
||||||
|
|
||||||
# Check that run_programming_task is enabled
|
# Check that run_programming_task is enabled
|
||||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||||
assert "file_str_replace" not in tool_names
|
assert "file_str_replace" not in tool_names
|
||||||
assert "put_complete_file_contents" not in tool_names
|
assert "put_complete_file_contents" not in tool_names
|
||||||
assert "run_programming_task" in tool_names
|
assert "run_programming_task" in tool_names
|
||||||
|
|
||||||
# Reset to default state for other tests
|
# Reset to default state for other tests
|
||||||
set_modification_tools(False)
|
set_modification_tools(False)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from ra_aid.tool_configs import (
|
from ra_aid.tool_configs import (
|
||||||
|
MODIFICATION_TOOLS,
|
||||||
get_implementation_tools,
|
get_implementation_tools,
|
||||||
get_planning_tools,
|
get_planning_tools,
|
||||||
get_read_only_tools,
|
get_read_only_tools,
|
||||||
get_research_tools,
|
get_research_tools,
|
||||||
get_web_research_tools,
|
get_web_research_tools,
|
||||||
set_modification_tools,
|
set_modification_tools,
|
||||||
MODIFICATION_TOOLS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,16 +14,16 @@ def test_get_read_only_tools():
|
||||||
tools = get_read_only_tools(human_interaction=False, use_aider=False)
|
tools = get_read_only_tools(human_interaction=False, use_aider=False)
|
||||||
assert len(tools) > 0
|
assert len(tools) > 0
|
||||||
assert all(callable(tool) for tool in tools)
|
assert all(callable(tool) for tool in tools)
|
||||||
|
|
||||||
# Check emit_related_files is not included when use_aider is False
|
# Check emit_related_files is not included when use_aider is False
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
assert "emit_related_files" not in tool_names
|
assert "emit_related_files" not in tool_names
|
||||||
|
|
||||||
# Test with use_aider=True
|
# Test with use_aider=True
|
||||||
tools_with_aider = get_read_only_tools(human_interaction=False, use_aider=True)
|
tools_with_aider = get_read_only_tools(human_interaction=False, use_aider=True)
|
||||||
tool_names_with_aider = [tool.name for tool in tools_with_aider]
|
tool_names_with_aider = [tool.name for tool in tools_with_aider]
|
||||||
assert "emit_related_files" in tool_names_with_aider
|
assert "emit_related_files" in tool_names_with_aider
|
||||||
|
|
||||||
# Test with human interaction
|
# Test with human interaction
|
||||||
tools_with_human = get_read_only_tools(human_interaction=True, use_aider=False)
|
tools_with_human = get_read_only_tools(human_interaction=True, use_aider=False)
|
||||||
assert len(tools_with_human) == len(tools) + 1
|
assert len(tools_with_human) == len(tools) + 1
|
||||||
|
|
@ -102,13 +102,13 @@ def test_set_modification_tools():
|
||||||
assert "file_str_replace" in tool_names
|
assert "file_str_replace" in tool_names
|
||||||
assert "put_complete_file_contents" in tool_names
|
assert "put_complete_file_contents" in tool_names
|
||||||
assert "run_programming_task" not in tool_names
|
assert "run_programming_task" not in tool_names
|
||||||
|
|
||||||
# Test with use_aider=True
|
# Test with use_aider=True
|
||||||
set_modification_tools(use_aider=True)
|
set_modification_tools(use_aider=True)
|
||||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||||
assert "file_str_replace" not in tool_names
|
assert "file_str_replace" not in tool_names
|
||||||
assert "put_complete_file_contents" not in tool_names
|
assert "put_complete_file_contents" not in tool_names
|
||||||
assert "run_programming_task" in tool_names
|
assert "run_programming_task" in tool_names
|
||||||
|
|
||||||
# Reset to default for other tests
|
# Reset to default for other tests
|
||||||
set_modification_tools(use_aider=False)
|
set_modification_tools(use_aider=False)
|
||||||
|
|
|
||||||
|
|
@ -222,7 +222,7 @@ def test_emit_key_snippet(reset_memory):
|
||||||
|
|
||||||
# Verify counter incremented correctly
|
# Verify counter incremented correctly
|
||||||
assert _global_memory["key_snippet_id_counter"] == 1
|
assert _global_memory["key_snippet_id_counter"] == 1
|
||||||
|
|
||||||
# Test snippet without description
|
# Test snippet without description
|
||||||
snippet2 = {
|
snippet2 = {
|
||||||
"filepath": "main.py",
|
"filepath": "main.py",
|
||||||
|
|
@ -230,16 +230,16 @@ def test_emit_key_snippet(reset_memory):
|
||||||
"snippet": "print('hello')",
|
"snippet": "print('hello')",
|
||||||
"description": None,
|
"description": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Emit second snippet
|
# Emit second snippet
|
||||||
result = emit_key_snippet.invoke({"snippet_info": snippet2})
|
result = emit_key_snippet.invoke({"snippet_info": snippet2})
|
||||||
|
|
||||||
# Verify return message
|
# Verify return message
|
||||||
assert result == "Snippet #1 stored."
|
assert result == "Snippet #1 stored."
|
||||||
|
|
||||||
# Verify snippet stored correctly
|
# Verify snippet stored correctly
|
||||||
assert _global_memory["key_snippets"][1] == snippet2
|
assert _global_memory["key_snippets"][1] == snippet2
|
||||||
|
|
||||||
# Verify counter incremented correctly
|
# Verify counter incremented correctly
|
||||||
assert _global_memory["key_snippet_id_counter"] == 2
|
assert _global_memory["key_snippet_id_counter"] == 2
|
||||||
|
|
||||||
|
|
@ -723,37 +723,40 @@ def test_emit_related_files_binary_filtering(reset_memory, tmp_path, monkeypatch
|
||||||
text_file1.write_text("Text file 1 content")
|
text_file1.write_text("Text file 1 content")
|
||||||
text_file2 = tmp_path / "text2.txt"
|
text_file2 = tmp_path / "text2.txt"
|
||||||
text_file2.write_text("Text file 2 content")
|
text_file2.write_text("Text file 2 content")
|
||||||
|
|
||||||
# Create test "binary" files
|
# Create test "binary" files
|
||||||
binary_file1 = tmp_path / "binary1.bin"
|
binary_file1 = tmp_path / "binary1.bin"
|
||||||
binary_file1.write_text("Binary file 1 content")
|
binary_file1.write_text("Binary file 1 content")
|
||||||
binary_file2 = tmp_path / "binary2.bin"
|
binary_file2 = tmp_path / "binary2.bin"
|
||||||
binary_file2.write_text("Binary file 2 content")
|
binary_file2.write_text("Binary file 2 content")
|
||||||
|
|
||||||
# Mock the is_binary_file function to identify our "binary" files
|
# Mock the is_binary_file function to identify our "binary" files
|
||||||
def mock_is_binary_file(filepath):
|
def mock_is_binary_file(filepath):
|
||||||
return ".bin" in str(filepath)
|
return ".bin" in str(filepath)
|
||||||
|
|
||||||
# Apply the mock
|
# Apply the mock
|
||||||
import ra_aid.tools.memory
|
import ra_aid.tools.memory
|
||||||
|
|
||||||
monkeypatch.setattr(ra_aid.tools.memory, "is_binary_file", mock_is_binary_file)
|
monkeypatch.setattr(ra_aid.tools.memory, "is_binary_file", mock_is_binary_file)
|
||||||
|
|
||||||
# Call emit_related_files with mix of text and binary files
|
# Call emit_related_files with mix of text and binary files
|
||||||
result = emit_related_files.invoke({
|
result = emit_related_files.invoke(
|
||||||
"files": [
|
{
|
||||||
str(text_file1),
|
"files": [
|
||||||
str(binary_file1),
|
str(text_file1),
|
||||||
str(text_file2),
|
str(binary_file1),
|
||||||
str(binary_file2)
|
str(text_file2),
|
||||||
]
|
str(binary_file2),
|
||||||
})
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the result message mentions skipped binary files
|
# Verify the result message mentions skipped binary files
|
||||||
assert "Files noted." in result
|
assert "Files noted." in result
|
||||||
assert "Binary files skipped:" in result
|
assert "Binary files skipped:" in result
|
||||||
assert f"'{binary_file1}'" in result
|
assert f"'{binary_file1}'" in result
|
||||||
assert f"'{binary_file2}'" in result
|
assert f"'{binary_file2}'" in result
|
||||||
|
|
||||||
# Verify only text files were added to related_files
|
# Verify only text files were added to related_files
|
||||||
assert len(_global_memory["related_files"]) == 2
|
assert len(_global_memory["related_files"]) == 2
|
||||||
file_values = list(_global_memory["related_files"].values())
|
file_values = list(_global_memory["related_files"].values())
|
||||||
|
|
@ -761,6 +764,60 @@ def test_emit_related_files_binary_filtering(reset_memory, tmp_path, monkeypatch
|
||||||
assert str(text_file2) in file_values
|
assert str(text_file2) in file_values
|
||||||
assert str(binary_file1) not in file_values
|
assert str(binary_file1) not in file_values
|
||||||
assert str(binary_file2) not in file_values
|
assert str(binary_file2) not in file_values
|
||||||
|
|
||||||
# Verify counter is correct (only incremented for text files)
|
# Verify counter is correct (only incremented for text files)
|
||||||
assert _global_memory["related_file_id_counter"] == 2
|
assert _global_memory["related_file_id_counter"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_binary_file_with_ascii(reset_memory, monkeypatch):
|
||||||
|
"""Test that ASCII files are correctly identified as text files"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import ra_aid.tools.memory
|
||||||
|
|
||||||
|
# Path to the mock ASCII file
|
||||||
|
ascii_file_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "mocks", "ascii.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with magic library if available
|
||||||
|
if ra_aid.tools.memory.magic:
|
||||||
|
# Test real implementation with ASCII file
|
||||||
|
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
|
||||||
|
assert not is_binary, "ASCII file should not be identified as binary"
|
||||||
|
|
||||||
|
# Test fallback implementation
|
||||||
|
# Mock magic to be None to force fallback implementation
|
||||||
|
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
|
||||||
|
|
||||||
|
# Test fallback with ASCII file
|
||||||
|
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
|
||||||
|
assert (
|
||||||
|
not is_binary
|
||||||
|
), "ASCII file should not be identified as binary with fallback method"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_binary_file_with_null_bytes(reset_memory, tmp_path, monkeypatch):
|
||||||
|
"""Test that files with null bytes are correctly identified as binary"""
|
||||||
|
import ra_aid.tools.memory
|
||||||
|
|
||||||
|
# Create a file with null bytes (binary content)
|
||||||
|
binary_file = tmp_path / "binary_with_nulls.bin"
|
||||||
|
with open(binary_file, "wb") as f:
|
||||||
|
f.write(b"Some text with \x00 null \x00 bytes")
|
||||||
|
|
||||||
|
# Test with magic library if available
|
||||||
|
if ra_aid.tools.memory.magic:
|
||||||
|
# Test real implementation with binary file
|
||||||
|
is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file))
|
||||||
|
assert is_binary, "File with null bytes should be identified as binary"
|
||||||
|
|
||||||
|
# Test fallback implementation
|
||||||
|
# Mock magic to be None to force fallback implementation
|
||||||
|
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
|
||||||
|
|
||||||
|
# Test fallback with binary file
|
||||||
|
is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file))
|
||||||
|
assert (
|
||||||
|
is_binary
|
||||||
|
), "File with null bytes should be identified as binary with fallback method"
|
||||||
|
|
|
||||||
|
|
@ -138,13 +138,13 @@ def test_verify_fix(tmp_path):
|
||||||
# Create a .ra-aid directory inside the temporary directory
|
# Create a .ra-aid directory inside the temporary directory
|
||||||
ra_aid_dir = tmp_path / ".ra-aid"
|
ra_aid_dir = tmp_path / ".ra-aid"
|
||||||
ra_aid_dir.mkdir()
|
ra_aid_dir.mkdir()
|
||||||
|
|
||||||
# Check that is_new_project() returns True (only .ra-aid directory)
|
# Check that is_new_project() returns True (only .ra-aid directory)
|
||||||
assert is_new_project(str(tmp_path)) is True
|
assert is_new_project(str(tmp_path)) is True
|
||||||
|
|
||||||
# Add a README.md file to the directory
|
# Add a README.md file to the directory
|
||||||
readme_file = tmp_path / "README.md"
|
readme_file = tmp_path / "README.md"
|
||||||
readme_file.write_text("# Test Project")
|
readme_file.write_text("# Test Project")
|
||||||
|
|
||||||
# Check that is_new_project() now returns False (has actual content)
|
# Check that is_new_project() now returns False (has actual content)
|
||||||
assert is_new_project(str(tmp_path)) is False
|
assert is_new_project(str(tmp_path)) is False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue