use context vars for key facts repo

This commit is contained in:
AI Christianson 2025-03-03 16:58:03 -05:00
parent dd9af78693
commit 36e4004db0
12 changed files with 652 additions and 360 deletions

View File

@ -42,7 +42,7 @@ from ra_aid.config import (
DEFAULT_TEST_CMD_TIMEOUT, DEFAULT_TEST_CMD_TIMEOUT,
VALID_PROVIDERS, VALID_PROVIDERS,
) )
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -393,251 +393,256 @@ def main():
except Exception as e: except Exception as e:
logger.error(f"Database migration error: {str(e)}") logger.error(f"Database migration error: {str(e)}")
# Check dependencies before proceeding # Initialize repositories with database connection
check_dependencies() with KeyFactRepositoryManager(db) as key_fact_repo:
# This initializes the repository and makes it available via get_key_fact_repository()
logger.debug("Initialized KeyFactRepository")
( # Check dependencies before proceeding
expert_enabled, check_dependencies()
expert_missing,
web_research_enabled,
web_research_missing,
) = validate_environment(args) # Will exit if main env vars missing
logger.debug("Environment validation successful")
# Validate model configuration early (
model_config = models_params.get(args.provider, {}).get( expert_enabled,
args.model or "", {} expert_missing,
) web_research_enabled,
supports_temperature = model_config.get( web_research_missing,
"supports_temperature", ) = validate_environment(args) # Will exit if main env vars missing
args.provider logger.debug("Environment validation successful")
in [
"anthropic",
"openai",
"openrouter",
"openai-compatible",
"deepseek",
],
)
if supports_temperature and args.temperature is None: # Validate model configuration early
args.temperature = model_config.get("default_temperature") model_config = models_params.get(args.provider, {}).get(
if args.temperature is None: args.model or "", {}
cpm( )
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}." supports_temperature = model_config.get(
"supports_temperature",
args.provider
in [
"anthropic",
"openai",
"openrouter",
"openai-compatible",
"deepseek",
],
)
if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature")
if args.temperature is None:
cpm(
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
)
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
)
status = build_status(args, expert_enabled, web_research_enabled)
console.print(
Panel(
status,
title=f"RA.Aid v{__version__}",
border_style="bright_blue",
padding=(0, 1),
) )
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
) )
status = build_status(args, expert_enabled, web_research_enabled) # Handle chat mode
if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
console.print( if args.research_only:
Panel( print_error("Chat mode cannot be used with --research-only")
status, sys.exit(1)
title=f"RA.Aid v{__version__}",
border_style="bright_blue",
padding=(0, 1),
)
)
# Handle chat mode print_stage_header("Chat Mode")
if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
if args.research_only: # Get project info
print_error("Chat mode cannot be used with --research-only") try:
project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info)
except Exception as e:
logger.warning(f"Failed to get project info: {e}")
formatted_project_info = ""
# Get initial request from user
initial_request = ask_human.invoke(
{"question": "What would you like help with?"}
)
# Record chat input in database (redundant as ask_human already records it,
# but needed in case the ask_human implementation changes)
try:
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
human_input_repo = HumanInputRepository(db)
human_input_repo.create(content=initial_request, source='chat')
human_input_repo.garbage_collect()
except Exception as e:
logger.error(f"Failed to record initial chat input: {str(e)}")
# Get working directory and current date
working_directory = os.getcwd()
current_date = datetime.now().strftime("%Y-%m-%d")
# Run chat agent with CHAT_PROMPT
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"chat_mode": True,
"cowboy_mode": args.cowboy_mode,
"hil": True, # Always true in chat mode
"web_research_enabled": web_research_enabled,
"initial_request": initial_request,
"limit_tokens": args.disable_limit_tokens,
}
# Store config in global memory
_global_memory["config"] = config
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
_global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
# Create chat agent with appropriate tools
chat_agent = create_agent(
chat_model,
get_chat_tools(
expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled,
),
checkpointer=MemorySaver(),
)
# Run chat agent and exit
run_agent_with_retry(
chat_agent,
CHAT_PROMPT.format(
initial_request=initial_request,
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT
if web_research_enabled
else ""
),
working_directory=working_directory,
current_date=current_date,
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()),
project_info=formatted_project_info,
),
config,
)
return
# Validate message is provided
if not args.message:
print_error("--message is required")
sys.exit(1) sys.exit(1)
print_stage_header("Chat Mode") base_task = args.message
# Get project info # Record CLI input in database
try:
project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info)
except Exception as e:
logger.warning(f"Failed to get project info: {e}")
formatted_project_info = ""
# Get initial request from user
initial_request = ask_human.invoke(
{"question": "What would you like help with?"}
)
# Record chat input in database (redundant as ask_human already records it,
# but needed in case the ask_human implementation changes)
try: try:
from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
human_input_repo = HumanInputRepository() human_input_repo = HumanInputRepository(db)
human_input_repo.create(content=initial_request, source='chat') human_input_repo.create(content=base_task, source='cli')
# Run garbage collection to ensure we don't exceed 100 inputs
human_input_repo.garbage_collect() human_input_repo.garbage_collect()
logger.debug(f"Recorded CLI input: {base_task}")
except Exception as e: except Exception as e:
logger.error(f"Failed to record initial chat input: {str(e)}") logger.error(f"Failed to record CLI input: {str(e)}")
# Get working directory and current date
working_directory = os.getcwd()
current_date = datetime.now().strftime("%Y-%m-%d")
# Run chat agent with CHAT_PROMPT
config = { config = {
"configurable": {"thread_id": str(uuid.uuid4())}, "configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit, "recursion_limit": args.recursion_limit,
"chat_mode": True, "research_only": args.research_only,
"cowboy_mode": args.cowboy_mode, "cowboy_mode": args.cowboy_mode,
"hil": True, # Always true in chat mode
"web_research_enabled": web_research_enabled, "web_research_enabled": web_research_enabled,
"initial_request": initial_request, "aider_config": args.aider_config,
"use_aider": args.use_aider,
"limit_tokens": args.disable_limit_tokens, "limit_tokens": args.disable_limit_tokens,
"auto_test": args.auto_test,
"test_cmd": args.test_cmd,
"max_test_cmd_retries": args.max_test_cmd_retries,
"experimental_fallback_handler": args.experimental_fallback_handler,
"test_cmd_timeout": args.test_cmd_timeout,
} }
# Store config in global memory # Store config in global memory for access by is_informational_query
_global_memory["config"] = config _global_memory["config"] = config
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider _global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model _global_memory["config"]["model"] = args.model
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider _global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model _global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = (
args.research_model or args.model
)
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature _global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag # Set modification tools based on use_aider flag
set_modification_tools(args.use_aider) set_modification_tools(args.use_aider)
# Create chat agent with appropriate tools # Run research stage
chat_agent = create_agent( print_stage_header("Research Stage")
chat_model,
get_chat_tools( # Initialize research model with potential overrides
expert_enabled=expert_enabled, research_provider = args.research_provider or args.provider
web_research_enabled=web_research_enabled, research_model_name = args.research_model or args.model
), research_model = initialize_llm(
checkpointer=MemorySaver(), research_provider, research_model_name, temperature=args.temperature
) )
# Run chat agent and exit run_research_agent(
run_agent_with_retry(
chat_agent,
CHAT_PROMPT.format(
initial_request=initial_request,
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT
if web_research_enabled
else ""
),
working_directory=working_directory,
current_date=current_date,
key_facts=format_key_facts_dict(KeyFactRepository().get_facts_dict()),
key_snippets=format_key_snippets_dict(KeySnippetRepository().get_snippets_dict()),
project_info=formatted_project_info,
),
config,
)
return
# Validate message is provided
if not args.message:
print_error("--message is required")
sys.exit(1)
base_task = args.message
# Record CLI input in database
try:
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
human_input_repo = HumanInputRepository()
human_input_repo.create(content=base_task, source='cli')
# Run garbage collection to ensure we don't exceed 100 inputs
human_input_repo.garbage_collect()
logger.debug(f"Recorded CLI input: {base_task}")
except Exception as e:
logger.error(f"Failed to record CLI input: {str(e)}")
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"research_only": args.research_only,
"cowboy_mode": args.cowboy_mode,
"web_research_enabled": web_research_enabled,
"aider_config": args.aider_config,
"use_aider": args.use_aider,
"limit_tokens": args.disable_limit_tokens,
"auto_test": args.auto_test,
"test_cmd": args.test_cmd,
"max_test_cmd_retries": args.max_test_cmd_retries,
"experimental_fallback_handler": args.experimental_fallback_handler,
"test_cmd_timeout": args.test_cmd_timeout,
}
# Store config in global memory for access by is_informational_query
_global_memory["config"] = config
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = (
args.research_model or args.model
)
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
# Run research stage
print_stage_header("Research Stage")
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model
research_model = initialize_llm(
research_provider, research_model_name, temperature=args.temperature
)
run_research_agent(
base_task,
research_model,
expert_enabled=expert_enabled,
research_only=args.research_only,
hil=args.hil,
memory=research_memory,
config=config,
)
# Proceed with planning and implementation if not an informational query
if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent
run_planning_agent(
base_task, base_task,
planning_model, research_model,
expert_enabled=expert_enabled, expert_enabled=expert_enabled,
research_only=args.research_only,
hil=args.hil, hil=args.hil,
memory=planning_memory, memory=research_memory,
config=config, config=config,
) )
# Proceed with planning and implementation if not an informational query
if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent
run_planning_agent(
base_task,
planning_model,
expert_enabled=expert_enabled,
hil=args.hil,
memory=planning_memory,
config=config,
)
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
print() print()
print(" 👋 Bye!") print(" 👋 Bye!")

