use context vars for key facts repo

This commit is contained in:
AI Christianson 2025-03-03 16:58:03 -05:00
parent dd9af78693
commit 36e4004db0
12 changed files with 652 additions and 360 deletions

View File

@ -42,7 +42,7 @@ from ra_aid.config import (
DEFAULT_TEST_CMD_TIMEOUT,
VALID_PROVIDERS,
)
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -393,251 +393,256 @@ def main():
except Exception as e:
logger.error(f"Database migration error: {str(e)}")
# Check dependencies before proceeding
check_dependencies()
# Initialize repositories with database connection
with KeyFactRepositoryManager(db) as key_fact_repo:
# This initializes the repository and makes it available via get_key_fact_repository()
logger.debug("Initialized KeyFactRepository")
(
expert_enabled,
expert_missing,
web_research_enabled,
web_research_missing,
) = validate_environment(args) # Will exit if main env vars missing
logger.debug("Environment validation successful")
# Check dependencies before proceeding
check_dependencies()
# Validate model configuration early
model_config = models_params.get(args.provider, {}).get(
args.model or "", {}
)
supports_temperature = model_config.get(
"supports_temperature",
args.provider
in [
"anthropic",
"openai",
"openrouter",
"openai-compatible",
"deepseek",
],
)
(
expert_enabled,
expert_missing,
web_research_enabled,
web_research_missing,
) = validate_environment(args) # Will exit if main env vars missing
logger.debug("Environment validation successful")
if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature")
if args.temperature is None:
cpm(
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
# Validate model configuration early
model_config = models_params.get(args.provider, {}).get(
args.model or "", {}
)
supports_temperature = model_config.get(
"supports_temperature",
args.provider
in [
"anthropic",
"openai",
"openrouter",
"openai-compatible",
"deepseek",
],
)
if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature")
if args.temperature is None:
cpm(
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
)
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
)
status = build_status(args, expert_enabled, web_research_enabled)
console.print(
Panel(
status,
title=f"RA.Aid v{__version__}",
border_style="bright_blue",
padding=(0, 1),
)
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
)
status = build_status(args, expert_enabled, web_research_enabled)
# Handle chat mode
if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
console.print(
Panel(
status,
title=f"RA.Aid v{__version__}",
border_style="bright_blue",
padding=(0, 1),
)
)
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
sys.exit(1)
# Handle chat mode
if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
print_stage_header("Chat Mode")
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
# Get project info
try:
project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info)
except Exception as e:
logger.warning(f"Failed to get project info: {e}")
formatted_project_info = ""
# Get initial request from user
initial_request = ask_human.invoke(
{"question": "What would you like help with?"}
)
# Record chat input in database (redundant as ask_human already records it,
# but needed in case the ask_human implementation changes)
try:
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
human_input_repo = HumanInputRepository(db)
human_input_repo.create(content=initial_request, source='chat')
human_input_repo.garbage_collect()
except Exception as e:
logger.error(f"Failed to record initial chat input: {str(e)}")
# Get working directory and current date
working_directory = os.getcwd()
current_date = datetime.now().strftime("%Y-%m-%d")
# Run chat agent with CHAT_PROMPT
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"chat_mode": True,
"cowboy_mode": args.cowboy_mode,
"hil": True, # Always true in chat mode
"web_research_enabled": web_research_enabled,
"initial_request": initial_request,
"limit_tokens": args.disable_limit_tokens,
}
# Store config in global memory
_global_memory["config"] = config
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
_global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
# Create chat agent with appropriate tools
chat_agent = create_agent(
chat_model,
get_chat_tools(
expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled,
),
checkpointer=MemorySaver(),
)
# Run chat agent and exit
run_agent_with_retry(
chat_agent,
CHAT_PROMPT.format(
initial_request=initial_request,
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT
if web_research_enabled
else ""
),
working_directory=working_directory,
current_date=current_date,
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()),
project_info=formatted_project_info,
),
config,
)
return
# Validate message is provided
if not args.message:
print_error("--message is required")
sys.exit(1)
print_stage_header("Chat Mode")
# Get project info
try:
project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info)
except Exception as e:
logger.warning(f"Failed to get project info: {e}")
formatted_project_info = ""
# Get initial request from user
initial_request = ask_human.invoke(
{"question": "What would you like help with?"}
)
base_task = args.message
# Record chat input in database (redundant as ask_human already records it,
# but needed in case the ask_human implementation changes)
# Record CLI input in database
try:
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
human_input_repo = HumanInputRepository()
human_input_repo.create(content=initial_request, source='chat')
human_input_repo = HumanInputRepository(db)
human_input_repo.create(content=base_task, source='cli')
# Run garbage collection to ensure we don't exceed 100 inputs
human_input_repo.garbage_collect()
logger.debug(f"Recorded CLI input: {base_task}")
except Exception as e:
logger.error(f"Failed to record initial chat input: {str(e)}")
# Get working directory and current date
working_directory = os.getcwd()
current_date = datetime.now().strftime("%Y-%m-%d")
# Run chat agent with CHAT_PROMPT
logger.error(f"Failed to record CLI input: {str(e)}")
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"chat_mode": True,
"research_only": args.research_only,
"cowboy_mode": args.cowboy_mode,
"hil": True, # Always true in chat mode
"web_research_enabled": web_research_enabled,
"initial_request": initial_request,
"aider_config": args.aider_config,
"use_aider": args.use_aider,
"limit_tokens": args.disable_limit_tokens,
"auto_test": args.auto_test,
"test_cmd": args.test_cmd,
"max_test_cmd_retries": args.max_test_cmd_retries,
"experimental_fallback_handler": args.experimental_fallback_handler,
"test_cmd_timeout": args.test_cmd_timeout,
}
# Store config in global memory
# Store config in global memory for access by is_informational_query
_global_memory["config"] = config
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = (
args.research_model or args.model
)
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
# Create chat agent with appropriate tools
chat_agent = create_agent(
chat_model,
get_chat_tools(
expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled,
),
checkpointer=MemorySaver(),
# Run research stage
print_stage_header("Research Stage")
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model
research_model = initialize_llm(
research_provider, research_model_name, temperature=args.temperature
)
# Run chat agent and exit
run_agent_with_retry(
chat_agent,
CHAT_PROMPT.format(
initial_request=initial_request,
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT
if web_research_enabled
else ""
),
working_directory=working_directory,
current_date=current_date,
key_facts=format_key_facts_dict(KeyFactRepository().get_facts_dict()),
key_snippets=format_key_snippets_dict(KeySnippetRepository().get_snippets_dict()),
project_info=formatted_project_info,
),
config,
)
return
# Validate message is provided
if not args.message:
print_error("--message is required")
sys.exit(1)
base_task = args.message
# Record CLI input in database
try:
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
human_input_repo = HumanInputRepository()
human_input_repo.create(content=base_task, source='cli')
# Run garbage collection to ensure we don't exceed 100 inputs
human_input_repo.garbage_collect()
logger.debug(f"Recorded CLI input: {base_task}")
except Exception as e:
logger.error(f"Failed to record CLI input: {str(e)}")
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"research_only": args.research_only,
"cowboy_mode": args.cowboy_mode,
"web_research_enabled": web_research_enabled,
"aider_config": args.aider_config,
"use_aider": args.use_aider,
"limit_tokens": args.disable_limit_tokens,
"auto_test": args.auto_test,
"test_cmd": args.test_cmd,
"max_test_cmd_retries": args.max_test_cmd_retries,
"experimental_fallback_handler": args.experimental_fallback_handler,
"test_cmd_timeout": args.test_cmd_timeout,
}
# Store config in global memory for access by is_informational_query
_global_memory["config"] = config
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = (
args.research_model or args.model
)
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
# Run research stage
print_stage_header("Research Stage")
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model
research_model = initialize_llm(
research_provider, research_model_name, temperature=args.temperature
)
run_research_agent(
base_task,
research_model,
expert_enabled=expert_enabled,
research_only=args.research_only,
hil=args.hil,
memory=research_memory,
config=config,
)
# Proceed with planning and implementation if not an informational query
if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent
run_planning_agent(
run_research_agent(
base_task,
planning_model,
research_model,
expert_enabled=expert_enabled,
research_only=args.research_only,
hil=args.hil,
memory=planning_memory,
memory=research_memory,
config=config,
)
# Proceed with planning and implementation if not an informational query
if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent
run_planning_agent(
base_task,
planning_model,
expert_enabled=expert_enabled,
hil=args.hil,
memory=planning_memory,
config=config,
)
except (KeyboardInterrupt, AgentInterrupt):
print()
print(" 👋 Bye!")

View File

@ -84,8 +84,8 @@ from ra_aid.tool_configs import (
get_web_research_tools,
)
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -100,10 +100,8 @@ console = Console()
logger = get_logger(__name__)
# Initialize repositories
key_fact_repository = KeyFactRepository()
key_snippet_repository = KeySnippetRepository()
human_input_repository = HumanInputRepository()
# Import repositories using get_* functions
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
@tool
@ -391,7 +389,11 @@ def run_research_agent(
else ""
)
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
@ -400,6 +402,7 @@ def run_research_agent(
# Get the last human input, if it exists
base_task = base_task_or_query
human_input_repository = HumanInputRepository()
recent_inputs = human_input_repository.get_recent(1)
if recent_inputs and len(recent_inputs) > 0:
last_human_input = recent_inputs[0].content
@ -537,7 +540,11 @@ def run_web_research_agent(
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
@ -647,6 +654,20 @@ def run_planning_agent(
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd()
# Make sure key_facts is defined before using it
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
# Make sure key_snippets is defined before using it
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
planning_prompt = PLANNING_PROMPT.format(
current_date=current_date,
working_directory=working_directory,
@ -657,8 +678,8 @@ def run_planning_agent(
project_info=formatted_project_info,
research_notes=get_memory_value("research_notes"),
related_files="\n".join(get_related_files()),
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
key_facts=key_facts,
key_snippets=key_snippets,
work_log=get_memory_value("work_log"),
research_only_note=(
""
@ -751,6 +772,13 @@ def run_task_implementation_agent(
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd()
# Make sure key_facts is defined before using it
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
prompt = IMPLEMENTATION_PROMPT.format(
current_date=current_date,
working_directory=working_directory,
@ -759,8 +787,8 @@ def run_task_implementation_agent(
tasks=tasks,
plan=plan,
related_files=related_files,
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
research_notes=get_memory_value("research_notes"),
work_log=get_memory_value("work_log"),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",

View File

@ -6,6 +6,7 @@ facts when the total number exceeds a specified threshold. The agent evaluates a
key facts and deletes the least valuable ones to keep the database clean and relevant.
"""
import logging
from typing import List
from langchain_core.tools import tool
@ -13,8 +14,10 @@ from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
logger = logging.getLogger(__name__)
from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
@ -22,7 +25,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
console = Console()
key_fact_repository = KeyFactRepository()
human_input_repository = HumanInputRepository()
@ -51,24 +53,30 @@ def delete_key_facts(fact_ids: List[int]) -> str:
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
for fact_id in fact_ids:
# Get the fact first to display information
fact = key_fact_repository.get(fact_id)
if fact:
# Check if this fact is associated with the current human input
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
protected_facts.append((fact_id, fact.content))
continue
try:
# Get the fact first to display information
fact = get_key_fact_repository().get(fact_id)
if fact:
# Check if this fact is associated with the current human input
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
protected_facts.append((fact_id, fact.content))
continue
# Delete the fact if it's not protected
was_deleted = get_key_fact_repository().delete(fact_id)
if was_deleted:
deleted_facts.append((fact_id, fact.content))
log_work_event(f"Deleted fact {fact_id}.")
else:
failed_facts.append(fact_id)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
failed_facts.append(fact_id)
except Exception as e:
# For any other exceptions, log and continue
logger.error(f"Error processing fact {fact_id}: {str(e)}")
failed_facts.append(fact_id)
# Delete the fact if it's not protected
was_deleted = key_fact_repository.delete(fact_id)
if was_deleted:
deleted_facts.append((fact_id, fact.content))
log_work_event(f"Deleted fact {fact_id}.")
else:
failed_facts.append(fact_id)
else:
not_found_facts.append(fact_id)
# Prepare result message
result_parts = []
if deleted_facts:
@ -104,8 +112,13 @@ def run_key_facts_gc_agent() -> None:
Facts associated with the current human input are excluded from deletion.
"""
# Get the count of key facts
facts = key_fact_repository.get_all()
fact_count = len(facts)
try:
facts = get_key_fact_repository().get_all()
fact_count = len(facts)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
return # Exit the function if we can't access the repository
# Display status panel with fact count included
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
@ -161,8 +174,12 @@ def run_key_facts_gc_agent() -> None:
run_agent_with_retry(agent, prompt, agent_config)
# Get updated count
updated_facts = key_fact_repository.get_all()
updated_count = len(updated_facts)
try:
updated_facts = get_key_fact_repository().get_all()
updated_count = len(updated_facts)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository for update count: {str(e)}")
updated_count = "unknown"
# Show info panel with updated count and protected facts count
protected_count = len(protected_facts)

View File

@ -14,7 +14,7 @@ from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
@ -22,7 +22,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
console = Console()
key_snippet_repository = KeySnippetRepository()
human_input_repository = HumanInputRepository()
@ -53,7 +52,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
for snippet_id in snippet_ids:
# Get the snippet first to capture filepath for the message
snippet = key_snippet_repository.get(snippet_id)
snippet = get_key_snippet_repository().get(snippet_id)
if snippet:
filepath = snippet.filepath
@ -63,7 +62,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
continue
# Delete from database if not protected
success = key_snippet_repository.delete(snippet_id)
success = get_key_snippet_repository().delete(snippet_id)
if success:
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
console.print(
@ -110,7 +109,7 @@ def run_key_snippets_gc_agent() -> None:
Snippets associated with the current human input are excluded from deletion.
"""
# Get the count of key snippets
snippets = key_snippet_repository.get_all()
snippets = get_key_snippet_repository().get_all()
snippet_count = len(snippets)
# Display status panel with snippet count included
@ -179,7 +178,7 @@ def run_key_snippets_gc_agent() -> None:
run_agent_with_retry(agent, prompt, agent_config)
# Get updated count
updated_snippets = key_snippet_repository.get_all()
updated_snippets = get_key_snippet_repository().get_all()
updated_count = len(updated_snippets)
# Show info panel with updated count and protected snippets count

View File

@ -6,15 +6,94 @@ following the repository pattern for data access abstraction.
"""
from typing import Dict, List, Optional
import contextvars
from contextlib import contextmanager
import peewee
from ra_aid.database.connection import get_db
from ra_aid.database.models import KeyFact, initialize_database
from ra_aid.database.models import KeyFact
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
# Create contextvar to hold the KeyFactRepository instance
key_fact_repo_var = contextvars.ContextVar("key_fact_repo", default=None)
class KeyFactRepositoryManager:
"""
Context manager for KeyFactRepository.
This class provides a context manager interface for KeyFactRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with KeyFactRepositoryManager(db) as repo:
# Use the repository
fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the KeyFactRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'KeyFactRepository':
"""
Initialize the KeyFactRepository and return it.
Returns:
KeyFactRepository: The initialized repository
"""
repo = KeyFactRepository(self.db)
key_fact_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
key_fact_repo_var.set(None)
# Don't suppress exceptions
return False
def get_key_fact_repository() -> 'KeyFactRepository':
"""
Get the current KeyFactRepository instance.
Returns:
KeyFactRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with KeyFactRepositoryManager
"""
repo = key_fact_repo_var.get()
if repo is None:
raise RuntimeError(
"No KeyFactRepository available. "
"Make sure to initialize one with KeyFactRepositoryManager first."
)
return repo
class KeyFactRepository:
"""
@ -24,18 +103,21 @@ class KeyFactRepository:
abstracting the database access details from the business logic.
Example:
repo = KeyFactRepository()
fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
with DatabaseManager() as db:
with KeyFactRepositoryManager(db) as repo:
fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
"""
def __init__(self, db=None):
def __init__(self, db):
"""
Initialize the repository with an optional database connection.
Initialize the repository with a database connection.
Args:
db: Optional database connection to use. If None, will use initialize_database()
db: Database connection to use (required)
"""
if db is None:
raise ValueError("Database connection is required for KeyFactRepository")
self.db = db
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
@ -53,7 +135,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error creating the fact
"""
try:
db = self.db if self.db is not None else initialize_database()
fact = KeyFact.create(content=content, human_input_id=human_input_id)
logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact
@ -75,7 +156,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = self.db if self.db is not None else initialize_database()
return KeyFact.get_or_none(KeyFact.id == fact_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
@ -96,7 +176,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error updating the fact
"""
try:
db = self.db if self.db is not None else initialize_database()
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
@ -126,7 +205,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error deleting the fact
"""
try:
db = self.db if self.db is not None else initialize_database()
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
@ -152,7 +230,6 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = self.db if self.db is not None else initialize_database()
return list(KeyFact.select().order_by(KeyFact.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key facts: {str(e)}")

View File

@ -214,4 +214,21 @@ class KeySnippetRepository:
}
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
raise
raise
# Global singleton instance
_key_snippet_repository = None
def get_key_snippet_repository() -> KeySnippetRepository:
"""
Get or create a singleton instance of KeySnippetRepository.
Returns:
KeySnippetRepository: Singleton instance of the repository
"""
global _key_snippet_repository
if _key_snippet_repository is None:
_key_snippet_repository = KeySnippetRepository()
return _key_snippet_repository

View File

@ -1,6 +1,7 @@
"""Tools for spawning and managing sub-agents."""
from typing import Any, Dict, List, Union
import logging
from langchain_core.tools import tool
from rich.console import Console
@ -12,8 +13,9 @@ from ra_aid.agent_context import (
reset_completion_flags,
)
from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.exceptions import AgentInterrupt
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -31,8 +33,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user.
RESEARCH_AGENT_RECURSION_LIMIT = 3
console = Console()
key_fact_repository = KeyFactRepository()
key_snippet_repository = KeySnippetRepository()
logger = logging.getLogger(__name__)
@tool("request_research")
@ -57,12 +58,24 @@ def request_research(query: str) -> ResearchResult:
current_depth = _global_memory.get("agent_depth", 0)
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
print_error("Maximum research recursion depth reached")
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
return {
"completion_message": "Research stopped - maximum recursion depth reached",
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"key_facts": key_facts,
"related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
"key_snippets": key_snippets,
"success": False,
"reason": "max_depth_exceeded",
}
@ -105,12 +118,24 @@ def request_research(query: str) -> ResearchResult:
# Clear completion state
reset_completion_flags()
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = {
"completion_message": completion_message,
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"key_facts": key_facts,
"related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
"key_snippets": key_snippets,
"success": success,
"reason": reason,
}
@ -171,9 +196,15 @@ def request_web_research(query: str) -> ResearchResult:
# Clear completion state
reset_completion_flags()
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = {
"completion_message": completion_message,
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
"key_snippets": key_snippets,
"research_notes": get_memory_value("research_notes"),
"success": success,
"reason": reason,
@ -239,12 +270,24 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
# Clear completion state
reset_completion_flags()
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = {
"completion_message": completion_message,
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"key_facts": key_facts,
"related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
"key_snippets": key_snippets,
"success": success,
"reason": reason,
}
@ -324,10 +367,22 @@ def request_task_implementation(task_spec: str) -> str:
agent_crashed = is_crashed()
crash_message = get_crash_message() if agent_crashed else None
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = {
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"key_facts": key_facts,
"related_files": get_related_files(),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
"key_snippets": key_snippets,
"completion_message": completion_message,
"success": success and not agent_crashed,
"reason": reason,
@ -444,11 +499,23 @@ def request_implementation(task_spec: str) -> str:
agent_crashed = is_crashed()
crash_message = get_crash_message() if agent_crashed else None
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
response_data = {
"completion_message": completion_message,
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"key_facts": key_facts,
"related_files": get_related_files(),
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
"key_snippets": key_snippets,
"success": success and not agent_crashed,
"reason": reason,
"agent_crashed": agent_crashed,

View File

@ -1,4 +1,5 @@
import os
import logging
from typing import List
from langchain_core.tools import tool
@ -6,8 +7,10 @@ from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ..database.repositories.key_fact_repository import KeyFactRepository
from ..database.repositories.key_snippet_repository import KeySnippetRepository
logger = logging.getLogger(__name__)
from ..database.repositories.key_fact_repository import get_key_fact_repository
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
from ..llm import initialize_expert_llm
from ..model_formatters import format_key_facts_dict
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -15,8 +18,6 @@ from .memory import _global_memory, get_memory_value
console = Console()
_model = None
key_fact_repository = KeyFactRepository()
key_snippet_repository = KeySnippetRepository()
def get_model():
@ -154,10 +155,18 @@ def ask_expert(question: str) -> str:
file_paths = list(_global_memory["related_files"].values())
related_contents = read_related_files(file_paths)
# Get key snippets directly from repository and format using the formatter
key_snippets = format_key_snippets_dict(key_snippet_repository.get_snippets_dict())
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
# Get key facts directly from repository and format using the formatter
facts_dict = key_fact_repository.get_facts_dict()
key_facts = format_key_facts_dict(facts_dict)
try:
facts_dict = get_key_fact_repository().get_facts_dict()
key_facts = format_key_facts_dict(facts_dict)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
research_notes = get_memory_value("research_notes")
# Build display query (just question)

View File

@ -17,8 +17,8 @@ from ra_aid.agent_context import (
mark_should_exit,
mark_task_completed,
)
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository, get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository, get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.model_formatters import key_snippets_formatter
from ra_aid.logging_config import get_logger
@ -40,14 +40,8 @@ class SnippetInfo(TypedDict):
console = Console()
# Initialize repository for key facts
key_fact_repository = KeyFactRepository()
# Initialize repository for key snippets
key_snippet_repository = KeySnippetRepository()
# Initialize repository for human inputs
human_input_repository = HumanInputRepository()
# Import repositories using the get_* functions
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
# Global memory store
_global_memory: Dict[str, Any] = {
@ -120,17 +114,23 @@ def emit_key_facts(facts: List[str]) -> str:
# Try to get the latest human input
human_input_id = None
human_input_repo = HumanInputRepository()
try:
recent_inputs = human_input_repository.get_recent(1)
recent_inputs = human_input_repo.get_recent(1)
if recent_inputs and len(recent_inputs) > 0:
human_input_id = recent_inputs[0].id
except Exception as e:
logger.warning(f"Failed to get recent human input: {str(e)}")
for fact in facts:
# Create fact in database using repository
created_fact = key_fact_repository.create(fact, human_input_id=human_input_id)
fact_id = created_fact.id
try:
# Create fact in database using repository
created_fact = get_key_fact_repository().create(fact, human_input_id=human_input_id)
fact_id = created_fact.id
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
console.print(f"Error storing fact: {str(e)}", style="red")
continue
# Display panel with ID
console.print(
@ -147,14 +147,17 @@ def emit_key_facts(facts: List[str]) -> str:
log_work_event(f"Stored {len(facts)} key facts.")
# Check if we need to clean up facts (more than 30)
all_facts = key_fact_repository.get_all()
if len(all_facts) > 30:
# Trigger the key facts cleaner agent
try:
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
run_key_facts_gc_agent()
except Exception as e:
logger.error(f"Failed to run key facts cleaner: {str(e)}")
try:
all_facts = get_key_fact_repository().get_all()
if len(all_facts) > 30:
# Trigger the key facts cleaner agent
try:
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
run_key_facts_gc_agent()
except Exception as e:
logger.error(f"Failed to run key facts cleaner: {str(e)}")
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
return "Facts stored."
@ -222,14 +225,15 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
# Try to get the latest human input
human_input_id = None
try:
recent_inputs = human_input_repository.get_recent(1)
human_input_repo = HumanInputRepository()
recent_inputs = human_input_repo.get_recent(1)
if recent_inputs and len(recent_inputs) > 0:
human_input_id = recent_inputs[0].id
except Exception as e:
logger.warning(f"Failed to get recent human input: {str(e)}")
# Create a new key snippet in the database
key_snippet = key_snippet_repository.create(
key_snippet = get_key_snippet_repository().create(
filepath=snippet_info["filepath"],
line_number=snippet_info["line_number"],
snippet=snippet_info["snippet"],
@ -266,7 +270,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
log_work_event(f"Stored code snippet #{snippet_id}.")
# Check if we need to clean up snippets (more than 20)
all_snippets = key_snippet_repository.get_all()
all_snippets = get_key_snippet_repository().get_all()
if len(all_snippets) > 20:
# Trigger the key snippets cleaner agent
try:
@ -279,7 +283,6 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
@tool("swap_task_order")
def swap_task_order(id1: int, id2: int) -> str:
"""Swap the order of two tasks in global memory by their IDs.

View File

@ -9,7 +9,12 @@ import peewee
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import KeyFact, BaseModel
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.key_fact_repository import (
KeyFactRepository,
KeyFactRepositoryManager,
get_key_fact_repository,
key_fact_repo_var
)
@pytest.fixture
@ -41,6 +46,19 @@ def cleanup_db():
db_var.set(None)
@pytest.fixture
def cleanup_repo():
"""Reset the repository contextvar after each test."""
# Reset before the test
key_fact_repo_var.set(None)
# Run the test
yield
# Reset after the test
key_fact_repo_var.set(None)
@pytest.fixture
def setup_db(cleanup_db):
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
@ -195,4 +213,49 @@ def test_get_facts_dict(setup_db):
# Verify each fact is in the dictionary with the correct content
for fact in facts:
assert fact.id in facts_dict
assert facts_dict[fact.id] == fact.content
assert facts_dict[fact.id] == fact.content
def test_repository_init_without_db():
"""Test that KeyFactRepository raises an error when initialized without a db parameter."""
# Attempt to create a repository without a database connection
with pytest.raises(ValueError) as excinfo:
KeyFactRepository(db=None)
# Verify the correct error message
assert "Database connection is required" in str(excinfo.value)
def test_key_fact_repository_manager(setup_db, cleanup_repo):
"""Test the KeyFactRepositoryManager context manager."""
# Use the context manager to create a repository
with KeyFactRepositoryManager(setup_db) as repo:
# Verify the repository was created correctly
assert isinstance(repo, KeyFactRepository)
assert repo.db is setup_db
# Verify we can use the repository
content = "Test fact via context manager"
fact = repo.create(content)
assert fact.id is not None
assert fact.content == content
# Verify we can get the repository using get_key_fact_repository
repo_from_var = get_key_fact_repository()
assert repo_from_var is repo
# Verify the repository was removed from the context var
with pytest.raises(RuntimeError) as excinfo:
get_key_fact_repository()
assert "No KeyFactRepository available" in str(excinfo.value)
def test_get_key_fact_repository_when_not_set(cleanup_repo):
"""Test that get_key_fact_repository raises an error when no repository is in context."""
# Attempt to get the repository when none exists
with pytest.raises(RuntimeError) as excinfo:
get_key_fact_repository()
# Verify the correct error message
assert "No KeyFactRepository available" in str(excinfo.value)

View File

@ -36,9 +36,11 @@ def reset_memory():
@pytest.fixture
def mock_functions():
"""Mock functions used in agent.py"""
with patch('ra_aid.tools.agent.key_fact_repository') as mock_fact_repo, \
mock_fact_repo = MagicMock()
mock_snippet_repo = MagicMock()
with patch('ra_aid.tools.agent.get_key_fact_repository', return_value=mock_fact_repo) as mock_get_fact_repo, \
patch('ra_aid.tools.agent.format_key_facts_dict') as mock_fact_formatter, \
patch('ra_aid.tools.agent.key_snippet_repository') as mock_snippet_repo, \
patch('ra_aid.tools.agent.get_key_snippet_repository', return_value=mock_snippet_repo) as mock_get_snippet_repo, \
patch('ra_aid.tools.agent.format_key_snippets_dict') as mock_snippet_formatter, \
patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
@ -60,8 +62,8 @@ def mock_functions():
# Return all mocks as a dictionary
yield {
'key_fact_repository': mock_fact_repo,
'key_snippet_repository': mock_snippet_repo,
'get_key_fact_repository': mock_get_fact_repo,
'get_key_snippet_repository': mock_get_snippet_repo,
'format_key_facts_dict': mock_fact_formatter,
'format_key_snippets_dict': mock_snippet_formatter,
'initialize_llm': mock_llm,
@ -81,11 +83,12 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
result = request_research("test query")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
)
# Verify formatted facts are used in response
@ -105,11 +108,12 @@ def test_request_research_max_depth(reset_memory, mock_functions):
result = request_research("test query")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
)
# Verify formatted facts are used in response
@ -128,11 +132,12 @@ def test_request_research_and_implementation_uses_key_fact_repository(reset_memo
result = request_research_and_implementation("test query")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
)
# Verify formatted facts are used in response
@ -151,11 +156,12 @@ def test_request_implementation_uses_key_fact_repository(reset_memory, mock_func
result = request_implementation("test task")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
)
# Check that the formatted key facts are included in the response
@ -174,11 +180,12 @@ def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock
result = request_task_implementation("test task")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
mock_functions['get_key_fact_repository'].assert_called_once()
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
)
# Check that the formatted key facts are included in the response

View File

@ -16,12 +16,12 @@ from ra_aid.tools.memory import (
get_memory_value,
get_related_files,
get_work_log,
key_fact_repository,
key_snippet_repository,
log_work_event,
reset_work_log,
swap_task_order,
)
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.connection import DatabaseManager
from ra_aid.database.models import KeyFact
@ -60,7 +60,7 @@ def in_memory_db():
@pytest.fixture(autouse=True)
def mock_repository():
"""Mock the KeyFactRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo:
with patch('ra_aid.tools.memory.get_key_fact_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory
facts = {} # Local in-memory storage
fact_id_counter = 0
@ -79,12 +79,12 @@ def mock_repository():
facts[fact_id_counter] = fact
fact_id_counter += 1
return fact
mock_repo.create.side_effect = mock_create
mock_repo.return_value.create.side_effect = mock_create
# Mock get method
def mock_get(fact_id):
return facts.get(fact_id)
mock_repo.get.side_effect = mock_get
mock_repo.return_value.get.side_effect = mock_get
# Mock delete method
def mock_delete(fact_id):
@ -92,17 +92,17 @@ def mock_repository():
del facts[fact_id]
return True
return False
mock_repo.delete.side_effect = mock_delete
mock_repo.return_value.delete.side_effect = mock_delete
# Mock get_facts_dict method
def mock_get_facts_dict():
return {fact_id: fact.content for fact_id, fact in facts.items()}
mock_repo.get_facts_dict.side_effect = mock_get_facts_dict
mock_repo.return_value.get_facts_dict.side_effect = mock_get_facts_dict
# Mock get_all method
def mock_get_all():
return list(facts.values())
mock_repo.get_all.side_effect = mock_get_all
mock_repo.return_value.get_all.side_effect = mock_get_all
yield mock_repo
@ -159,16 +159,16 @@ def mock_key_snippet_repository():
return list(snippets.values())
# Create the actual mocks for both memory.py and key_snippets_gc_agent.py
with patch('ra_aid.tools.memory.key_snippet_repository') as memory_mock_repo, \
patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository') as agent_mock_repo:
with patch('ra_aid.tools.memory.get_key_snippet_repository') as memory_mock_repo, \
patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository') as agent_mock_repo:
# Setup both mocks with the same implementation
for mock_repo in [memory_mock_repo, agent_mock_repo]:
mock_repo.create.side_effect = mock_create
mock_repo.get.side_effect = mock_get
mock_repo.delete.side_effect = mock_delete
mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict
mock_repo.get_all.side_effect = mock_get_all
mock_repo.return_value.create.side_effect = mock_create
mock_repo.return_value.get.side_effect = mock_get
mock_repo.return_value.delete.side_effect = mock_delete
mock_repo.return_value.get_snippets_dict.side_effect = mock_get_snippets_dict
mock_repo.return_value.get_all.side_effect = mock_get_all
yield memory_mock_repo
@ -180,7 +180,7 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
assert result == "Facts stored."
# Verify the repository's create method was called
mock_repository.create.assert_called_once_with("First fact", human_input_id=ANY)
mock_repository.return_value.create.assert_called_once_with("First fact", human_input_id=ANY)
def test_get_memory_value_other_types(reset_memory):
@ -264,10 +264,10 @@ def test_emit_key_facts(reset_memory, mock_repository):
assert result == "Facts stored."
# Verify create was called for each fact
assert mock_repository.create.call_count == 3
mock_repository.create.assert_any_call("First fact", human_input_id=ANY)
mock_repository.create.assert_any_call("Second fact", human_input_id=ANY)
mock_repository.create.assert_any_call("Third fact", human_input_id=ANY)
assert mock_repository.return_value.create.call_count == 3
mock_repository.return_value.create.assert_any_call("First fact", human_input_id=ANY)
mock_repository.return_value.create.assert_any_call("Second fact", human_input_id=ANY)
mock_repository.return_value.create.assert_any_call("Third fact", human_input_id=ANY)
def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
@ -278,7 +278,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
# Mock the get_all method to return more than 30 facts
mock_repository.get_all.return_value = facts
mock_repository.return_value.get_all.return_value = facts
# Note on testing approach:
# Rather than trying to mock the dynamic import which is challenging due to
@ -295,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
# Verify that mock_repository.get_all was called,
# which is the condition that would trigger the GC agent
mock_repository.get_all.assert_called_once()
mock_repository.return_value.get_all.assert_called_once()
def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
@ -315,7 +315,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
assert result == "Snippet #0 stored."
# Verify create was called correctly
mock_key_snippet_repository.create.assert_called_with(
mock_key_snippet_repository.return_value.create.assert_called_with(
filepath="test.py",
line_number=10,
snippet="def test():\n pass",
@ -338,7 +338,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
assert result == "Snippet #1 stored."
# Verify create was called correctly
mock_key_snippet_repository.create.assert_called_with(
mock_key_snippet_repository.return_value.create.assert_called_with(
filepath="main.py",
line_number=20,
snippet="print('hello')",
@ -379,16 +379,16 @@ def test_delete_key_snippets(mock_log_work_event, reset_memory, mock_key_snippet
mock_key_snippet_repository.reset_mock()
# Test deleting mix of valid and invalid IDs
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
# Verify success message
assert result == "Snippets deleted."
# Verify repository get was called with correct IDs
mock_key_snippet_repository.get.assert_any_call(0)
mock_key_snippet_repository.get.assert_any_call(1)
mock_key_snippet_repository.get.assert_any_call(999)
mock_key_snippet_repository.return_value.get.assert_any_call(0)
mock_key_snippet_repository.return_value.get.assert_any_call(1)
mock_key_snippet_repository.return_value.get.assert_any_call(999)
# We skip verifying delete calls because they are prone to test environment issues
# The implementation logic will properly delete IDs 0 and 1 but not 999
@ -410,12 +410,12 @@ def test_delete_key_snippets_empty(mock_log_work_event, reset_memory, mock_key_s
mock_key_snippet_repository.reset_mock()
# Test with empty list
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": []})
assert result == "Snippets deleted."
# Verify no call to delete method
mock_key_snippet_repository.delete.assert_not_called()
mock_key_snippet_repository.return_value.delete.assert_not_called()
def test_emit_related_files_basic(reset_memory, tmp_path):
@ -458,13 +458,13 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path):
new_file.write_text("# New file")
# Add initial files
result = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
assert result == "Files noted."
result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
assert result1 == "Files noted."
_first_id = 0 # ID of test.py
# Try adding duplicates
result = emit_related_files.invoke({"files": [str(test_file)]})
assert result == "Files noted."
result2 = emit_related_files.invoke({"files": [str(test_file)]})
assert result2 == "Files noted."
assert len(_global_memory["related_files"]) == 2 # Count should not increase
# Try mix of new and duplicate files
@ -670,7 +670,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
mock_key_snippet_repository.reset_mock()
# Delete some but not all snippets (0 and 2)
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
assert result == "Snippets deleted."
@ -692,7 +692,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
assert result == "Snippet #3 stored."
# Verify create was called with correct params
mock_key_snippet_repository.create.assert_called_with(
mock_key_snippet_repository.return_value.create.assert_called_with(
filepath=file4,
line_number=40,
snippet="def func4():\n return False",
@ -704,7 +704,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
mock_key_snippet_repository.reset_mock()
# Delete remaining snippets
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
assert result == "Snippets deleted."