From 36e4004db04bbf3520dc7a41054e871ce937c302 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Mon, 3 Mar 2025 16:58:03 -0500 Subject: [PATCH] use context vars for key facts repo --- ra_aid/__main__.py | 413 +++++++++--------- ra_aid/agent_utils.py | 52 ++- ra_aid/agents/key_facts_gc_agent.py | 63 ++- ra_aid/agents/key_snippets_gc_agent.py | 11 +- .../repositories/key_fact_repository.py | 103 ++++- .../repositories/key_snippet_repository.py | 19 +- ra_aid/tools/agent.py | 97 +++- ra_aid/tools/expert.py | 23 +- ra_aid/tools/memory.py | 55 +-- .../database/test_key_fact_repository.py | 67 ++- tests/ra_aid/tools/test_agent.py | 35 +- tests/ra_aid/tools/test_memory.py | 74 ++-- 12 files changed, 652 insertions(+), 360 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 3239c0a..fd1deb2 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -42,7 +42,7 @@ from ra_aid.config import ( DEFAULT_TEST_CMD_TIMEOUT, VALID_PROVIDERS, ) -from ra_aid.database.repositories.key_fact_repository import KeyFactRepository +from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict @@ -393,251 +393,256 @@ def main(): except Exception as e: logger.error(f"Database migration error: {str(e)}") - # Check dependencies before proceeding - check_dependencies() + # Initialize repositories with database connection + with KeyFactRepositoryManager(db) as key_fact_repo: + # This initializes the repository and makes it available via get_key_fact_repository() + logger.debug("Initialized KeyFactRepository") - ( - expert_enabled, - expert_missing, - web_research_enabled, - web_research_missing, - ) = validate_environment(args) # Will exit if main env vars missing - logger.debug("Environment validation successful") + # Check dependencies before proceeding + check_dependencies() - # Validate model configuration early - model_config = models_params.get(args.provider, {}).get( - args.model or "", {} - ) - supports_temperature = model_config.get( - "supports_temperature", - args.provider - in [ - "anthropic", - "openai", - "openrouter", - "openai-compatible", - "deepseek", - ], - ) + ( + expert_enabled, + expert_missing, + web_research_enabled, + web_research_missing, + ) = validate_environment(args) # Will exit if main env vars missing + logger.debug("Environment validation successful") - if supports_temperature and args.temperature is None: - args.temperature = model_config.get("default_temperature") - if args.temperature is None: - cpm( - f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}." + # Validate model configuration early + model_config = models_params.get(args.provider, {}).get( + args.model or "", {} + ) + supports_temperature = model_config.get( + "supports_temperature", + args.provider + in [ + "anthropic", + "openai", + "openrouter", + "openai-compatible", + "deepseek", + ], + ) + + if supports_temperature and args.temperature is None: + args.temperature = model_config.get("default_temperature") + if args.temperature is None: + cpm( + f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}." + ) + args.temperature = DEFAULT_TEMPERATURE + logger.debug( + f"Using default temperature {args.temperature} for model {args.model}" + ) + + status = build_status(args, expert_enabled, web_research_enabled) + + console.print( + Panel( + status, + title=f"RA.Aid v{__version__}", + border_style="bright_blue", + padding=(0, 1), ) - args.temperature = DEFAULT_TEMPERATURE - logger.debug( - f"Using default temperature {args.temperature} for model {args.model}" ) - status = build_status(args, expert_enabled, web_research_enabled) + # Handle chat mode + if args.chat: + # Initialize chat model with default provider/model + chat_model = initialize_llm( + args.provider, args.model, temperature=args.temperature + ) - console.print( - Panel( - status, - title=f"RA.Aid v{__version__}", - border_style="bright_blue", - padding=(0, 1), - ) - ) + if args.research_only: + print_error("Chat mode cannot be used with --research-only") + sys.exit(1) - # Handle chat mode - if args.chat: - # Initialize chat model with default provider/model - chat_model = initialize_llm( - args.provider, args.model, temperature=args.temperature - ) + print_stage_header("Chat Mode") - if args.research_only: - print_error("Chat mode cannot be used with --research-only") + # Get project info + try: + project_info = get_project_info(".", file_limit=2000) + formatted_project_info = format_project_info(project_info) + except Exception as e: + logger.warning(f"Failed to get project info: {e}") + formatted_project_info = "" + + # Get initial request from user + initial_request = ask_human.invoke( + {"question": "What would you like help with?"} + ) + + # Record chat input in database (redundant as ask_human already records it, + # but needed in case the ask_human implementation changes) + try: + from ra_aid.database.repositories.human_input_repository import HumanInputRepository + human_input_repo = HumanInputRepository(db) + human_input_repo.create(content=initial_request, source='chat') + human_input_repo.garbage_collect() + except Exception as e: + logger.error(f"Failed to record initial chat input: {str(e)}") + + # Get working directory and current date + working_directory = os.getcwd() + current_date = datetime.now().strftime("%Y-%m-%d") + + # Run chat agent with CHAT_PROMPT + config = { + "configurable": {"thread_id": str(uuid.uuid4())}, + "recursion_limit": args.recursion_limit, + "chat_mode": True, + "cowboy_mode": args.cowboy_mode, + "hil": True, # Always true in chat mode + "web_research_enabled": web_research_enabled, + "initial_request": initial_request, + "limit_tokens": args.disable_limit_tokens, + } + + # Store config in global memory + _global_memory["config"] = config + _global_memory["config"]["provider"] = args.provider + _global_memory["config"]["model"] = args.model + _global_memory["config"]["expert_provider"] = args.expert_provider + _global_memory["config"]["expert_model"] = args.expert_model + _global_memory["config"]["temperature"] = args.temperature + + # Set modification tools based on use_aider flag + set_modification_tools(args.use_aider) + + # Create chat agent with appropriate tools + chat_agent = create_agent( + chat_model, + get_chat_tools( + expert_enabled=expert_enabled, + web_research_enabled=web_research_enabled, + ), + checkpointer=MemorySaver(), + ) + + # Run chat agent and exit + run_agent_with_retry( + chat_agent, + CHAT_PROMPT.format( + initial_request=initial_request, + web_research_section=( + WEB_RESEARCH_PROMPT_SECTION_CHAT + if web_research_enabled + else "" + ), + working_directory=working_directory, + current_date=current_date, + key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()), + key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()), + project_info=formatted_project_info, + ), + config, + ) + return + + # Validate message is provided + if not args.message: + print_error("--message is required") sys.exit(1) - print_stage_header("Chat Mode") - - # Get project info - try: - project_info = get_project_info(".", file_limit=2000) - formatted_project_info = format_project_info(project_info) - except Exception as e: - logger.warning(f"Failed to get project info: {e}") - formatted_project_info = "" - - # Get initial request from user - initial_request = ask_human.invoke( - {"question": "What would you like help with?"} - ) + base_task = args.message - # Record chat input in database (redundant as ask_human already records it, - # but needed in case the ask_human implementation changes) + # Record CLI input in database try: from ra_aid.database.repositories.human_input_repository import HumanInputRepository - human_input_repo = HumanInputRepository() - human_input_repo.create(content=initial_request, source='chat') + human_input_repo = HumanInputRepository(db) + human_input_repo.create(content=base_task, source='cli') + # Run garbage collection to ensure we don't exceed 100 inputs human_input_repo.garbage_collect() + logger.debug(f"Recorded CLI input: {base_task}") except Exception as e: - logger.error(f"Failed to record initial chat input: {str(e)}") - - # Get working directory and current date - working_directory = os.getcwd() - current_date = datetime.now().strftime("%Y-%m-%d") - - # Run chat agent with CHAT_PROMPT + logger.error(f"Failed to record CLI input: {str(e)}") config = { "configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": args.recursion_limit, - "chat_mode": True, + "research_only": args.research_only, "cowboy_mode": args.cowboy_mode, - "hil": True, # Always true in chat mode "web_research_enabled": web_research_enabled, - "initial_request": initial_request, + "aider_config": args.aider_config, + "use_aider": args.use_aider, "limit_tokens": args.disable_limit_tokens, + "auto_test": args.auto_test, + "test_cmd": args.test_cmd, + "max_test_cmd_retries": args.max_test_cmd_retries, + "experimental_fallback_handler": args.experimental_fallback_handler, + "test_cmd_timeout": args.test_cmd_timeout, } - # Store config in global memory + # Store config in global memory for access by is_informational_query _global_memory["config"] = config + + # Store base provider/model configuration _global_memory["config"]["provider"] = args.provider _global_memory["config"]["model"] = args.model + + # Store expert provider/model (no fallback) _global_memory["config"]["expert_provider"] = args.expert_provider _global_memory["config"]["expert_model"] = args.expert_model + + # Store planner config with fallback to base values + _global_memory["config"]["planner_provider"] = ( + args.planner_provider or args.provider + ) + _global_memory["config"]["planner_model"] = args.planner_model or args.model + + # Store research config with fallback to base values + _global_memory["config"]["research_provider"] = ( + args.research_provider or args.provider + ) + _global_memory["config"]["research_model"] = ( + args.research_model or args.model + ) + + # Store temperature in global config _global_memory["config"]["temperature"] = args.temperature # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) - # Create chat agent with appropriate tools - chat_agent = create_agent( - chat_model, - get_chat_tools( - expert_enabled=expert_enabled, - web_research_enabled=web_research_enabled, - ), - checkpointer=MemorySaver(), + # Run research stage + print_stage_header("Research Stage") + + # Initialize research model with potential overrides + research_provider = args.research_provider or args.provider + research_model_name = args.research_model or args.model + research_model = initialize_llm( + research_provider, research_model_name, temperature=args.temperature ) - # Run chat agent and exit - run_agent_with_retry( - chat_agent, - CHAT_PROMPT.format( - initial_request=initial_request, - web_research_section=( - WEB_RESEARCH_PROMPT_SECTION_CHAT - if web_research_enabled - else "" - ), - working_directory=working_directory, - current_date=current_date, - key_facts=format_key_facts_dict(KeyFactRepository().get_facts_dict()), - key_snippets=format_key_snippets_dict(KeySnippetRepository().get_snippets_dict()), - project_info=formatted_project_info, - ), - config, - ) - return - - # Validate message is provided - if not args.message: - print_error("--message is required") - sys.exit(1) - - base_task = args.message - - # Record CLI input in database - try: - from ra_aid.database.repositories.human_input_repository import HumanInputRepository - human_input_repo = HumanInputRepository() - human_input_repo.create(content=base_task, source='cli') - # Run garbage collection to ensure we don't exceed 100 inputs - human_input_repo.garbage_collect() - logger.debug(f"Recorded CLI input: {base_task}") - except Exception as e: - logger.error(f"Failed to record CLI input: {str(e)}") - config = { - "configurable": {"thread_id": str(uuid.uuid4())}, - "recursion_limit": args.recursion_limit, - "research_only": args.research_only, - "cowboy_mode": args.cowboy_mode, - "web_research_enabled": web_research_enabled, - "aider_config": args.aider_config, - "use_aider": args.use_aider, - "limit_tokens": args.disable_limit_tokens, - "auto_test": args.auto_test, - "test_cmd": args.test_cmd, - "max_test_cmd_retries": args.max_test_cmd_retries, - "experimental_fallback_handler": args.experimental_fallback_handler, - "test_cmd_timeout": args.test_cmd_timeout, - } - - # Store config in global memory for access by is_informational_query - _global_memory["config"] = config - - # Store base provider/model configuration - _global_memory["config"]["provider"] = args.provider - _global_memory["config"]["model"] = args.model - - # Store expert provider/model (no fallback) - _global_memory["config"]["expert_provider"] = args.expert_provider - _global_memory["config"]["expert_model"] = args.expert_model - - # Store planner config with fallback to base values - _global_memory["config"]["planner_provider"] = ( - args.planner_provider or args.provider - ) - _global_memory["config"]["planner_model"] = args.planner_model or args.model - - # Store research config with fallback to base values - _global_memory["config"]["research_provider"] = ( - args.research_provider or args.provider - ) - _global_memory["config"]["research_model"] = ( - args.research_model or args.model - ) - - # Store temperature in global config - _global_memory["config"]["temperature"] = args.temperature - - # Set modification tools based on use_aider flag - set_modification_tools(args.use_aider) - - # Run research stage - print_stage_header("Research Stage") - - # Initialize research model with potential overrides - research_provider = args.research_provider or args.provider - research_model_name = args.research_model or args.model - research_model = initialize_llm( - research_provider, research_model_name, temperature=args.temperature - ) - - run_research_agent( - base_task, - research_model, - expert_enabled=expert_enabled, - research_only=args.research_only, - hil=args.hil, - memory=research_memory, - config=config, - ) - - # Proceed with planning and implementation if not an informational query - if not is_informational_query(): - # Initialize planning model with potential overrides - planner_provider = args.planner_provider or args.provider - planner_model_name = args.planner_model or args.model - planning_model = initialize_llm( - planner_provider, planner_model_name, temperature=args.temperature - ) - - # Run planning agent - run_planning_agent( + run_research_agent( base_task, - planning_model, + research_model, expert_enabled=expert_enabled, + research_only=args.research_only, hil=args.hil, - memory=planning_memory, + memory=research_memory, config=config, ) + # Proceed with planning and implementation if not an informational query + if not is_informational_query(): + # Initialize planning model with potential overrides + planner_provider = args.planner_provider or args.provider + planner_model_name = args.planner_model or args.model + planning_model = initialize_llm( + planner_provider, planner_model_name, temperature=args.temperature + ) + + # Run planning agent + run_planning_agent( + base_task, + planning_model, + expert_enabled=expert_enabled, + hil=args.hil, + memory=planning_memory, + config=config, + ) + except (KeyboardInterrupt, AgentInterrupt): print() print(" 👋 Bye!") diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 59b3eed..a3af9e8 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -84,8 +84,8 @@ from ra_aid.tool_configs import ( get_web_research_tools, ) from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command -from ra_aid.database.repositories.key_fact_repository import KeyFactRepository -from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository +from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository +from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict @@ -100,10 +100,8 @@ console = Console() logger = get_logger(__name__) -# Initialize repositories -key_fact_repository = KeyFactRepository() -key_snippet_repository = KeySnippetRepository() -human_input_repository = HumanInputRepository() +# Import repositories using get_* functions +from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository @tool @@ -391,7 +389,11 @@ def run_research_agent( else "" ) - key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict()) + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" code_snippets = _global_memory.get("code_snippets", "") related_files = _global_memory.get("related_files", "") @@ -400,6 +402,7 @@ def run_research_agent( # Get the last human input, if it exists base_task = base_task_or_query + human_input_repository = HumanInputRepository() recent_inputs = human_input_repository.get_recent(1) if recent_inputs and len(recent_inputs) > 0: last_human_input = recent_inputs[0].content @@ -537,7 +540,11 @@ def run_web_research_agent( expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" - key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict()) + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" code_snippets = _global_memory.get("code_snippets", "") related_files = _global_memory.get("related_files", "") @@ -647,6 +654,20 @@ def run_planning_agent( current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") working_directory = os.getcwd() + # Make sure key_facts is defined before using it + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" + + # Make sure key_snippets is defined before using it + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" + planning_prompt = PLANNING_PROMPT.format( current_date=current_date, working_directory=working_directory, @@ -657,8 +678,8 @@ def run_planning_agent( project_info=formatted_project_info, research_notes=get_memory_value("research_notes"), related_files="\n".join(get_related_files()), - key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()), - key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + key_facts=key_facts, + key_snippets=key_snippets, work_log=get_memory_value("work_log"), research_only_note=( "" @@ -751,6 +772,13 @@ def run_task_implementation_agent( current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") working_directory = os.getcwd() + # Make sure key_facts is defined before using it + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" + prompt = IMPLEMENTATION_PROMPT.format( current_date=current_date, working_directory=working_directory, @@ -759,8 +787,8 @@ def run_task_implementation_agent( tasks=tasks, plan=plan, related_files=related_files, - key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()), - key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + key_facts=key_facts, + key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), research_notes=get_memory_value("research_notes"), work_log=get_memory_value("work_log"), expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", diff --git a/ra_aid/agents/key_facts_gc_agent.py b/ra_aid/agents/key_facts_gc_agent.py index f8a5543..b58adbc 100644 --- a/ra_aid/agents/key_facts_gc_agent.py +++ b/ra_aid/agents/key_facts_gc_agent.py @@ -6,6 +6,7 @@ facts when the total number exceeds a specified threshold. The agent evaluates a key facts and deletes the least valuable ones to keep the database clean and relevant. """ +import logging from typing import List from langchain_core.tools import tool @@ -13,8 +14,10 @@ from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel +logger = logging.getLogger(__name__) + from ra_aid.agent_utils import create_agent, run_agent_with_retry -from ra_aid.database.repositories.key_fact_repository import KeyFactRepository +from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.llm import initialize_llm from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT @@ -22,7 +25,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory console = Console() -key_fact_repository = KeyFactRepository() human_input_repository = HumanInputRepository() @@ -51,24 +53,30 @@ def delete_key_facts(fact_ids: List[int]) -> str: console.print(f"Warning: Could not retrieve current human input: {str(e)}") for fact_id in fact_ids: - # Get the fact first to display information - fact = key_fact_repository.get(fact_id) - if fact: - # Check if this fact is associated with the current human input - if current_human_input_id is not None and fact.human_input_id == current_human_input_id: - protected_facts.append((fact_id, fact.content)) - continue + try: + # Get the fact first to display information + fact = get_key_fact_repository().get(fact_id) + if fact: + # Check if this fact is associated with the current human input + if current_human_input_id is not None and fact.human_input_id == current_human_input_id: + protected_facts.append((fact_id, fact.content)) + continue + + # Delete the fact if it's not protected + was_deleted = get_key_fact_repository().delete(fact_id) + if was_deleted: + deleted_facts.append((fact_id, fact.content)) + log_work_event(f"Deleted fact {fact_id}.") + else: + failed_facts.append(fact_id) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + failed_facts.append(fact_id) + except Exception as e: + # For any other exceptions, log and continue + logger.error(f"Error processing fact {fact_id}: {str(e)}") + failed_facts.append(fact_id) - # Delete the fact if it's not protected - was_deleted = key_fact_repository.delete(fact_id) - if was_deleted: - deleted_facts.append((fact_id, fact.content)) - log_work_event(f"Deleted fact {fact_id}.") - else: - failed_facts.append(fact_id) - else: - not_found_facts.append(fact_id) - # Prepare result message result_parts = [] if deleted_facts: @@ -104,8 +112,13 @@ def run_key_facts_gc_agent() -> None: Facts associated with the current human input are excluded from deletion. """ # Get the count of key facts - facts = key_fact_repository.get_all() - fact_count = len(facts) + try: + facts = get_key_fact_repository().get_all() + fact_count = len(facts) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red")) + return # Exit the function if we can't access the repository # Display status panel with fact count included console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection")) @@ -161,8 +174,12 @@ def run_key_facts_gc_agent() -> None: run_agent_with_retry(agent, prompt, agent_config) # Get updated count - updated_facts = key_fact_repository.get_all() - updated_count = len(updated_facts) + try: + updated_facts = get_key_fact_repository().get_all() + updated_count = len(updated_facts) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository for update count: {str(e)}") + updated_count = "unknown" # Show info panel with updated count and protected facts count protected_count = len(protected_facts) diff --git a/ra_aid/agents/key_snippets_gc_agent.py b/ra_aid/agents/key_snippets_gc_agent.py index ee4e390..2e13538 100644 --- a/ra_aid/agents/key_snippets_gc_agent.py +++ b/ra_aid/agents/key_snippets_gc_agent.py @@ -14,7 +14,7 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agent_utils import create_agent, run_agent_with_retry -from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository +from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.llm import initialize_llm from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT @@ -22,7 +22,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory console = Console() -key_snippet_repository = KeySnippetRepository() human_input_repository = HumanInputRepository() @@ -53,7 +52,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str: for snippet_id in snippet_ids: # Get the snippet first to capture filepath for the message - snippet = key_snippet_repository.get(snippet_id) + snippet = get_key_snippet_repository().get(snippet_id) if snippet: filepath = snippet.filepath @@ -63,7 +62,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str: continue # Delete from database if not protected - success = key_snippet_repository.delete(snippet_id) + success = get_key_snippet_repository().delete(snippet_id) if success: success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}" console.print( @@ -110,7 +109,7 @@ def run_key_snippets_gc_agent() -> None: Snippets associated with the current human input are excluded from deletion. """ # Get the count of key snippets - snippets = key_snippet_repository.get_all() + snippets = get_key_snippet_repository().get_all() snippet_count = len(snippets) # Display status panel with snippet count included @@ -179,7 +178,7 @@ def run_key_snippets_gc_agent() -> None: run_agent_with_retry(agent, prompt, agent_config) # Get updated count - updated_snippets = key_snippet_repository.get_all() + updated_snippets = get_key_snippet_repository().get_all() updated_count = len(updated_snippets) # Show info panel with updated count and protected snippets count diff --git a/ra_aid/database/repositories/key_fact_repository.py b/ra_aid/database/repositories/key_fact_repository.py index 1ad75dc..2c29c52 100644 --- a/ra_aid/database/repositories/key_fact_repository.py +++ b/ra_aid/database/repositories/key_fact_repository.py @@ -6,15 +6,94 @@ following the repository pattern for data access abstraction. """ from typing import Dict, List, Optional +import contextvars +from contextlib import contextmanager import peewee -from ra_aid.database.connection import get_db -from ra_aid.database.models import KeyFact, initialize_database +from ra_aid.database.models import KeyFact from ra_aid.logging_config import get_logger logger = get_logger(__name__) +# Create contextvar to hold the KeyFactRepository instance +key_fact_repo_var = contextvars.ContextVar("key_fact_repo", default=None) + + +class KeyFactRepositoryManager: + """ + Context manager for KeyFactRepository. + + This class provides a context manager interface for KeyFactRepository, + using the contextvars approach for thread safety. + + Example: + with DatabaseManager() as db: + with KeyFactRepositoryManager(db) as repo: + # Use the repository + fact = repo.create("Important fact about the project") + all_facts = repo.get_all() + """ + + def __init__(self, db): + """ + Initialize the KeyFactRepositoryManager. + + Args: + db: Database connection to use (required) + """ + self.db = db + + def __enter__(self) -> 'KeyFactRepository': + """ + Initialize the KeyFactRepository and return it. + + Returns: + KeyFactRepository: The initialized repository + """ + repo = KeyFactRepository(self.db) + key_fact_repo_var.set(repo) + return repo + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Optional[object], + ) -> None: + """ + Reset the repository when exiting the context. + + Args: + exc_type: The exception type if an exception was raised + exc_val: The exception value if an exception was raised + exc_tb: The traceback if an exception was raised + """ + # Reset the contextvar to None + key_fact_repo_var.set(None) + + # Don't suppress exceptions + return False + + +def get_key_fact_repository() -> 'KeyFactRepository': + """ + Get the current KeyFactRepository instance. + + Returns: + KeyFactRepository: The current repository instance + + Raises: + RuntimeError: If no repository has been initialized with KeyFactRepositoryManager + """ + repo = key_fact_repo_var.get() + if repo is None: + raise RuntimeError( + "No KeyFactRepository available. " + "Make sure to initialize one with KeyFactRepositoryManager first." + ) + return repo + class KeyFactRepository: """ @@ -24,18 +103,21 @@ class KeyFactRepository: abstracting the database access details from the business logic. Example: - repo = KeyFactRepository() - fact = repo.create("Important fact about the project") - all_facts = repo.get_all() + with DatabaseManager() as db: + with KeyFactRepositoryManager(db) as repo: + fact = repo.create("Important fact about the project") + all_facts = repo.get_all() """ - def __init__(self, db=None): + def __init__(self, db): """ - Initialize the repository with an optional database connection. + Initialize the repository with a database connection. Args: - db: Optional database connection to use. If None, will use initialize_database() + db: Database connection to use (required) """ + if db is None: + raise ValueError("Database connection is required for KeyFactRepository") self.db = db def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact: @@ -53,7 +135,6 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error creating the fact """ try: - db = self.db if self.db is not None else initialize_database() fact = KeyFact.create(content=content, human_input_id=human_input_id) logger.debug(f"Created key fact ID {fact.id}: {content}") return fact @@ -75,7 +156,6 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - db = self.db if self.db is not None else initialize_database() return KeyFact.get_or_none(KeyFact.id == fact_id) except peewee.DatabaseError as e: logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}") @@ -96,7 +176,6 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error updating the fact """ try: - db = self.db if self.db is not None else initialize_database() # First check if the fact exists fact = self.get(fact_id) if not fact: @@ -126,7 +205,6 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error deleting the fact """ try: - db = self.db if self.db is not None else initialize_database() # First check if the fact exists fact = self.get(fact_id) if not fact: @@ -152,7 +230,6 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - db = self.db if self.db is not None else initialize_database() return list(KeyFact.select().order_by(KeyFact.id)) except peewee.DatabaseError as e: logger.error(f"Failed to fetch all key facts: {str(e)}") diff --git a/ra_aid/database/repositories/key_snippet_repository.py b/ra_aid/database/repositories/key_snippet_repository.py index 24bbdaa..7991838 100644 --- a/ra_aid/database/repositories/key_snippet_repository.py +++ b/ra_aid/database/repositories/key_snippet_repository.py @@ -214,4 +214,21 @@ class KeySnippetRepository: } except peewee.DatabaseError as e: logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}") - raise \ No newline at end of file + raise + + +# Global singleton instance +_key_snippet_repository = None + + +def get_key_snippet_repository() -> KeySnippetRepository: + """ + Get or create a singleton instance of KeySnippetRepository. + + Returns: + KeySnippetRepository: Singleton instance of the repository + """ + global _key_snippet_repository + if _key_snippet_repository is None: + _key_snippet_repository = KeySnippetRepository() + return _key_snippet_repository \ No newline at end of file diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index d70b91e..64e27b8 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -1,6 +1,7 @@ """Tools for spawning and managing sub-agents.""" from typing import Any, Dict, List, Union +import logging from langchain_core.tools import tool from rich.console import Console @@ -12,8 +13,9 @@ from ra_aid.agent_context import ( reset_completion_flags, ) from ra_aid.console.formatting import print_error -from ra_aid.database.repositories.key_fact_repository import KeyFactRepository -from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository +from ra_aid.database.repositories.human_input_repository import HumanInputRepository +from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository +from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.exceptions import AgentInterrupt from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict @@ -31,8 +33,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. RESEARCH_AGENT_RECURSION_LIMIT = 3 console = Console() -key_fact_repository = KeyFactRepository() -key_snippet_repository = KeySnippetRepository() +logger = logging.getLogger(__name__) @tool("request_research") @@ -57,12 +58,24 @@ def request_research(query: str) -> ResearchResult: current_depth = _global_memory.get("agent_depth", 0) if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT: print_error("Maximum research recursion depth reached") + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" + + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" + return { "completion_message": "Research stopped - maximum recursion depth reached", - "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), + "key_facts": key_facts, "related_files": get_related_files(), "research_notes": get_memory_value("research_notes"), - "key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + "key_snippets": key_snippets, "success": False, "reason": "max_depth_exceeded", } @@ -105,12 +118,24 @@ def request_research(query: str) -> ResearchResult: # Clear completion state reset_completion_flags() + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" + + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" + response_data = { "completion_message": completion_message, - "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), + "key_facts": key_facts, "related_files": get_related_files(), "research_notes": get_memory_value("research_notes"), - "key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + "key_snippets": key_snippets, "success": success, "reason": reason, } @@ -171,9 +196,15 @@ def request_web_research(query: str) -> ResearchResult: # Clear completion state reset_completion_flags() + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" + response_data = { "completion_message": completion_message, - "key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + "key_snippets": key_snippets, "research_notes": get_memory_value("research_notes"), "success": success, "reason": reason, @@ -239,12 +270,24 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: # Clear completion state reset_completion_flags() + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" + + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" + response_data = { "completion_message": completion_message, - "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), + "key_facts": key_facts, "related_files": get_related_files(), "research_notes": get_memory_value("research_notes"), - "key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + "key_snippets": key_snippets, "success": success, "reason": reason, } @@ -324,10 +367,22 @@ def request_task_implementation(task_spec: str) -> str: agent_crashed = is_crashed() crash_message = get_crash_message() if agent_crashed else None + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" + + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" + response_data = { - "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), + "key_facts": key_facts, "related_files": get_related_files(), - "key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + "key_snippets": key_snippets, "completion_message": completion_message, "success": success and not agent_crashed, "reason": reason, @@ -444,11 +499,23 @@ def request_implementation(task_spec: str) -> str: agent_crashed = is_crashed() crash_message = get_crash_message() if agent_crashed else None + try: + key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" + + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" + response_data = { "completion_message": completion_message, - "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), + "key_facts": key_facts, "related_files": get_related_files(), - "key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), + "key_snippets": key_snippets, "success": success and not agent_crashed, "reason": reason, "agent_crashed": agent_crashed, diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index ad96d16..6720a70 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -1,4 +1,5 @@ import os +import logging from typing import List from langchain_core.tools import tool @@ -6,8 +7,10 @@ from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel -from ..database.repositories.key_fact_repository import KeyFactRepository -from ..database.repositories.key_snippet_repository import KeySnippetRepository +logger = logging.getLogger(__name__) + +from ..database.repositories.key_fact_repository import get_key_fact_repository +from ..database.repositories.key_snippet_repository import get_key_snippet_repository from ..llm import initialize_expert_llm from ..model_formatters import format_key_facts_dict from ..model_formatters.key_snippets_formatter import format_key_snippets_dict @@ -15,8 +18,6 @@ from .memory import _global_memory, get_memory_value console = Console() _model = None -key_fact_repository = KeyFactRepository() -key_snippet_repository = KeySnippetRepository() def get_model(): @@ -154,10 +155,18 @@ def ask_expert(question: str) -> str: file_paths = list(_global_memory["related_files"].values()) related_contents = read_related_files(file_paths) # Get key snippets directly from repository and format using the formatter - key_snippets = format_key_snippets_dict(key_snippet_repository.get_snippets_dict()) + try: + key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) + except RuntimeError as e: + logger.error(f"Failed to access key snippet repository: {str(e)}") + key_snippets = "" # Get key facts directly from repository and format using the formatter - facts_dict = key_fact_repository.get_facts_dict() - key_facts = format_key_facts_dict(facts_dict) + try: + facts_dict = get_key_fact_repository().get_facts_dict() + key_facts = format_key_facts_dict(facts_dict) + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + key_facts = "" research_notes = get_memory_value("research_notes") # Build display query (just question) diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 2152fe1..85d6bcd 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -17,8 +17,8 @@ from ra_aid.agent_context import ( mark_should_exit, mark_task_completed, ) -from ra_aid.database.repositories.key_fact_repository import KeyFactRepository -from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository +from ra_aid.database.repositories.key_fact_repository import KeyFactRepository, get_key_fact_repository +from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository, get_key_snippet_repository from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.model_formatters import key_snippets_formatter from ra_aid.logging_config import get_logger @@ -40,14 +40,8 @@ class SnippetInfo(TypedDict): console = Console() -# Initialize repository for key facts -key_fact_repository = KeyFactRepository() - -# Initialize repository for key snippets -key_snippet_repository = KeySnippetRepository() - -# Initialize repository for human inputs -human_input_repository = HumanInputRepository() +# Import repositories using the get_* functions +from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository # Global memory store _global_memory: Dict[str, Any] = { @@ -120,17 +114,23 @@ def emit_key_facts(facts: List[str]) -> str: # Try to get the latest human input human_input_id = None + human_input_repo = HumanInputRepository() try: - recent_inputs = human_input_repository.get_recent(1) + recent_inputs = human_input_repo.get_recent(1) if recent_inputs and len(recent_inputs) > 0: human_input_id = recent_inputs[0].id except Exception as e: logger.warning(f"Failed to get recent human input: {str(e)}") for fact in facts: - # Create fact in database using repository - created_fact = key_fact_repository.create(fact, human_input_id=human_input_id) - fact_id = created_fact.id + try: + # Create fact in database using repository + created_fact = get_key_fact_repository().create(fact, human_input_id=human_input_id) + fact_id = created_fact.id + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") + console.print(f"Error storing fact: {str(e)}", style="red") + continue # Display panel with ID console.print( @@ -147,14 +147,17 @@ def emit_key_facts(facts: List[str]) -> str: log_work_event(f"Stored {len(facts)} key facts.") # Check if we need to clean up facts (more than 30) - all_facts = key_fact_repository.get_all() - if len(all_facts) > 30: - # Trigger the key facts cleaner agent - try: - from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent - run_key_facts_gc_agent() - except Exception as e: - logger.error(f"Failed to run key facts cleaner: {str(e)}") + try: + all_facts = get_key_fact_repository().get_all() + if len(all_facts) > 30: + # Trigger the key facts cleaner agent + try: + from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent + run_key_facts_gc_agent() + except Exception as e: + logger.error(f"Failed to run key facts cleaner: {str(e)}") + except RuntimeError as e: + logger.error(f"Failed to access key fact repository: {str(e)}") return "Facts stored." @@ -222,14 +225,15 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str: # Try to get the latest human input human_input_id = None try: - recent_inputs = human_input_repository.get_recent(1) + human_input_repo = HumanInputRepository() + recent_inputs = human_input_repo.get_recent(1) if recent_inputs and len(recent_inputs) > 0: human_input_id = recent_inputs[0].id except Exception as e: logger.warning(f"Failed to get recent human input: {str(e)}") # Create a new key snippet in the database - key_snippet = key_snippet_repository.create( + key_snippet = get_key_snippet_repository().create( filepath=snippet_info["filepath"], line_number=snippet_info["line_number"], snippet=snippet_info["snippet"], @@ -266,7 +270,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str: log_work_event(f"Stored code snippet #{snippet_id}.") # Check if we need to clean up snippets (more than 20) - all_snippets = key_snippet_repository.get_all() + all_snippets = get_key_snippet_repository().get_all() if len(all_snippets) > 20: # Trigger the key snippets cleaner agent try: @@ -279,7 +283,6 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str: - @tool("swap_task_order") def swap_task_order(id1: int, id2: int) -> str: """Swap the order of two tasks in global memory by their IDs. diff --git a/tests/ra_aid/database/test_key_fact_repository.py b/tests/ra_aid/database/test_key_fact_repository.py index 6e5e31a..d97a455 100644 --- a/tests/ra_aid/database/test_key_fact_repository.py +++ b/tests/ra_aid/database/test_key_fact_repository.py @@ -9,7 +9,12 @@ import peewee from ra_aid.database.connection import DatabaseManager, db_var from ra_aid.database.models import KeyFact, BaseModel -from ra_aid.database.repositories.key_fact_repository import KeyFactRepository +from ra_aid.database.repositories.key_fact_repository import ( + KeyFactRepository, + KeyFactRepositoryManager, + get_key_fact_repository, + key_fact_repo_var +) @pytest.fixture @@ -41,6 +46,19 @@ def cleanup_db(): db_var.set(None) +@pytest.fixture +def cleanup_repo(): + """Reset the repository contextvar after each test.""" + # Reset before the test + key_fact_repo_var.set(None) + + # Run the test + yield + + # Reset after the test + key_fact_repo_var.set(None) + + @pytest.fixture def setup_db(cleanup_db): """Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database.""" @@ -195,4 +213,49 @@ def test_get_facts_dict(setup_db): # Verify each fact is in the dictionary with the correct content for fact in facts: assert fact.id in facts_dict - assert facts_dict[fact.id] == fact.content \ No newline at end of file + assert facts_dict[fact.id] == fact.content + + +def test_repository_init_without_db(): + """Test that KeyFactRepository raises an error when initialized without a db parameter.""" + # Attempt to create a repository without a database connection + with pytest.raises(ValueError) as excinfo: + KeyFactRepository(db=None) + + # Verify the correct error message + assert "Database connection is required" in str(excinfo.value) + + +def test_key_fact_repository_manager(setup_db, cleanup_repo): + """Test the KeyFactRepositoryManager context manager.""" + # Use the context manager to create a repository + with KeyFactRepositoryManager(setup_db) as repo: + # Verify the repository was created correctly + assert isinstance(repo, KeyFactRepository) + assert repo.db is setup_db + + # Verify we can use the repository + content = "Test fact via context manager" + fact = repo.create(content) + assert fact.id is not None + assert fact.content == content + + # Verify we can get the repository using get_key_fact_repository + repo_from_var = get_key_fact_repository() + assert repo_from_var is repo + + # Verify the repository was removed from the context var + with pytest.raises(RuntimeError) as excinfo: + get_key_fact_repository() + + assert "No KeyFactRepository available" in str(excinfo.value) + + +def test_get_key_fact_repository_when_not_set(cleanup_repo): + """Test that get_key_fact_repository raises an error when no repository is in context.""" + # Attempt to get the repository when none exists + with pytest.raises(RuntimeError) as excinfo: + get_key_fact_repository() + + # Verify the correct error message + assert "No KeyFactRepository available" in str(excinfo.value) \ No newline at end of file diff --git a/tests/ra_aid/tools/test_agent.py b/tests/ra_aid/tools/test_agent.py index d07b6e0..bc2a065 100644 --- a/tests/ra_aid/tools/test_agent.py +++ b/tests/ra_aid/tools/test_agent.py @@ -36,9 +36,11 @@ def reset_memory(): @pytest.fixture def mock_functions(): """Mock functions used in agent.py""" - with patch('ra_aid.tools.agent.key_fact_repository') as mock_fact_repo, \ + mock_fact_repo = MagicMock() + mock_snippet_repo = MagicMock() + with patch('ra_aid.tools.agent.get_key_fact_repository', return_value=mock_fact_repo) as mock_get_fact_repo, \ patch('ra_aid.tools.agent.format_key_facts_dict') as mock_fact_formatter, \ - patch('ra_aid.tools.agent.key_snippet_repository') as mock_snippet_repo, \ + patch('ra_aid.tools.agent.get_key_snippet_repository', return_value=mock_snippet_repo) as mock_get_snippet_repo, \ patch('ra_aid.tools.agent.format_key_snippets_dict') as mock_snippet_formatter, \ patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \ patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \ @@ -60,8 +62,8 @@ def mock_functions(): # Return all mocks as a dictionary yield { - 'key_fact_repository': mock_fact_repo, - 'key_snippet_repository': mock_snippet_repo, + 'get_key_fact_repository': mock_get_fact_repo, + 'get_key_snippet_repository': mock_get_snippet_repo, 'format_key_facts_dict': mock_fact_formatter, 'format_key_snippets_dict': mock_snippet_formatter, 'initialize_llm': mock_llm, @@ -81,11 +83,12 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions) result = request_research("test query") # Verify repository was called - mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + mock_functions['get_key_fact_repository'].assert_called_once() + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once() # Verify formatter was called with repository results mock_functions['format_key_facts_dict'].assert_called_once_with( - mock_functions['key_fact_repository'].get_facts_dict.return_value + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value ) # Verify formatted facts are used in response @@ -105,11 +108,12 @@ def test_request_research_max_depth(reset_memory, mock_functions): result = request_research("test query") # Verify repository was called - mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + mock_functions['get_key_fact_repository'].assert_called_once() + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once() # Verify formatter was called with repository results mock_functions['format_key_facts_dict'].assert_called_once_with( - mock_functions['key_fact_repository'].get_facts_dict.return_value + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value ) # Verify formatted facts are used in response @@ -128,11 +132,12 @@ def test_request_research_and_implementation_uses_key_fact_repository(reset_memo result = request_research_and_implementation("test query") # Verify repository was called - mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + mock_functions['get_key_fact_repository'].assert_called_once() + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once() # Verify formatter was called with repository results mock_functions['format_key_facts_dict'].assert_called_once_with( - mock_functions['key_fact_repository'].get_facts_dict.return_value + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value ) # Verify formatted facts are used in response @@ -151,11 +156,12 @@ def test_request_implementation_uses_key_fact_repository(reset_memory, mock_func result = request_implementation("test task") # Verify repository was called - mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + mock_functions['get_key_fact_repository'].assert_called_once() + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once() # Verify formatter was called with repository results mock_functions['format_key_facts_dict'].assert_called_once_with( - mock_functions['key_fact_repository'].get_facts_dict.return_value + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value ) # Check that the formatted key facts are included in the response @@ -174,11 +180,12 @@ def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock result = request_task_implementation("test task") # Verify repository was called - mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + mock_functions['get_key_fact_repository'].assert_called_once() + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once() # Verify formatter was called with repository results mock_functions['format_key_facts_dict'].assert_called_once_with( - mock_functions['key_fact_repository'].get_facts_dict.return_value + mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value ) # Check that the formatted key facts are included in the response diff --git a/tests/ra_aid/tools/test_memory.py b/tests/ra_aid/tools/test_memory.py index 24effe1..860de40 100644 --- a/tests/ra_aid/tools/test_memory.py +++ b/tests/ra_aid/tools/test_memory.py @@ -16,12 +16,12 @@ from ra_aid.tools.memory import ( get_memory_value, get_related_files, get_work_log, - key_fact_repository, - key_snippet_repository, log_work_event, reset_work_log, swap_task_order, ) +from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository +from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.connection import DatabaseManager from ra_aid.database.models import KeyFact @@ -60,7 +60,7 @@ def in_memory_db(): @pytest.fixture(autouse=True) def mock_repository(): """Mock the KeyFactRepository to avoid database operations during tests""" - with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo: + with patch('ra_aid.tools.memory.get_key_fact_repository') as mock_repo: # Setup the mock repository to behave like the original, but using memory facts = {} # Local in-memory storage fact_id_counter = 0 @@ -79,12 +79,12 @@ def mock_repository(): facts[fact_id_counter] = fact fact_id_counter += 1 return fact - mock_repo.create.side_effect = mock_create + mock_repo.return_value.create.side_effect = mock_create # Mock get method def mock_get(fact_id): return facts.get(fact_id) - mock_repo.get.side_effect = mock_get + mock_repo.return_value.get.side_effect = mock_get # Mock delete method def mock_delete(fact_id): @@ -92,17 +92,17 @@ def mock_repository(): del facts[fact_id] return True return False - mock_repo.delete.side_effect = mock_delete + mock_repo.return_value.delete.side_effect = mock_delete # Mock get_facts_dict method def mock_get_facts_dict(): return {fact_id: fact.content for fact_id, fact in facts.items()} - mock_repo.get_facts_dict.side_effect = mock_get_facts_dict + mock_repo.return_value.get_facts_dict.side_effect = mock_get_facts_dict # Mock get_all method def mock_get_all(): return list(facts.values()) - mock_repo.get_all.side_effect = mock_get_all + mock_repo.return_value.get_all.side_effect = mock_get_all yield mock_repo @@ -159,16 +159,16 @@ def mock_key_snippet_repository(): return list(snippets.values()) # Create the actual mocks for both memory.py and key_snippets_gc_agent.py - with patch('ra_aid.tools.memory.key_snippet_repository') as memory_mock_repo, \ - patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository') as agent_mock_repo: + with patch('ra_aid.tools.memory.get_key_snippet_repository') as memory_mock_repo, \ + patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository') as agent_mock_repo: # Setup both mocks with the same implementation for mock_repo in [memory_mock_repo, agent_mock_repo]: - mock_repo.create.side_effect = mock_create - mock_repo.get.side_effect = mock_get - mock_repo.delete.side_effect = mock_delete - mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict - mock_repo.get_all.side_effect = mock_get_all + mock_repo.return_value.create.side_effect = mock_create + mock_repo.return_value.get.side_effect = mock_get + mock_repo.return_value.delete.side_effect = mock_delete + mock_repo.return_value.get_snippets_dict.side_effect = mock_get_snippets_dict + mock_repo.return_value.get_all.side_effect = mock_get_all yield memory_mock_repo @@ -180,7 +180,7 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository): assert result == "Facts stored." # Verify the repository's create method was called - mock_repository.create.assert_called_once_with("First fact", human_input_id=ANY) + mock_repository.return_value.create.assert_called_once_with("First fact", human_input_id=ANY) def test_get_memory_value_other_types(reset_memory): @@ -264,10 +264,10 @@ def test_emit_key_facts(reset_memory, mock_repository): assert result == "Facts stored." # Verify create was called for each fact - assert mock_repository.create.call_count == 3 - mock_repository.create.assert_any_call("First fact", human_input_id=ANY) - mock_repository.create.assert_any_call("Second fact", human_input_id=ANY) - mock_repository.create.assert_any_call("Third fact", human_input_id=ANY) + assert mock_repository.return_value.create.call_count == 3 + mock_repository.return_value.create.assert_any_call("First fact", human_input_id=ANY) + mock_repository.return_value.create.assert_any_call("Second fact", human_input_id=ANY) + mock_repository.return_value.create.assert_any_call("Third fact", human_input_id=ANY) def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository): @@ -278,7 +278,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository): facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None)) # Mock the get_all method to return more than 30 facts - mock_repository.get_all.return_value = facts + mock_repository.return_value.get_all.return_value = facts # Note on testing approach: # Rather than trying to mock the dynamic import which is challenging due to @@ -295,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository): # Verify that mock_repository.get_all was called, # which is the condition that would trigger the GC agent - mock_repository.get_all.assert_called_once() + mock_repository.return_value.get_all.assert_called_once() def test_emit_key_snippet(reset_memory, mock_key_snippet_repository): @@ -315,7 +315,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository): assert result == "Snippet #0 stored." # Verify create was called correctly - mock_key_snippet_repository.create.assert_called_with( + mock_key_snippet_repository.return_value.create.assert_called_with( filepath="test.py", line_number=10, snippet="def test():\n pass", @@ -338,7 +338,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository): assert result == "Snippet #1 stored." # Verify create was called correctly - mock_key_snippet_repository.create.assert_called_with( + mock_key_snippet_repository.return_value.create.assert_called_with( filepath="main.py", line_number=20, snippet="print('hello')", @@ -379,16 +379,16 @@ def test_delete_key_snippets(mock_log_work_event, reset_memory, mock_key_snippet mock_key_snippet_repository.reset_mock() # Test deleting mix of valid and invalid IDs - with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): + with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository): result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]}) # Verify success message assert result == "Snippets deleted." # Verify repository get was called with correct IDs - mock_key_snippet_repository.get.assert_any_call(0) - mock_key_snippet_repository.get.assert_any_call(1) - mock_key_snippet_repository.get.assert_any_call(999) + mock_key_snippet_repository.return_value.get.assert_any_call(0) + mock_key_snippet_repository.return_value.get.assert_any_call(1) + mock_key_snippet_repository.return_value.get.assert_any_call(999) # We skip verifying delete calls because they are prone to test environment issues # The implementation logic will properly delete IDs 0 and 1 but not 999 @@ -410,12 +410,12 @@ def test_delete_key_snippets_empty(mock_log_work_event, reset_memory, mock_key_s mock_key_snippet_repository.reset_mock() # Test with empty list - with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): + with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository): result = delete_key_snippets.invoke({"snippet_ids": []}) assert result == "Snippets deleted." # Verify no call to delete method - mock_key_snippet_repository.delete.assert_not_called() + mock_key_snippet_repository.return_value.delete.assert_not_called() def test_emit_related_files_basic(reset_memory, tmp_path): @@ -458,13 +458,13 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path): new_file.write_text("# New file") # Add initial files - result = emit_related_files.invoke({"files": [str(test_file), str(main_file)]}) - assert result == "Files noted." + result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]}) + assert result1 == "Files noted." _first_id = 0 # ID of test.py # Try adding duplicates - result = emit_related_files.invoke({"files": [str(test_file)]}) - assert result == "Files noted." + result2 = emit_related_files.invoke({"files": [str(test_file)]}) + assert result2 == "Files noted." assert len(_global_memory["related_files"]) == 2 # Count should not increase # Try mix of new and duplicate files @@ -670,7 +670,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn mock_key_snippet_repository.reset_mock() # Delete some but not all snippets (0 and 2) - with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): + with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository): result = delete_key_snippets.invoke({"snippet_ids": [0, 2]}) assert result == "Snippets deleted." @@ -692,7 +692,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn assert result == "Snippet #3 stored." # Verify create was called with correct params - mock_key_snippet_repository.create.assert_called_with( + mock_key_snippet_repository.return_value.create.assert_called_with( filepath=file4, line_number=40, snippet="def func4():\n return False", @@ -704,7 +704,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn mock_key_snippet_repository.reset_mock() # Delete remaining snippets - with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): + with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository): result = delete_key_snippets.invoke({"snippet_ids": [1, 3]}) assert result == "Snippets deleted."