View File

@ -84,8 +84,8 @@ from ra_aid.tool_configs import (
get_web_research_tools, get_web_research_tools,
) )
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -100,10 +100,8 @@ console = Console()
logger = get_logger(__name__) logger = get_logger(__name__)
# Initialize repositories # Import repositories using get_* functions
key_fact_repository = KeyFactRepository() from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
key_snippet_repository = KeySnippetRepository()
human_input_repository = HumanInputRepository()
@tool @tool
@ -391,7 +389,11 @@ def run_research_agent(
else "" else ""
) )
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict()) try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
code_snippets = _global_memory.get("code_snippets", "") code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "") related_files = _global_memory.get("related_files", "")
@ -400,6 +402,7 @@ def run_research_agent(
# Get the last human input, if it exists # Get the last human input, if it exists
base_task = base_task_or_query base_task = base_task_or_query
human_input_repository = HumanInputRepository()
recent_inputs = human_input_repository.get_recent(1) recent_inputs = human_input_repository.get_recent(1)
if recent_inputs and len(recent_inputs) > 0: if recent_inputs and len(recent_inputs) > 0:
last_human_input = recent_inputs[0].content last_human_input = recent_inputs[0].content
@ -537,7 +540,11 @@ def run_web_research_agent(
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict()) try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
code_snippets = _global_memory.get("code_snippets", "") code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "") related_files = _global_memory.get("related_files", "")
@ -647,6 +654,20 @@ def run_planning_agent(
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd() working_directory = os.getcwd()
# Make sure key_facts is defined before using it
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
# Make sure key_snippets is defined before using it
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
planning_prompt = PLANNING_PROMPT.format( planning_prompt = PLANNING_PROMPT.format(
current_date=current_date, current_date=current_date,
working_directory=working_directory, working_directory=working_directory,
@ -657,8 +678,8 @@ def run_planning_agent(
project_info=formatted_project_info, project_info=formatted_project_info,
research_notes=get_memory_value("research_notes"), research_notes=get_memory_value("research_notes"),
related_files="\n".join(get_related_files()), related_files="\n".join(get_related_files()),
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()), key_facts=key_facts,
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), key_snippets=key_snippets,
work_log=get_memory_value("work_log"), work_log=get_memory_value("work_log"),
research_only_note=( research_only_note=(
"" ""
@ -751,6 +772,13 @@ def run_task_implementation_agent(
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd() working_directory = os.getcwd()
# Make sure key_facts is defined before using it
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
prompt = IMPLEMENTATION_PROMPT.format( prompt = IMPLEMENTATION_PROMPT.format(
current_date=current_date, current_date=current_date,
working_directory=working_directory, working_directory=working_directory,
@ -759,8 +787,8 @@ def run_task_implementation_agent(
tasks=tasks, tasks=tasks,
plan=plan, plan=plan,
related_files=related_files, related_files=related_files,
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()), key_facts=key_facts,
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
research_notes=get_memory_value("research_notes"), research_notes=get_memory_value("research_notes"),
work_log=get_memory_value("work_log"), work_log=get_memory_value("work_log"),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",

View File

@ -6,6 +6,7 @@ facts when the total number exceeds a specified threshold. The agent evaluates a
key facts and deletes the least valuable ones to keep the database clean and relevant. key facts and deletes the least valuable ones to keep the database clean and relevant.
""" """
import logging
from typing import List from typing import List
from langchain_core.tools import tool from langchain_core.tools import tool
@ -13,8 +14,10 @@ from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
logger = logging.getLogger(__name__)
from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
@ -22,7 +25,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
console = Console() console = Console()
key_fact_repository = KeyFactRepository()
human_input_repository = HumanInputRepository() human_input_repository = HumanInputRepository()
@ -51,23 +53,29 @@ def delete_key_facts(fact_ids: List[int]) -> str:
console.print(f"Warning: Could not retrieve current human input: {str(e)}") console.print(f"Warning: Could not retrieve current human input: {str(e)}")
for fact_id in fact_ids: for fact_id in fact_ids:
# Get the fact first to display information try:
fact = key_fact_repository.get(fact_id) # Get the fact first to display information
if fact: fact = get_key_fact_repository().get(fact_id)
# Check if this fact is associated with the current human input if fact:
if current_human_input_id is not None and fact.human_input_id == current_human_input_id: # Check if this fact is associated with the current human input
protected_facts.append((fact_id, fact.content)) if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
continue protected_facts.append((fact_id, fact.content))
continue
# Delete the fact if it's not protected # Delete the fact if it's not protected
was_deleted = key_fact_repository.delete(fact_id) was_deleted = get_key_fact_repository().delete(fact_id)
if was_deleted: if was_deleted:
deleted_facts.append((fact_id, fact.content)) deleted_facts.append((fact_id, fact.content))
log_work_event(f"Deleted fact {fact_id}.") log_work_event(f"Deleted fact {fact_id}.")
else: else:
failed_facts.append(fact_id) failed_facts.append(fact_id)
else: except RuntimeError as e:
not_found_facts.append(fact_id) logger.error(f"Failed to access key fact repository: {str(e)}")
failed_facts.append(fact_id)
except Exception as e:
# For any other exceptions, log and continue
logger.error(f"Error processing fact {fact_id}: {str(e)}")
failed_facts.append(fact_id)
# Prepare result message # Prepare result message
result_parts = [] result_parts = []
@ -104,8 +112,13 @@ def run_key_facts_gc_agent() -> None:
Facts associated with the current human input are excluded from deletion. Facts associated with the current human input are excluded from deletion.
""" """
# Get the count of key facts # Get the count of key facts
facts = key_fact_repository.get_all() try:
fact_count = len(facts) facts = get_key_fact_repository().get_all()
fact_count = len(facts)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
return # Exit the function if we can't access the repository
# Display status panel with fact count included # Display status panel with fact count included
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection")) console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
@ -161,8 +174,12 @@ def run_key_facts_gc_agent() -> None:
run_agent_with_retry(agent, prompt, agent_config) run_agent_with_retry(agent, prompt, agent_config)
# Get updated count # Get updated count
updated_facts = key_fact_repository.get_all() try:
updated_count = len(updated_facts) updated_facts = get_key_fact_repository().get_all()
updated_count = len(updated_facts)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository for update count: {str(e)}")
updated_count = "unknown"
# Show info panel with updated count and protected facts count # Show info panel with updated count and protected facts count
protected_count = len(protected_facts) protected_count = len(protected_facts)

View File

@ -14,7 +14,7 @@ from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
@ -22,7 +22,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
console = Console() console = Console()
key_snippet_repository = KeySnippetRepository()
human_input_repository = HumanInputRepository() human_input_repository = HumanInputRepository()
@ -53,7 +52,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
for snippet_id in snippet_ids: for snippet_id in snippet_ids:
# Get the snippet first to capture filepath for the message # Get the snippet first to capture filepath for the message
snippet = key_snippet_repository.get(snippet_id) snippet = get_key_snippet_repository().get(snippet_id)
if snippet: if snippet:
filepath = snippet.filepath filepath = snippet.filepath
@ -63,7 +62,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
continue continue
# Delete from database if not protected # Delete from database if not protected
success = key_snippet_repository.delete(snippet_id) success = get_key_snippet_repository().delete(snippet_id)
if success: if success:
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}" success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
console.print( console.print(
@ -110,7 +109,7 @@ def run_key_snippets_gc_agent() -> None:
Snippets associated with the current human input are excluded from deletion. Snippets associated with the current human input are excluded from deletion.
""" """
# Get the count of key snippets # Get the count of key snippets
snippets = key_snippet_repository.get_all() snippets = get_key_snippet_repository().get_all()
snippet_count = len(snippets) snippet_count = len(snippets)
# Display status panel with snippet count included # Display status panel with snippet count included
@ -179,7 +178,7 @@ def run_key_snippets_gc_agent() -> None:
run_agent_with_retry(agent, prompt, agent_config) run_agent_with_retry(agent, prompt, agent_config)
# Get updated count # Get updated count
updated_snippets = key_snippet_repository.get_all() updated_snippets = get_key_snippet_repository().get_all()
updated_count = len(updated_snippets) updated_count = len(updated_snippets)
# Show info panel with updated count and protected snippets count # Show info panel with updated count and protected snippets count

View File

@ -6,15 +6,94 @@ following the repository pattern for data access abstraction.
""" """
from typing import Dict, List, Optional from typing import Dict, List, Optional
import contextvars
from contextlib import contextmanager
import peewee import peewee
from ra_aid.database.connection import get_db from ra_aid.database.models import KeyFact
from ra_aid.database.models import KeyFact, initialize_database
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
# Create contextvar to hold the KeyFactRepository instance
key_fact_repo_var = contextvars.ContextVar("key_fact_repo", default=None)
class KeyFactRepositoryManager:
"""
Context manager for KeyFactRepository.
This class provides a context manager interface for KeyFactRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with KeyFactRepositoryManager(db) as repo:
# Use the repository
fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the KeyFactRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'KeyFactRepository':
"""
Initialize the KeyFactRepository and return it.
Returns:
KeyFactRepository: The initialized repository
"""
repo = KeyFactRepository(self.db)
key_fact_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
key_fact_repo_var.set(None)
# Don't suppress exceptions
return False
def get_key_fact_repository() -> 'KeyFactRepository':
"""
Get the current KeyFactRepository instance.
Returns:
KeyFactRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with KeyFactRepositoryManager
"""
repo = key_fact_repo_var.get()
if repo is None:
raise RuntimeError(
"No KeyFactRepository available. "
"Make sure to initialize one with KeyFactRepositoryManager first."
)
return repo
class KeyFactRepository: class KeyFactRepository:
""" """
@ -24,18 +103,21 @@ class KeyFactRepository:
abstracting the database access details from the business logic. abstracting the database access details from the business logic.
Example: Example:
repo = KeyFactRepository() with DatabaseManager() as db:
fact = repo.create("Important fact about the project") with KeyFactRepositoryManager(db) as repo:
all_facts = repo.get_all() fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
""" """
def __init__(self, db=None): def __init__(self, db):
""" """
Initialize the repository with an optional database connection. Initialize the repository with a database connection.
Args: Args:
db: Optional database connection to use. If None, will use initialize_database() db: Database connection to use (required)
""" """
if db is None:
raise ValueError("Database connection is required for KeyFactRepository")
self.db = db self.db = db
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact: def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
@ -53,7 +135,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error creating the fact peewee.DatabaseError: If there's an error creating the fact
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
fact = KeyFact.create(content=content, human_input_id=human_input_id) fact = KeyFact.create(content=content, human_input_id=human_input_id)
logger.debug(f"Created key fact ID {fact.id}: {content}") logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact return fact
@ -75,7 +156,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database peewee.DatabaseError: If there's an error accessing the database
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
return KeyFact.get_or_none(KeyFact.id == fact_id) return KeyFact.get_or_none(KeyFact.id == fact_id)
except peewee.DatabaseError as e: except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}") logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
@ -96,7 +176,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error updating the fact peewee.DatabaseError: If there's an error updating the fact
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
# First check if the fact exists # First check if the fact exists
fact = self.get(fact_id) fact = self.get(fact_id)
if not fact: if not fact:
@ -126,7 +205,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error deleting the fact peewee.DatabaseError: If there's an error deleting the fact
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
# First check if the fact exists # First check if the fact exists
fact = self.get(fact_id) fact = self.get(fact_id)
if not fact: if not fact:
@ -152,7 +230,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database peewee.DatabaseError: If there's an error accessing the database
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
return list(KeyFact.select().order_by(KeyFact.id)) return list(KeyFact.select().order_by(KeyFact.id))
except peewee.DatabaseError as e: except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key facts: {str(e)}") logger.error(f"Failed to fetch all key facts: {str(e)}")

View File

@ -215,3 +215,20 @@ class KeySnippetRepository:
except peewee.DatabaseError as e: except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}") logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
raise raise
# Global singleton instance
_key_snippet_repository = None
def get_key_snippet_repository() -> KeySnippetRepository:
"""
Get or create a singleton instance of KeySnippetRepository.
Returns:
KeySnippetRepository: Singleton instance of the repository
"""
global _key_snippet_repository
if _key_snippet_repository is None:
_key_snippet_repository = KeySnippetRepository()
return _key_snippet_repository

View File

@ -1,6 +1,7 @@
"""Tools for spawning and managing sub-agents.""" """Tools for spawning and managing sub-agents."""
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import logging
from langchain_core.tools import tool from langchain_core.tools import tool
from rich.console import Console from rich.console import Console
@ -12,8 +13,9 @@ from ra_aid.agent_context import (
reset_completion_flags, reset_completion_flags,
) )
from ra_aid.console.formatting import print_error from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.exceptions import AgentInterrupt from ra_aid.exceptions import AgentInterrupt
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -31,8 +33,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user.
RESEARCH_AGENT_RECURSION_LIMIT = 3 RESEARCH_AGENT_RECURSION_LIMIT = 3
console = Console() console = Console()
key_fact_repository = KeyFactRepository() logger = logging.getLogger(__name__)
key_snippet_repository = KeySnippetRepository()
@tool("request_research") @tool("request_research")
@ -57,12 +58,24 @@ def request_research(query: str) -> ResearchResult:
current_depth = _global_memory.get("agent_depth", 0) current_depth = _global_memory.get("agent_depth", 0)
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT: if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
print_error("Maximum research recursion depth reached") print_error("Maximum research recursion depth reached")
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
return { return {
"completion_message": "Research stopped - maximum recursion depth reached", "completion_message": "Research stopped - maximum recursion depth reached",
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "key_facts": key_facts,
"related_files": get_related_files(), "related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"), "research_notes": get_memory_value("research_notes"),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), "key_snippets": key_snippets,
"success": False, "success": False,
"reason": "max_depth_exceeded", "reason": "max_depth_exceeded",
} }
@ -105,12 +118,24 @@ def request_research(query: str) -> ResearchResult:
# Clear completion state # Clear completion state
reset_completion_flags() reset_completion_flags()
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = { response_data = {
"completion_message": completion_message, "completion_message": completion_message,
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "key_facts": key_facts,
"related_files": get_related_files(), "related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"), "research_notes": get_memory_value("research_notes"),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), "key_snippets": key_snippets,
"success": success, "success": success,
"reason": reason, "reason": reason,
} }
@ -171,9 +196,15 @@ def request_web_research(query: str) -> ResearchResult:
# Clear completion state # Clear completion state
reset_completion_flags() reset_completion_flags()
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = { response_data = {
"completion_message": completion_message, "completion_message": completion_message,
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), "key_snippets": key_snippets,
"research_notes": get_memory_value("research_notes"), "research_notes": get_memory_value("research_notes"),
"success": success, "success": success,
"reason": reason, "reason": reason,
@ -239,12 +270,24 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
# Clear completion state # Clear completion state
reset_completion_flags() reset_completion_flags()
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = { response_data = {
"completion_message": completion_message, "completion_message": completion_message,
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "key_facts": key_facts,
"related_files": get_related_files(), "related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"), "research_notes": get_memory_value("research_notes"),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), "key_snippets": key_snippets,
"success": success, "success": success,
"reason": reason, "reason": reason,
} }
@ -324,10 +367,22 @@ def request_task_implementation(task_spec: str) -> str:
agent_crashed = is_crashed() agent_crashed = is_crashed()
crash_message = get_crash_message() if agent_crashed else None crash_message = get_crash_message() if agent_crashed else None
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = { response_data = {
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "key_facts": key_facts,
"related_files": get_related_files(), "related_files": get_related_files(),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), "key_snippets": key_snippets,
"completion_message": completion_message, "completion_message": completion_message,
"success": success and not agent_crashed, "success": success and not agent_crashed,
"reason": reason, "reason": reason,
@ -444,11 +499,23 @@ def request_implementation(task_spec: str) -> str:
agent_crashed = is_crashed() agent_crashed = is_crashed()
crash_message = get_crash_message() if agent_crashed else None crash_message = get_crash_message() if agent_crashed else None
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = { response_data = {
"completion_message": completion_message, "completion_message": completion_message,
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "key_facts": key_facts,
"related_files": get_related_files(), "related_files": get_related_files(),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()), "key_snippets": key_snippets,
"success": success and not agent_crashed, "success": success and not agent_crashed,
"reason": reason, "reason": reason,
"agent_crashed": agent_crashed, "agent_crashed": agent_crashed,

View File

@ -1,4 +1,5 @@
import os import os
import logging
from typing import List from typing import List
from langchain_core.tools import tool from langchain_core.tools import tool
@ -6,8 +7,10 @@ from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from ..database.repositories.key_fact_repository import KeyFactRepository logger = logging.getLogger(__name__)
from ..database.repositories.key_snippet_repository import KeySnippetRepository
from ..database.repositories.key_fact_repository import get_key_fact_repository
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
from ..llm import initialize_expert_llm from ..llm import initialize_expert_llm
from ..model_formatters import format_key_facts_dict from ..model_formatters import format_key_facts_dict
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -15,8 +18,6 @@ from .memory import _global_memory, get_memory_value
console = Console() console = Console()
_model = None _model = None
key_fact_repository = KeyFactRepository()
key_snippet_repository = KeySnippetRepository()
def get_model(): def get_model():
@ -154,10 +155,18 @@ def ask_expert(question: str) -> str:
file_paths = list(_global_memory["related_files"].values()) file_paths = list(_global_memory["related_files"].values())
related_contents = read_related_files(file_paths) related_contents = read_related_files(file_paths)
# Get key snippets directly from repository and format using the formatter # Get key snippets directly from repository and format using the formatter
key_snippets = format_key_snippets_dict(key_snippet_repository.get_snippets_dict()) try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
# Get key facts directly from repository and format using the formatter # Get key facts directly from repository and format using the formatter
facts_dict = key_fact_repository.get_facts_dict() try:
key_facts = format_key_facts_dict(facts_dict) facts_dict = get_key_fact_repository().get_facts_dict()
key_facts = format_key_facts_dict(facts_dict)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
research_notes = get_memory_value("research_notes") research_notes = get_memory_value("research_notes")
# Build display query (just question) # Build display query (just question)

View File

@ -17,8 +17,8 @@ from ra_aid.agent_context import (
mark_should_exit, mark_should_exit,
mark_task_completed, mark_task_completed,
) )
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.database.repositories.key_fact_repository import KeyFactRepository, get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository, get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.model_formatters import key_snippets_formatter from ra_aid.model_formatters import key_snippets_formatter
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
@ -40,14 +40,8 @@ class SnippetInfo(TypedDict):
console = Console() console = Console()
# Initialize repository for key facts # Import repositories using the get_* functions
key_fact_repository = KeyFactRepository() from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
# Initialize repository for key snippets
key_snippet_repository = KeySnippetRepository()
# Initialize repository for human inputs
human_input_repository = HumanInputRepository()
# Global memory store # Global memory store
_global_memory: Dict[str, Any] = { _global_memory: Dict[str, Any] = {
@ -120,17 +114,23 @@ def emit_key_facts(facts: List[str]) -> str:
# Try to get the latest human input # Try to get the latest human input
human_input_id = None human_input_id = None
human_input_repo = HumanInputRepository()
try: try:
recent_inputs = human_input_repository.get_recent(1) recent_inputs = human_input_repo.get_recent(1)
if recent_inputs and len(recent_inputs) > 0: if recent_inputs and len(recent_inputs) > 0:
human_input_id = recent_inputs[0].id human_input_id = recent_inputs[0].id
except Exception as e: except Exception as e:
logger.warning(f"Failed to get recent human input: {str(e)}") logger.warning(f"Failed to get recent human input: {str(e)}")
for fact in facts: for fact in facts:
# Create fact in database using repository try:
created_fact = key_fact_repository.create(fact, human_input_id=human_input_id) # Create fact in database using repository
fact_id = created_fact.id created_fact = get_key_fact_repository().create(fact, human_input_id=human_input_id)
fact_id = created_fact.id
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
console.print(f"Error storing fact: {str(e)}", style="red")
continue
# Display panel with ID # Display panel with ID
console.print( console.print(
@ -147,14 +147,17 @@ def emit_key_facts(facts: List[str]) -> str:
log_work_event(f"Stored {len(facts)} key facts.") log_work_event(f"Stored {len(facts)} key facts.")
# Check if we need to clean up facts (more than 30) # Check if we need to clean up facts (more than 30)
all_facts = key_fact_repository.get_all() try:
if len(all_facts) > 30: all_facts = get_key_fact_repository().get_all()
# Trigger the key facts cleaner agent if len(all_facts) > 30:
try: # Trigger the key facts cleaner agent
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent try:
run_key_facts_gc_agent() from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
except Exception as e: run_key_facts_gc_agent()
logger.error(f"Failed to run key facts cleaner: {str(e)}") except Exception as e:
logger.error(f"Failed to run key facts cleaner: {str(e)}")
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
return "Facts stored." return "Facts stored."
@ -222,14 +225,15 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
# Try to get the latest human input # Try to get the latest human input
human_input_id = None human_input_id = None
try: try:
recent_inputs = human_input_repository.get_recent(1) human_input_repo = HumanInputRepository()
recent_inputs = human_input_repo.get_recent(1)
if recent_inputs and len(recent_inputs) > 0: if recent_inputs and len(recent_inputs) > 0:
human_input_id = recent_inputs[0].id human_input_id = recent_inputs[0].id
except Exception as e: except Exception as e:
logger.warning(f"Failed to get recent human input: {str(e)}") logger.warning(f"Failed to get recent human input: {str(e)}")
# Create a new key snippet in the database # Create a new key snippet in the database
key_snippet = key_snippet_repository.create( key_snippet = get_key_snippet_repository().create(
filepath=snippet_info["filepath"], filepath=snippet_info["filepath"],
line_number=snippet_info["line_number"], line_number=snippet_info["line_number"],
snippet=snippet_info["snippet"], snippet=snippet_info["snippet"],
@ -266,7 +270,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
log_work_event(f"Stored code snippet #{snippet_id}.") log_work_event(f"Stored code snippet #{snippet_id}.")
# Check if we need to clean up snippets (more than 20) # Check if we need to clean up snippets (more than 20)
all_snippets = key_snippet_repository.get_all() all_snippets = get_key_snippet_repository().get_all()
if len(all_snippets) > 20: if len(all_snippets) > 20:
# Trigger the key snippets cleaner agent # Trigger the key snippets cleaner agent
try: try:
@ -279,7 +283,6 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
@tool("swap_task_order") @tool("swap_task_order")
def swap_task_order(id1: int, id2: int) -> str: def swap_task_order(id1: int, id2: int) -> str:
"""Swap the order of two tasks in global memory by their IDs. """Swap the order of two tasks in global memory by their IDs.

View File

@ -9,7 +9,12 @@ import peewee
from ra_aid.database.connection import DatabaseManager, db_var from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import KeyFact, BaseModel from ra_aid.database.models import KeyFact, BaseModel
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.database.repositories.key_fact_repository import (
KeyFactRepository,
KeyFactRepositoryManager,
get_key_fact_repository,
key_fact_repo_var
)
@pytest.fixture @pytest.fixture
@ -41,6 +46,19 @@ def cleanup_db():
db_var.set(None) db_var.set(None)
@pytest.fixture
def cleanup_repo():
"""Reset the repository contextvar after each test."""
# Reset before the test
key_fact_repo_var.set(None)
# Run the test
yield
# Reset after the test
key_fact_repo_var.set(None)
@pytest.fixture @pytest.fixture
def setup_db(cleanup_db): def setup_db(cleanup_db):
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database.""" """Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
@ -196,3 +214,48 @@ def test_get_facts_dict(setup_db):
for fact in facts: for fact in facts:
assert fact.id in facts_dict assert fact.id in facts_dict
assert facts_dict[fact.id] == fact.content assert facts_dict[fact.id] == fact.content
def test_repository_init_without_db():
"""Test that KeyFactRepository raises an error when initialized without a db parameter."""
# Attempt to create a repository without a database connection
with pytest.raises(ValueError) as excinfo:
KeyFactRepository(db=None)
# Verify the correct error message
assert "Database connection is required" in str(excinfo.value)
def test_key_fact_repository_manager(setup_db, cleanup_repo):
"""Test the KeyFactRepositoryManager context manager."""
# Use the context manager to create a repository
with KeyFactRepositoryManager(setup_db) as repo:
# Verify the repository was created correctly
assert isinstance(repo, KeyFactRepository)
assert repo.db is setup_db
# Verify we can use the repository
content = "Test fact via context manager"
fact = repo.create(content)
assert fact.id is not None
assert fact.content == content
# Verify we can get the repository using get_key_fact_repository
repo_from_var = get_key_fact_repository()
assert repo_from_var is repo
# Verify the repository was removed from the context var
with pytest.raises(RuntimeError) as excinfo:
get_key_fact_repository()
assert "No KeyFactRepository available" in str(excinfo.value)
def test_get_key_fact_repository_when_not_set(cleanup_repo):
"""Test that get_key_fact_repository raises an error when no repository is in context."""
# Attempt to get the repository when none exists
with pytest.raises(RuntimeError) as excinfo:
get_key_fact_repository()
# Verify the correct error message
assert "No KeyFactRepository available" in str(excinfo.value)

View File

@ -36,9 +36,11 @@ def reset_memory():
@pytest.fixture @pytest.fixture
def mock_functions(): def mock_functions():
"""Mock functions used in agent.py""" """Mock functions used in agent.py"""
with patch('ra_aid.tools.agent.key_fact_repository') as mock_fact_repo, \ mock_fact_repo = MagicMock()
mock_snippet_repo = MagicMock()
with patch('ra_aid.tools.agent.get_key_fact_repository', return_value=mock_fact_repo) as mock_get_fact_repo, \
patch('ra_aid.tools.agent.format_key_facts_dict') as mock_fact_formatter, \ patch('ra_aid.tools.agent.format_key_facts_dict') as mock_fact_formatter, \
patch('ra_aid.tools.agent.key_snippet_repository') as mock_snippet_repo, \ patch('ra_aid.tools.agent.get_key_snippet_repository', return_value=mock_snippet_repo) as mock_get_snippet_repo, \
patch('ra_aid.tools.agent.format_key_snippets_dict') as mock_snippet_formatter, \ patch('ra_aid.tools.agent.format_key_snippets_dict') as mock_snippet_formatter, \
patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \ patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \ patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
@ -60,8 +62,8 @@ def mock_functions():
# Return all mocks as a dictionary # Return all mocks as a dictionary
yield { yield {
'key_fact_repository': mock_fact_repo, 'get_key_fact_repository': mock_get_fact_repo,
'key_snippet_repository': mock_snippet_repo, 'get_key_snippet_repository': mock_get_snippet_repo,
'format_key_facts_dict': mock_fact_formatter, 'format_key_facts_dict': mock_fact_formatter,
'format_key_snippets_dict': mock_snippet_formatter, 'format_key_snippets_dict': mock_snippet_formatter,
'initialize_llm': mock_llm, 'initialize_llm': mock_llm,
@ -81,11 +83,12 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
result = request_research("test query") result = request_research("test query")
# Verify repository was called # Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results # Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with( mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
) )
# Verify formatted facts are used in response # Verify formatted facts are used in response
@ -105,11 +108,12 @@ def test_request_research_max_depth(reset_memory, mock_functions):
result = request_research("test query") result = request_research("test query")
# Verify repository was called # Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results # Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with( mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
) )
# Verify formatted facts are used in response # Verify formatted facts are used in response
@ -128,11 +132,12 @@ def test_request_research_and_implementation_uses_key_fact_repository(reset_memo
result = request_research_and_implementation("test query") result = request_research_and_implementation("test query")
# Verify repository was called # Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results # Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with( mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
) )
# Verify formatted facts are used in response # Verify formatted facts are used in response
@ -151,11 +156,12 @@ def test_request_implementation_uses_key_fact_repository(reset_memory, mock_func
result = request_implementation("test task") result = request_implementation("test task")
# Verify repository was called # Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results # Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with( mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
) )
# Check that the formatted key facts are included in the response # Check that the formatted key facts are included in the response
@ -174,11 +180,12 @@ def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock
result = request_task_implementation("test task") result = request_task_implementation("test task")
# Verify repository was called # Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results # Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with( mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
) )
# Check that the formatted key facts are included in the response # Check that the formatted key facts are included in the response

View File

@ -16,12 +16,12 @@ from ra_aid.tools.memory import (
get_memory_value, get_memory_value,
get_related_files, get_related_files,
get_work_log, get_work_log,
key_fact_repository,
key_snippet_repository,
log_work_event, log_work_event,
reset_work_log, reset_work_log,
swap_task_order, swap_task_order,
) )
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.connection import DatabaseManager from ra_aid.database.connection import DatabaseManager
from ra_aid.database.models import KeyFact from ra_aid.database.models import KeyFact
@ -60,7 +60,7 @@ def in_memory_db():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_repository(): def mock_repository():
"""Mock the KeyFactRepository to avoid database operations during tests""" """Mock the KeyFactRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo: with patch('ra_aid.tools.memory.get_key_fact_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory # Setup the mock repository to behave like the original, but using memory
facts = {} # Local in-memory storage facts = {} # Local in-memory storage
fact_id_counter = 0 fact_id_counter = 0
@ -79,12 +79,12 @@ def mock_repository():
facts[fact_id_counter] = fact facts[fact_id_counter] = fact
fact_id_counter += 1 fact_id_counter += 1
return fact return fact
mock_repo.create.side_effect = mock_create mock_repo.return_value.create.side_effect = mock_create
# Mock get method # Mock get method
def mock_get(fact_id): def mock_get(fact_id):
return facts.get(fact_id) return facts.get(fact_id)
mock_repo.get.side_effect = mock_get mock_repo.return_value.get.side_effect = mock_get
# Mock delete method # Mock delete method
def mock_delete(fact_id): def mock_delete(fact_id):
@ -92,17 +92,17 @@ def mock_repository():
del facts[fact_id] del facts[fact_id]
return True return True
return False return False
mock_repo.delete.side_effect = mock_delete mock_repo.return_value.delete.side_effect = mock_delete
# Mock get_facts_dict method # Mock get_facts_dict method
def mock_get_facts_dict(): def mock_get_facts_dict():
return {fact_id: fact.content for fact_id, fact in facts.items()} return {fact_id: fact.content for fact_id, fact in facts.items()}
mock_repo.get_facts_dict.side_effect = mock_get_facts_dict mock_repo.return_value.get_facts_dict.side_effect = mock_get_facts_dict
# Mock get_all method # Mock get_all method
def mock_get_all(): def mock_get_all():
return list(facts.values()) return list(facts.values())
mock_repo.get_all.side_effect = mock_get_all mock_repo.return_value.get_all.side_effect = mock_get_all
yield mock_repo yield mock_repo
@ -159,16 +159,16 @@ def mock_key_snippet_repository():
return list(snippets.values()) return list(snippets.values())
# Create the actual mocks for both memory.py and key_snippets_gc_agent.py # Create the actual mocks for both memory.py and key_snippets_gc_agent.py
with patch('ra_aid.tools.memory.key_snippet_repository') as memory_mock_repo, \ with patch('ra_aid.tools.memory.get_key_snippet_repository') as memory_mock_repo, \
patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository') as agent_mock_repo: patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository') as agent_mock_repo:
# Setup both mocks with the same implementation # Setup both mocks with the same implementation
for mock_repo in [memory_mock_repo, agent_mock_repo]: for mock_repo in [memory_mock_repo, agent_mock_repo]:
mock_repo.create.side_effect = mock_create mock_repo.return_value.create.side_effect = mock_create
mock_repo.get.side_effect = mock_get mock_repo.return_value.get.side_effect = mock_get
mock_repo.delete.side_effect = mock_delete mock_repo.return_value.delete.side_effect = mock_delete
mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict mock_repo.return_value.get_snippets_dict.side_effect = mock_get_snippets_dict
mock_repo.get_all.side_effect = mock_get_all mock_repo.return_value.get_all.side_effect = mock_get_all
yield memory_mock_repo yield memory_mock_repo
@ -180,7 +180,7 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
assert result == "Facts stored." assert result == "Facts stored."
# Verify the repository's create method was called # Verify the repository's create method was called
mock_repository.create.assert_called_once_with("First fact", human_input_id=ANY) mock_repository.return_value.create.assert_called_once_with("First fact", human_input_id=ANY)
def test_get_memory_value_other_types(reset_memory): def test_get_memory_value_other_types(reset_memory):
@ -264,10 +264,10 @@ def test_emit_key_facts(reset_memory, mock_repository):
assert result == "Facts stored." assert result == "Facts stored."
# Verify create was called for each fact # Verify create was called for each fact
assert mock_repository.create.call_count == 3 assert mock_repository.return_value.create.call_count == 3
mock_repository.create.assert_any_call("First fact", human_input_id=ANY) mock_repository.return_value.create.assert_any_call("First fact", human_input_id=ANY)
mock_repository.create.assert_any_call("Second fact", human_input_id=ANY) mock_repository.return_value.create.assert_any_call("Second fact", human_input_id=ANY)
mock_repository.create.assert_any_call("Third fact", human_input_id=ANY) mock_repository.return_value.create.assert_any_call("Third fact", human_input_id=ANY)
def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository): def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
@ -278,7 +278,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None)) facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
# Mock the get_all method to return more than 30 facts # Mock the get_all method to return more than 30 facts
mock_repository.get_all.return_value = facts mock_repository.return_value.get_all.return_value = facts
# Note on testing approach: # Note on testing approach:
# Rather than trying to mock the dynamic import which is challenging due to # Rather than trying to mock the dynamic import which is challenging due to
@ -295,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
# Verify that mock_repository.get_all was called, # Verify that mock_repository.get_all was called,
# which is the condition that would trigger the GC agent # which is the condition that would trigger the GC agent
mock_repository.get_all.assert_called_once() mock_repository.return_value.get_all.assert_called_once()
def test_emit_key_snippet(reset_memory, mock_key_snippet_repository): def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
@ -315,7 +315,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
assert result == "Snippet #0 stored." assert result == "Snippet #0 stored."
# Verify create was called correctly # Verify create was called correctly
mock_key_snippet_repository.create.assert_called_with( mock_key_snippet_repository.return_value.create.assert_called_with(
filepath="test.py", filepath="test.py",
line_number=10, line_number=10,
snippet="def test():\n pass", snippet="def test():\n pass",
@ -338,7 +338,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
assert result == "Snippet #1 stored." assert result == "Snippet #1 stored."
# Verify create was called correctly # Verify create was called correctly
mock_key_snippet_repository.create.assert_called_with( mock_key_snippet_repository.return_value.create.assert_called_with(
filepath="main.py", filepath="main.py",
line_number=20, line_number=20,
snippet="print('hello')", snippet="print('hello')",
@ -379,16 +379,16 @@ def test_delete_key_snippets(mock_log_work_event, reset_memory, mock_key_snippet
mock_key_snippet_repository.reset_mock() mock_key_snippet_repository.reset_mock()
# Test deleting mix of valid and invalid IDs # Test deleting mix of valid and invalid IDs
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]}) result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
# Verify success message # Verify success message
assert result == "Snippets deleted." assert result == "Snippets deleted."
# Verify repository get was called with correct IDs # Verify repository get was called with correct IDs
mock_key_snippet_repository.get.assert_any_call(0) mock_key_snippet_repository.return_value.get.assert_any_call(0)
mock_key_snippet_repository.get.assert_any_call(1) mock_key_snippet_repository.return_value.get.assert_any_call(1)
mock_key_snippet_repository.get.assert_any_call(999) mock_key_snippet_repository.return_value.get.assert_any_call(999)
# We skip verifying delete calls because they are prone to test environment issues # We skip verifying delete calls because they are prone to test environment issues
# The implementation logic will properly delete IDs 0 and 1 but not 999 # The implementation logic will properly delete IDs 0 and 1 but not 999
@ -410,12 +410,12 @@ def test_delete_key_snippets_empty(mock_log_work_event, reset_memory, mock_key_s
mock_key_snippet_repository.reset_mock() mock_key_snippet_repository.reset_mock()
# Test with empty list # Test with empty list
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": []}) result = delete_key_snippets.invoke({"snippet_ids": []})
assert result == "Snippets deleted." assert result == "Snippets deleted."
# Verify no call to delete method # Verify no call to delete method
mock_key_snippet_repository.delete.assert_not_called() mock_key_snippet_repository.return_value.delete.assert_not_called()
def test_emit_related_files_basic(reset_memory, tmp_path): def test_emit_related_files_basic(reset_memory, tmp_path):
@ -458,13 +458,13 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path):
new_file.write_text("# New file") new_file.write_text("# New file")
# Add initial files # Add initial files
result = emit_related_files.invoke({"files": [str(test_file), str(main_file)]}) result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
assert result == "Files noted." assert result1 == "Files noted."
_first_id = 0 # ID of test.py _first_id = 0 # ID of test.py
# Try adding duplicates # Try adding duplicates
result = emit_related_files.invoke({"files": [str(test_file)]}) result2 = emit_related_files.invoke({"files": [str(test_file)]})
assert result == "Files noted." assert result2 == "Files noted."
assert len(_global_memory["related_files"]) == 2 # Count should not increase assert len(_global_memory["related_files"]) == 2 # Count should not increase
# Try mix of new and duplicate files # Try mix of new and duplicate files
@ -670,7 +670,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
mock_key_snippet_repository.reset_mock() mock_key_snippet_repository.reset_mock()
# Delete some but not all snippets (0 and 2) # Delete some but not all snippets (0 and 2)
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]}) result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
assert result == "Snippets deleted." assert result == "Snippets deleted."
@ -692,7 +692,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
assert result == "Snippet #3 stored." assert result == "Snippet #3 stored."
# Verify create was called with correct params # Verify create was called with correct params
mock_key_snippet_repository.create.assert_called_with( mock_key_snippet_repository.return_value.create.assert_called_with(
filepath=file4, filepath=file4,
line_number=40, line_number=40,
snippet="def func4():\n return False", snippet="def func4():\n return False",
@ -704,7 +704,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
mock_key_snippet_repository.reset_mock() mock_key_snippet_repository.reset_mock()
# Delete remaining snippets # Delete remaining snippets
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository): with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]}) result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
assert result == "Snippets deleted." assert result == "Snippets deleted."