use context vars for key facts repo
This commit is contained in:
parent
dd9af78693
commit
36e4004db0
|
|
@ -42,7 +42,7 @@ from ra_aid.config import (
|
|||
DEFAULT_TEST_CMD_TIMEOUT,
|
||||
VALID_PROVIDERS,
|
||||
)
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
|
|
@ -393,251 +393,256 @@ def main():
|
|||
except Exception as e:
|
||||
logger.error(f"Database migration error: {str(e)}")
|
||||
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
# Initialize repositories with database connection
|
||||
with KeyFactRepositoryManager(db) as key_fact_repo:
|
||||
# This initializes the repository and makes it available via get_key_fact_repository()
|
||||
logger.debug("Initialized KeyFactRepository")
|
||||
|
||||
(
|
||||
expert_enabled,
|
||||
expert_missing,
|
||||
web_research_enabled,
|
||||
web_research_missing,
|
||||
) = validate_environment(args) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
||||
# Validate model configuration early
|
||||
model_config = models_params.get(args.provider, {}).get(
|
||||
args.model or "", {}
|
||||
)
|
||||
supports_temperature = model_config.get(
|
||||
"supports_temperature",
|
||||
args.provider
|
||||
in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"deepseek",
|
||||
],
|
||||
)
|
||||
(
|
||||
expert_enabled,
|
||||
expert_missing,
|
||||
web_research_enabled,
|
||||
web_research_missing,
|
||||
) = validate_environment(args) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
if supports_temperature and args.temperature is None:
|
||||
args.temperature = model_config.get("default_temperature")
|
||||
if args.temperature is None:
|
||||
cpm(
|
||||
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||
# Validate model configuration early
|
||||
model_config = models_params.get(args.provider, {}).get(
|
||||
args.model or "", {}
|
||||
)
|
||||
supports_temperature = model_config.get(
|
||||
"supports_temperature",
|
||||
args.provider
|
||||
in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"deepseek",
|
||||
],
|
||||
)
|
||||
|
||||
if supports_temperature and args.temperature is None:
|
||||
args.temperature = model_config.get("default_temperature")
|
||||
if args.temperature is None:
|
||||
cpm(
|
||||
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||
)
|
||||
args.temperature = DEFAULT_TEMPERATURE
|
||||
logger.debug(
|
||||
f"Using default temperature {args.temperature} for model {args.model}"
|
||||
)
|
||||
|
||||
status = build_status(args, expert_enabled, web_research_enabled)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
status,
|
||||
title=f"RA.Aid v{__version__}",
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
args.temperature = DEFAULT_TEMPERATURE
|
||||
logger.debug(
|
||||
f"Using default temperature {args.temperature} for model {args.model}"
|
||||
)
|
||||
|
||||
status = build_status(args, expert_enabled, web_research_enabled)
|
||||
# Handle chat mode
|
||||
if args.chat:
|
||||
# Initialize chat model with default provider/model
|
||||
chat_model = initialize_llm(
|
||||
args.provider, args.model, temperature=args.temperature
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
status,
|
||||
title=f"RA.Aid v{__version__}",
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
if args.research_only:
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
sys.exit(1)
|
||||
|
||||
# Handle chat mode
|
||||
if args.chat:
|
||||
# Initialize chat model with default provider/model
|
||||
chat_model = initialize_llm(
|
||||
args.provider, args.model, temperature=args.temperature
|
||||
)
|
||||
print_stage_header("Chat Mode")
|
||||
|
||||
if args.research_only:
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
# Get project info
|
||||
try:
|
||||
project_info = get_project_info(".", file_limit=2000)
|
||||
formatted_project_info = format_project_info(project_info)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get project info: {e}")
|
||||
formatted_project_info = ""
|
||||
|
||||
# Get initial request from user
|
||||
initial_request = ask_human.invoke(
|
||||
{"question": "What would you like help with?"}
|
||||
)
|
||||
|
||||
# Record chat input in database (redundant as ask_human already records it,
|
||||
# but needed in case the ask_human implementation changes)
|
||||
try:
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
human_input_repo = HumanInputRepository(db)
|
||||
human_input_repo.create(content=initial_request, source='chat')
|
||||
human_input_repo.garbage_collect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record initial chat input: {str(e)}")
|
||||
|
||||
# Get working directory and current date
|
||||
working_directory = os.getcwd()
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Run chat agent with CHAT_PROMPT
|
||||
config = {
|
||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||
"recursion_limit": args.recursion_limit,
|
||||
"chat_mode": True,
|
||||
"cowboy_mode": args.cowboy_mode,
|
||||
"hil": True, # Always true in chat mode
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"initial_request": initial_request,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
}
|
||||
|
||||
# Store config in global memory
|
||||
_global_memory["config"] = config
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_agent(
|
||||
chat_model,
|
||||
get_chat_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=web_research_enabled,
|
||||
),
|
||||
checkpointer=MemorySaver(),
|
||||
)
|
||||
|
||||
# Run chat agent and exit
|
||||
run_agent_with_retry(
|
||||
chat_agent,
|
||||
CHAT_PROMPT.format(
|
||||
initial_request=initial_request,
|
||||
web_research_section=(
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
if web_research_enabled
|
||||
else ""
|
||||
),
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
||||
key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()),
|
||||
project_info=formatted_project_info,
|
||||
),
|
||||
config,
|
||||
)
|
||||
return
|
||||
|
||||
# Validate message is provided
|
||||
if not args.message:
|
||||
print_error("--message is required")
|
||||
sys.exit(1)
|
||||
|
||||
print_stage_header("Chat Mode")
|
||||
|
||||
# Get project info
|
||||
try:
|
||||
project_info = get_project_info(".", file_limit=2000)
|
||||
formatted_project_info = format_project_info(project_info)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get project info: {e}")
|
||||
formatted_project_info = ""
|
||||
|
||||
# Get initial request from user
|
||||
initial_request = ask_human.invoke(
|
||||
{"question": "What would you like help with?"}
|
||||
)
|
||||
base_task = args.message
|
||||
|
||||
# Record chat input in database (redundant as ask_human already records it,
|
||||
# but needed in case the ask_human implementation changes)
|
||||
# Record CLI input in database
|
||||
try:
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
human_input_repo = HumanInputRepository()
|
||||
human_input_repo.create(content=initial_request, source='chat')
|
||||
human_input_repo = HumanInputRepository(db)
|
||||
human_input_repo.create(content=base_task, source='cli')
|
||||
# Run garbage collection to ensure we don't exceed 100 inputs
|
||||
human_input_repo.garbage_collect()
|
||||
logger.debug(f"Recorded CLI input: {base_task}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record initial chat input: {str(e)}")
|
||||
|
||||
# Get working directory and current date
|
||||
working_directory = os.getcwd()
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Run chat agent with CHAT_PROMPT
|
||||
logger.error(f"Failed to record CLI input: {str(e)}")
|
||||
config = {
|
||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||
"recursion_limit": args.recursion_limit,
|
||||
"chat_mode": True,
|
||||
"research_only": args.research_only,
|
||||
"cowboy_mode": args.cowboy_mode,
|
||||
"hil": True, # Always true in chat mode
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"initial_request": initial_request,
|
||||
"aider_config": args.aider_config,
|
||||
"use_aider": args.use_aider,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
"auto_test": args.auto_test,
|
||||
"test_cmd": args.test_cmd,
|
||||
"max_test_cmd_retries": args.max_test_cmd_retries,
|
||||
"experimental_fallback_handler": args.experimental_fallback_handler,
|
||||
"test_cmd_timeout": args.test_cmd_timeout,
|
||||
}
|
||||
|
||||
# Store config in global memory
|
||||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory["config"] = config
|
||||
|
||||
# Store base provider/model configuration
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
|
||||
# Store expert provider/model (no fallback)
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
_global_memory["config"]["planner_provider"] = (
|
||||
args.planner_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||
|
||||
# Store research config with fallback to base values
|
||||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = (
|
||||
args.research_model or args.model
|
||||
)
|
||||
|
||||
# Store temperature in global config
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_agent(
|
||||
chat_model,
|
||||
get_chat_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=web_research_enabled,
|
||||
),
|
||||
checkpointer=MemorySaver(),
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
# Initialize research model with potential overrides
|
||||
research_provider = args.research_provider or args.provider
|
||||
research_model_name = args.research_model or args.model
|
||||
research_model = initialize_llm(
|
||||
research_provider, research_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Run chat agent and exit
|
||||
run_agent_with_retry(
|
||||
chat_agent,
|
||||
CHAT_PROMPT.format(
|
||||
initial_request=initial_request,
|
||||
web_research_section=(
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
if web_research_enabled
|
||||
else ""
|
||||
),
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
key_facts=format_key_facts_dict(KeyFactRepository().get_facts_dict()),
|
||||
key_snippets=format_key_snippets_dict(KeySnippetRepository().get_snippets_dict()),
|
||||
project_info=formatted_project_info,
|
||||
),
|
||||
config,
|
||||
)
|
||||
return
|
||||
|
||||
# Validate message is provided
|
||||
if not args.message:
|
||||
print_error("--message is required")
|
||||
sys.exit(1)
|
||||
|
||||
base_task = args.message
|
||||
|
||||
# Record CLI input in database
|
||||
try:
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
human_input_repo = HumanInputRepository()
|
||||
human_input_repo.create(content=base_task, source='cli')
|
||||
# Run garbage collection to ensure we don't exceed 100 inputs
|
||||
human_input_repo.garbage_collect()
|
||||
logger.debug(f"Recorded CLI input: {base_task}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record CLI input: {str(e)}")
|
||||
config = {
|
||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||
"recursion_limit": args.recursion_limit,
|
||||
"research_only": args.research_only,
|
||||
"cowboy_mode": args.cowboy_mode,
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"aider_config": args.aider_config,
|
||||
"use_aider": args.use_aider,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
"auto_test": args.auto_test,
|
||||
"test_cmd": args.test_cmd,
|
||||
"max_test_cmd_retries": args.max_test_cmd_retries,
|
||||
"experimental_fallback_handler": args.experimental_fallback_handler,
|
||||
"test_cmd_timeout": args.test_cmd_timeout,
|
||||
}
|
||||
|
||||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory["config"] = config
|
||||
|
||||
# Store base provider/model configuration
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
|
||||
# Store expert provider/model (no fallback)
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
_global_memory["config"]["planner_provider"] = (
|
||||
args.planner_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||
|
||||
# Store research config with fallback to base values
|
||||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = (
|
||||
args.research_model or args.model
|
||||
)
|
||||
|
||||
# Store temperature in global config
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
# Initialize research model with potential overrides
|
||||
research_provider = args.research_provider or args.provider
|
||||
research_model_name = args.research_model or args.model
|
||||
research_model = initialize_llm(
|
||||
research_provider, research_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
run_research_agent(
|
||||
base_task,
|
||||
research_model,
|
||||
expert_enabled=expert_enabled,
|
||||
research_only=args.research_only,
|
||||
hil=args.hil,
|
||||
memory=research_memory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Proceed with planning and implementation if not an informational query
|
||||
if not is_informational_query():
|
||||
# Initialize planning model with potential overrides
|
||||
planner_provider = args.planner_provider or args.provider
|
||||
planner_model_name = args.planner_model or args.model
|
||||
planning_model = initialize_llm(
|
||||
planner_provider, planner_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Run planning agent
|
||||
run_planning_agent(
|
||||
run_research_agent(
|
||||
base_task,
|
||||
planning_model,
|
||||
research_model,
|
||||
expert_enabled=expert_enabled,
|
||||
research_only=args.research_only,
|
||||
hil=args.hil,
|
||||
memory=planning_memory,
|
||||
memory=research_memory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Proceed with planning and implementation if not an informational query
|
||||
if not is_informational_query():
|
||||
# Initialize planning model with potential overrides
|
||||
planner_provider = args.planner_provider or args.provider
|
||||
planner_model_name = args.planner_model or args.model
|
||||
planning_model = initialize_llm(
|
||||
planner_provider, planner_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Run planning agent
|
||||
run_planning_agent(
|
||||
base_task,
|
||||
planning_model,
|
||||
expert_enabled=expert_enabled,
|
||||
hil=args.hil,
|
||||
memory=planning_memory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
print()
|
||||
print(" 👋 Bye!")
|
||||
|
|
|
|||
|
|
@ -84,8 +84,8 @@ from ra_aid.tool_configs import (
|
|||
get_web_research_tools,
|
||||
)
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
|
|
@ -100,10 +100,8 @@ console = Console()
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Initialize repositories
|
||||
key_fact_repository = KeyFactRepository()
|
||||
key_snippet_repository = KeySnippetRepository()
|
||||
human_input_repository = HumanInputRepository()
|
||||
# Import repositories using get_* functions
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -391,7 +389,11 @@ def run_research_agent(
|
|||
else ""
|
||||
)
|
||||
|
||||
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
code_snippets = _global_memory.get("code_snippets", "")
|
||||
related_files = _global_memory.get("related_files", "")
|
||||
|
||||
|
|
@ -400,6 +402,7 @@ def run_research_agent(
|
|||
|
||||
# Get the last human input, if it exists
|
||||
base_task = base_task_or_query
|
||||
human_input_repository = HumanInputRepository()
|
||||
recent_inputs = human_input_repository.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
last_human_input = recent_inputs[0].content
|
||||
|
|
@ -537,7 +540,11 @@ def run_web_research_agent(
|
|||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||
|
||||
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
code_snippets = _global_memory.get("code_snippets", "")
|
||||
related_files = _global_memory.get("related_files", "")
|
||||
|
||||
|
|
@ -647,6 +654,20 @@ def run_planning_agent(
|
|||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
working_directory = os.getcwd()
|
||||
|
||||
# Make sure key_facts is defined before using it
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
# Make sure key_snippets is defined before using it
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
planning_prompt = PLANNING_PROMPT.format(
|
||||
current_date=current_date,
|
||||
working_directory=working_directory,
|
||||
|
|
@ -657,8 +678,8 @@ def run_planning_agent(
|
|||
project_info=formatted_project_info,
|
||||
research_notes=get_memory_value("research_notes"),
|
||||
related_files="\n".join(get_related_files()),
|
||||
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
key_facts=key_facts,
|
||||
key_snippets=key_snippets,
|
||||
work_log=get_memory_value("work_log"),
|
||||
research_only_note=(
|
||||
""
|
||||
|
|
@ -751,6 +772,13 @@ def run_task_implementation_agent(
|
|||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
working_directory = os.getcwd()
|
||||
|
||||
# Make sure key_facts is defined before using it
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
prompt = IMPLEMENTATION_PROMPT.format(
|
||||
current_date=current_date,
|
||||
working_directory=working_directory,
|
||||
|
|
@ -759,8 +787,8 @@ def run_task_implementation_agent(
|
|||
tasks=tasks,
|
||||
plan=plan,
|
||||
related_files=related_files,
|
||||
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
key_facts=key_facts,
|
||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||
research_notes=get_memory_value("research_notes"),
|
||||
work_log=get_memory_value("work_log"),
|
||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ facts when the total number exceeds a specified threshold. The agent evaluates a
|
|||
key facts and deletes the least valuable ones to keep the database clean and relevant.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
|
@ -13,8 +14,10 @@ from rich.console import Console
|
|||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||
|
|
@ -22,7 +25,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
|
|||
|
||||
|
||||
console = Console()
|
||||
key_fact_repository = KeyFactRepository()
|
||||
human_input_repository = HumanInputRepository()
|
||||
|
||||
|
||||
|
|
@ -51,24 +53,30 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
|
||||
|
||||
for fact_id in fact_ids:
|
||||
# Get the fact first to display information
|
||||
fact = key_fact_repository.get(fact_id)
|
||||
if fact:
|
||||
# Check if this fact is associated with the current human input
|
||||
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
|
||||
protected_facts.append((fact_id, fact.content))
|
||||
continue
|
||||
try:
|
||||
# Get the fact first to display information
|
||||
fact = get_key_fact_repository().get(fact_id)
|
||||
if fact:
|
||||
# Check if this fact is associated with the current human input
|
||||
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
|
||||
protected_facts.append((fact_id, fact.content))
|
||||
continue
|
||||
|
||||
# Delete the fact if it's not protected
|
||||
was_deleted = get_key_fact_repository().delete(fact_id)
|
||||
if was_deleted:
|
||||
deleted_facts.append((fact_id, fact.content))
|
||||
log_work_event(f"Deleted fact {fact_id}.")
|
||||
else:
|
||||
failed_facts.append(fact_id)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
failed_facts.append(fact_id)
|
||||
except Exception as e:
|
||||
# For any other exceptions, log and continue
|
||||
logger.error(f"Error processing fact {fact_id}: {str(e)}")
|
||||
failed_facts.append(fact_id)
|
||||
|
||||
# Delete the fact if it's not protected
|
||||
was_deleted = key_fact_repository.delete(fact_id)
|
||||
if was_deleted:
|
||||
deleted_facts.append((fact_id, fact.content))
|
||||
log_work_event(f"Deleted fact {fact_id}.")
|
||||
else:
|
||||
failed_facts.append(fact_id)
|
||||
else:
|
||||
not_found_facts.append(fact_id)
|
||||
|
||||
# Prepare result message
|
||||
result_parts = []
|
||||
if deleted_facts:
|
||||
|
|
@ -104,8 +112,13 @@ def run_key_facts_gc_agent() -> None:
|
|||
Facts associated with the current human input are excluded from deletion.
|
||||
"""
|
||||
# Get the count of key facts
|
||||
facts = key_fact_repository.get_all()
|
||||
fact_count = len(facts)
|
||||
try:
|
||||
facts = get_key_fact_repository().get_all()
|
||||
fact_count = len(facts)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||
return # Exit the function if we can't access the repository
|
||||
|
||||
# Display status panel with fact count included
|
||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
|
||||
|
|
@ -161,8 +174,12 @@ def run_key_facts_gc_agent() -> None:
|
|||
run_agent_with_retry(agent, prompt, agent_config)
|
||||
|
||||
# Get updated count
|
||||
updated_facts = key_fact_repository.get_all()
|
||||
updated_count = len(updated_facts)
|
||||
try:
|
||||
updated_facts = get_key_fact_repository().get_all()
|
||||
updated_count = len(updated_facts)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository for update count: {str(e)}")
|
||||
updated_count = "unknown"
|
||||
|
||||
# Show info panel with updated count and protected facts count
|
||||
protected_count = len(protected_facts)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||
|
|
@ -22,7 +22,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
|
|||
|
||||
|
||||
console = Console()
|
||||
key_snippet_repository = KeySnippetRepository()
|
||||
human_input_repository = HumanInputRepository()
|
||||
|
||||
|
||||
|
|
@ -53,7 +52,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
|
||||
for snippet_id in snippet_ids:
|
||||
# Get the snippet first to capture filepath for the message
|
||||
snippet = key_snippet_repository.get(snippet_id)
|
||||
snippet = get_key_snippet_repository().get(snippet_id)
|
||||
if snippet:
|
||||
filepath = snippet.filepath
|
||||
|
||||
|
|
@ -63,7 +62,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
continue
|
||||
|
||||
# Delete from database if not protected
|
||||
success = key_snippet_repository.delete(snippet_id)
|
||||
success = get_key_snippet_repository().delete(snippet_id)
|
||||
if success:
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||
console.print(
|
||||
|
|
@ -110,7 +109,7 @@ def run_key_snippets_gc_agent() -> None:
|
|||
Snippets associated with the current human input are excluded from deletion.
|
||||
"""
|
||||
# Get the count of key snippets
|
||||
snippets = key_snippet_repository.get_all()
|
||||
snippets = get_key_snippet_repository().get_all()
|
||||
snippet_count = len(snippets)
|
||||
|
||||
# Display status panel with snippet count included
|
||||
|
|
@ -179,7 +178,7 @@ def run_key_snippets_gc_agent() -> None:
|
|||
run_agent_with_retry(agent, prompt, agent_config)
|
||||
|
||||
# Get updated count
|
||||
updated_snippets = key_snippet_repository.get_all()
|
||||
updated_snippets = get_key_snippet_repository().get_all()
|
||||
updated_count = len(updated_snippets)
|
||||
|
||||
# Show info panel with updated count and protected snippets count
|
||||
|
|
|
|||
|
|
@ -6,15 +6,94 @@ following the repository pattern for data access abstraction.
|
|||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.database.models import KeyFact, initialize_database
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Create contextvar to hold the KeyFactRepository instance
|
||||
key_fact_repo_var = contextvars.ContextVar("key_fact_repo", default=None)
|
||||
|
||||
|
||||
class KeyFactRepositoryManager:
|
||||
"""
|
||||
Context manager for KeyFactRepository.
|
||||
|
||||
This class provides a context manager interface for KeyFactRepository,
|
||||
using the contextvars approach for thread safety.
|
||||
|
||||
Example:
|
||||
with DatabaseManager() as db:
|
||||
with KeyFactRepositoryManager(db) as repo:
|
||||
# Use the repository
|
||||
fact = repo.create("Important fact about the project")
|
||||
all_facts = repo.get_all()
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize the KeyFactRepositoryManager.
|
||||
|
||||
Args:
|
||||
db: Database connection to use (required)
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def __enter__(self) -> 'KeyFactRepository':
|
||||
"""
|
||||
Initialize the KeyFactRepository and return it.
|
||||
|
||||
Returns:
|
||||
KeyFactRepository: The initialized repository
|
||||
"""
|
||||
repo = KeyFactRepository(self.db)
|
||||
key_fact_repo_var.set(repo)
|
||||
return repo
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type],
|
||||
exc_val: Optional[Exception],
|
||||
exc_tb: Optional[object],
|
||||
) -> None:
|
||||
"""
|
||||
Reset the repository when exiting the context.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised
|
||||
exc_val: The exception value if an exception was raised
|
||||
exc_tb: The traceback if an exception was raised
|
||||
"""
|
||||
# Reset the contextvar to None
|
||||
key_fact_repo_var.set(None)
|
||||
|
||||
# Don't suppress exceptions
|
||||
return False
|
||||
|
||||
|
||||
def get_key_fact_repository() -> 'KeyFactRepository':
|
||||
"""
|
||||
Get the current KeyFactRepository instance.
|
||||
|
||||
Returns:
|
||||
KeyFactRepository: The current repository instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no repository has been initialized with KeyFactRepositoryManager
|
||||
"""
|
||||
repo = key_fact_repo_var.get()
|
||||
if repo is None:
|
||||
raise RuntimeError(
|
||||
"No KeyFactRepository available. "
|
||||
"Make sure to initialize one with KeyFactRepositoryManager first."
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
class KeyFactRepository:
|
||||
"""
|
||||
|
|
@ -24,18 +103,21 @@ class KeyFactRepository:
|
|||
abstracting the database access details from the business logic.
|
||||
|
||||
Example:
|
||||
repo = KeyFactRepository()
|
||||
fact = repo.create("Important fact about the project")
|
||||
all_facts = repo.get_all()
|
||||
with DatabaseManager() as db:
|
||||
with KeyFactRepositoryManager(db) as repo:
|
||||
fact = repo.create("Important fact about the project")
|
||||
all_facts = repo.get_all()
|
||||
"""
|
||||
|
||||
def __init__(self, db=None):
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize the repository with an optional database connection.
|
||||
Initialize the repository with a database connection.
|
||||
|
||||
Args:
|
||||
db: Optional database connection to use. If None, will use initialize_database()
|
||||
db: Database connection to use (required)
|
||||
"""
|
||||
if db is None:
|
||||
raise ValueError("Database connection is required for KeyFactRepository")
|
||||
self.db = db
|
||||
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
|
||||
|
|
@ -53,7 +135,6 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error creating the fact
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
fact = KeyFact.create(content=content, human_input_id=human_input_id)
|
||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||
return fact
|
||||
|
|
@ -75,7 +156,6 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||
|
|
@ -96,7 +176,6 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error updating the fact
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
|
|
@ -126,7 +205,6 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error deleting the fact
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
|
|
@ -152,7 +230,6 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
return list(KeyFact.select().order_by(KeyFact.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -214,4 +214,21 @@ class KeySnippetRepository:
|
|||
}
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_key_snippet_repository = None
|
||||
|
||||
|
||||
def get_key_snippet_repository() -> KeySnippetRepository:
|
||||
"""
|
||||
Get or create a singleton instance of KeySnippetRepository.
|
||||
|
||||
Returns:
|
||||
KeySnippetRepository: Singleton instance of the repository
|
||||
"""
|
||||
global _key_snippet_repository
|
||||
if _key_snippet_repository is None:
|
||||
_key_snippet_repository = KeySnippetRepository()
|
||||
return _key_snippet_repository
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tools for spawning and managing sub-agents."""
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
import logging
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
|
@ -12,8 +13,9 @@ from ra_aid.agent_context import (
|
|||
reset_completion_flags,
|
||||
)
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
|
|
@ -31,8 +33,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user.
|
|||
RESEARCH_AGENT_RECURSION_LIMIT = 3
|
||||
|
||||
console = Console()
|
||||
key_fact_repository = KeyFactRepository()
|
||||
key_snippet_repository = KeySnippetRepository()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool("request_research")
|
||||
|
|
@ -57,12 +58,24 @@ def request_research(query: str) -> ResearchResult:
|
|||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
||||
print_error("Maximum research recursion depth reached")
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
return {
|
||||
"completion_message": "Research stopped - maximum recursion depth reached",
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"key_facts": key_facts,
|
||||
"related_files": get_related_files(),
|
||||
"research_notes": get_memory_value("research_notes"),
|
||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
"key_snippets": key_snippets,
|
||||
"success": False,
|
||||
"reason": "max_depth_exceeded",
|
||||
}
|
||||
|
|
@ -105,12 +118,24 @@ def request_research(query: str) -> ResearchResult:
|
|||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"key_facts": key_facts,
|
||||
"related_files": get_related_files(),
|
||||
"research_notes": get_memory_value("research_notes"),
|
||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
"key_snippets": key_snippets,
|
||||
"success": success,
|
||||
"reason": reason,
|
||||
}
|
||||
|
|
@ -171,9 +196,15 @@ def request_web_research(query: str) -> ResearchResult:
|
|||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
"key_snippets": key_snippets,
|
||||
"research_notes": get_memory_value("research_notes"),
|
||||
"success": success,
|
||||
"reason": reason,
|
||||
|
|
@ -239,12 +270,24 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
|||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"key_facts": key_facts,
|
||||
"related_files": get_related_files(),
|
||||
"research_notes": get_memory_value("research_notes"),
|
||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
"key_snippets": key_snippets,
|
||||
"success": success,
|
||||
"reason": reason,
|
||||
}
|
||||
|
|
@ -324,10 +367,22 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
agent_crashed = is_crashed()
|
||||
crash_message = get_crash_message() if agent_crashed else None
|
||||
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
response_data = {
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"key_facts": key_facts,
|
||||
"related_files": get_related_files(),
|
||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
"key_snippets": key_snippets,
|
||||
"completion_message": completion_message,
|
||||
"success": success and not agent_crashed,
|
||||
"reason": reason,
|
||||
|
|
@ -444,11 +499,23 @@ def request_implementation(task_spec: str) -> str:
|
|||
agent_crashed = is_crashed()
|
||||
crash_message = get_crash_message() if agent_crashed else None
|
||||
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"key_facts": key_facts,
|
||||
"related_files": get_related_files(),
|
||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
||||
"key_snippets": key_snippets,
|
||||
"success": success and not agent_crashed,
|
||||
"reason": reason,
|
||||
"agent_crashed": agent_crashed,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
|
@ -6,8 +7,10 @@ from rich.console import Console
|
|||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ..database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ..database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ..llm import initialize_expert_llm
|
||||
from ..model_formatters import format_key_facts_dict
|
||||
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
|
|
@ -15,8 +18,6 @@ from .memory import _global_memory, get_memory_value
|
|||
|
||||
console = Console()
|
||||
_model = None
|
||||
key_fact_repository = KeyFactRepository()
|
||||
key_snippet_repository = KeySnippetRepository()
|
||||
|
||||
|
||||
def get_model():
|
||||
|
|
@ -154,10 +155,18 @@ def ask_expert(question: str) -> str:
|
|||
file_paths = list(_global_memory["related_files"].values())
|
||||
related_contents = read_related_files(file_paths)
|
||||
# Get key snippets directly from repository and format using the formatter
|
||||
key_snippets = format_key_snippets_dict(key_snippet_repository.get_snippets_dict())
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
# Get key facts directly from repository and format using the formatter
|
||||
facts_dict = key_fact_repository.get_facts_dict()
|
||||
key_facts = format_key_facts_dict(facts_dict)
|
||||
try:
|
||||
facts_dict = get_key_fact_repository().get_facts_dict()
|
||||
key_facts = format_key_facts_dict(facts_dict)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
research_notes = get_memory_value("research_notes")
|
||||
|
||||
# Build display query (just question)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ from ra_aid.agent_context import (
|
|||
mark_should_exit,
|
||||
mark_task_completed,
|
||||
)
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository, get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository, get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.model_formatters import key_snippets_formatter
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
|
@ -40,14 +40,8 @@ class SnippetInfo(TypedDict):
|
|||
|
||||
console = Console()
|
||||
|
||||
# Initialize repository for key facts
|
||||
key_fact_repository = KeyFactRepository()
|
||||
|
||||
# Initialize repository for key snippets
|
||||
key_snippet_repository = KeySnippetRepository()
|
||||
|
||||
# Initialize repository for human inputs
|
||||
human_input_repository = HumanInputRepository()
|
||||
# Import repositories using the get_* functions
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
|
||||
# Global memory store
|
||||
_global_memory: Dict[str, Any] = {
|
||||
|
|
@ -120,17 +114,23 @@ def emit_key_facts(facts: List[str]) -> str:
|
|||
|
||||
# Try to get the latest human input
|
||||
human_input_id = None
|
||||
human_input_repo = HumanInputRepository()
|
||||
try:
|
||||
recent_inputs = human_input_repository.get_recent(1)
|
||||
recent_inputs = human_input_repo.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
human_input_id = recent_inputs[0].id
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get recent human input: {str(e)}")
|
||||
|
||||
for fact in facts:
|
||||
# Create fact in database using repository
|
||||
created_fact = key_fact_repository.create(fact, human_input_id=human_input_id)
|
||||
fact_id = created_fact.id
|
||||
try:
|
||||
# Create fact in database using repository
|
||||
created_fact = get_key_fact_repository().create(fact, human_input_id=human_input_id)
|
||||
fact_id = created_fact.id
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
console.print(f"Error storing fact: {str(e)}", style="red")
|
||||
continue
|
||||
|
||||
# Display panel with ID
|
||||
console.print(
|
||||
|
|
@ -147,14 +147,17 @@ def emit_key_facts(facts: List[str]) -> str:
|
|||
log_work_event(f"Stored {len(facts)} key facts.")
|
||||
|
||||
# Check if we need to clean up facts (more than 30)
|
||||
all_facts = key_fact_repository.get_all()
|
||||
if len(all_facts) > 30:
|
||||
# Trigger the key facts cleaner agent
|
||||
try:
|
||||
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
|
||||
run_key_facts_gc_agent()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run key facts cleaner: {str(e)}")
|
||||
try:
|
||||
all_facts = get_key_fact_repository().get_all()
|
||||
if len(all_facts) > 30:
|
||||
# Trigger the key facts cleaner agent
|
||||
try:
|
||||
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
|
||||
run_key_facts_gc_agent()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run key facts cleaner: {str(e)}")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
|
||||
return "Facts stored."
|
||||
|
||||
|
|
@ -222,14 +225,15 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
# Try to get the latest human input
|
||||
human_input_id = None
|
||||
try:
|
||||
recent_inputs = human_input_repository.get_recent(1)
|
||||
human_input_repo = HumanInputRepository()
|
||||
recent_inputs = human_input_repo.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
human_input_id = recent_inputs[0].id
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get recent human input: {str(e)}")
|
||||
|
||||
# Create a new key snippet in the database
|
||||
key_snippet = key_snippet_repository.create(
|
||||
key_snippet = get_key_snippet_repository().create(
|
||||
filepath=snippet_info["filepath"],
|
||||
line_number=snippet_info["line_number"],
|
||||
snippet=snippet_info["snippet"],
|
||||
|
|
@ -266,7 +270,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
log_work_event(f"Stored code snippet #{snippet_id}.")
|
||||
|
||||
# Check if we need to clean up snippets (more than 20)
|
||||
all_snippets = key_snippet_repository.get_all()
|
||||
all_snippets = get_key_snippet_repository().get_all()
|
||||
if len(all_snippets) > 20:
|
||||
# Trigger the key snippets cleaner agent
|
||||
try:
|
||||
|
|
@ -279,7 +283,6 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
|
||||
|
||||
|
||||
|
||||
@tool("swap_task_order")
|
||||
def swap_task_order(id1: int, id2: int) -> str:
|
||||
"""Swap the order of two tasks in global memory by their IDs.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,12 @@ import peewee
|
|||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import KeyFact, BaseModel
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import (
|
||||
KeyFactRepository,
|
||||
KeyFactRepositoryManager,
|
||||
get_key_fact_repository,
|
||||
key_fact_repo_var
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -41,6 +46,19 @@ def cleanup_db():
|
|||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_repo():
|
||||
"""Reset the repository contextvar after each test."""
|
||||
# Reset before the test
|
||||
key_fact_repo_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
key_fact_repo_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
|
||||
|
|
@ -195,4 +213,49 @@ def test_get_facts_dict(setup_db):
|
|||
# Verify each fact is in the dictionary with the correct content
|
||||
for fact in facts:
|
||||
assert fact.id in facts_dict
|
||||
assert facts_dict[fact.id] == fact.content
|
||||
assert facts_dict[fact.id] == fact.content
|
||||
|
||||
|
||||
def test_repository_init_without_db():
|
||||
"""Test that KeyFactRepository raises an error when initialized without a db parameter."""
|
||||
# Attempt to create a repository without a database connection
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
KeyFactRepository(db=None)
|
||||
|
||||
# Verify the correct error message
|
||||
assert "Database connection is required" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_key_fact_repository_manager(setup_db, cleanup_repo):
|
||||
"""Test the KeyFactRepositoryManager context manager."""
|
||||
# Use the context manager to create a repository
|
||||
with KeyFactRepositoryManager(setup_db) as repo:
|
||||
# Verify the repository was created correctly
|
||||
assert isinstance(repo, KeyFactRepository)
|
||||
assert repo.db is setup_db
|
||||
|
||||
# Verify we can use the repository
|
||||
content = "Test fact via context manager"
|
||||
fact = repo.create(content)
|
||||
assert fact.id is not None
|
||||
assert fact.content == content
|
||||
|
||||
# Verify we can get the repository using get_key_fact_repository
|
||||
repo_from_var = get_key_fact_repository()
|
||||
assert repo_from_var is repo
|
||||
|
||||
# Verify the repository was removed from the context var
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_key_fact_repository()
|
||||
|
||||
assert "No KeyFactRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_key_fact_repository_when_not_set(cleanup_repo):
|
||||
"""Test that get_key_fact_repository raises an error when no repository is in context."""
|
||||
# Attempt to get the repository when none exists
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_key_fact_repository()
|
||||
|
||||
# Verify the correct error message
|
||||
assert "No KeyFactRepository available" in str(excinfo.value)
|
||||
|
|
@ -36,9 +36,11 @@ def reset_memory():
|
|||
@pytest.fixture
|
||||
def mock_functions():
|
||||
"""Mock functions used in agent.py"""
|
||||
with patch('ra_aid.tools.agent.key_fact_repository') as mock_fact_repo, \
|
||||
mock_fact_repo = MagicMock()
|
||||
mock_snippet_repo = MagicMock()
|
||||
with patch('ra_aid.tools.agent.get_key_fact_repository', return_value=mock_fact_repo) as mock_get_fact_repo, \
|
||||
patch('ra_aid.tools.agent.format_key_facts_dict') as mock_fact_formatter, \
|
||||
patch('ra_aid.tools.agent.key_snippet_repository') as mock_snippet_repo, \
|
||||
patch('ra_aid.tools.agent.get_key_snippet_repository', return_value=mock_snippet_repo) as mock_get_snippet_repo, \
|
||||
patch('ra_aid.tools.agent.format_key_snippets_dict') as mock_snippet_formatter, \
|
||||
patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \
|
||||
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
||||
|
|
@ -60,8 +62,8 @@ def mock_functions():
|
|||
|
||||
# Return all mocks as a dictionary
|
||||
yield {
|
||||
'key_fact_repository': mock_fact_repo,
|
||||
'key_snippet_repository': mock_snippet_repo,
|
||||
'get_key_fact_repository': mock_get_fact_repo,
|
||||
'get_key_snippet_repository': mock_get_snippet_repo,
|
||||
'format_key_facts_dict': mock_fact_formatter,
|
||||
'format_key_snippets_dict': mock_snippet_formatter,
|
||||
'initialize_llm': mock_llm,
|
||||
|
|
@ -81,11 +83,12 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
|
|||
result = request_research("test query")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Verify formatted facts are used in response
|
||||
|
|
@ -105,11 +108,12 @@ def test_request_research_max_depth(reset_memory, mock_functions):
|
|||
result = request_research("test query")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Verify formatted facts are used in response
|
||||
|
|
@ -128,11 +132,12 @@ def test_request_research_and_implementation_uses_key_fact_repository(reset_memo
|
|||
result = request_research_and_implementation("test query")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Verify formatted facts are used in response
|
||||
|
|
@ -151,11 +156,12 @@ def test_request_implementation_uses_key_fact_repository(reset_memory, mock_func
|
|||
result = request_implementation("test task")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Check that the formatted key facts are included in the response
|
||||
|
|
@ -174,11 +180,12 @@ def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock
|
|||
result = request_task_implementation("test task")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Check that the formatted key facts are included in the response
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ from ra_aid.tools.memory import (
|
|||
get_memory_value,
|
||||
get_related_files,
|
||||
get_work_log,
|
||||
key_fact_repository,
|
||||
key_snippet_repository,
|
||||
log_work_event,
|
||||
reset_work_log,
|
||||
swap_task_order,
|
||||
)
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.connection import DatabaseManager
|
||||
from ra_aid.database.models import KeyFact
|
||||
|
||||
|
|
@ -60,7 +60,7 @@ def in_memory_db():
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_repository():
|
||||
"""Mock the KeyFactRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo:
|
||||
with patch('ra_aid.tools.memory.get_key_fact_repository') as mock_repo:
|
||||
# Setup the mock repository to behave like the original, but using memory
|
||||
facts = {} # Local in-memory storage
|
||||
fact_id_counter = 0
|
||||
|
|
@ -79,12 +79,12 @@ def mock_repository():
|
|||
facts[fact_id_counter] = fact
|
||||
fact_id_counter += 1
|
||||
return fact
|
||||
mock_repo.create.side_effect = mock_create
|
||||
mock_repo.return_value.create.side_effect = mock_create
|
||||
|
||||
# Mock get method
|
||||
def mock_get(fact_id):
|
||||
return facts.get(fact_id)
|
||||
mock_repo.get.side_effect = mock_get
|
||||
mock_repo.return_value.get.side_effect = mock_get
|
||||
|
||||
# Mock delete method
|
||||
def mock_delete(fact_id):
|
||||
|
|
@ -92,17 +92,17 @@ def mock_repository():
|
|||
del facts[fact_id]
|
||||
return True
|
||||
return False
|
||||
mock_repo.delete.side_effect = mock_delete
|
||||
mock_repo.return_value.delete.side_effect = mock_delete
|
||||
|
||||
# Mock get_facts_dict method
|
||||
def mock_get_facts_dict():
|
||||
return {fact_id: fact.content for fact_id, fact in facts.items()}
|
||||
mock_repo.get_facts_dict.side_effect = mock_get_facts_dict
|
||||
mock_repo.return_value.get_facts_dict.side_effect = mock_get_facts_dict
|
||||
|
||||
# Mock get_all method
|
||||
def mock_get_all():
|
||||
return list(facts.values())
|
||||
mock_repo.get_all.side_effect = mock_get_all
|
||||
mock_repo.return_value.get_all.side_effect = mock_get_all
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
|
@ -159,16 +159,16 @@ def mock_key_snippet_repository():
|
|||
return list(snippets.values())
|
||||
|
||||
# Create the actual mocks for both memory.py and key_snippets_gc_agent.py
|
||||
with patch('ra_aid.tools.memory.key_snippet_repository') as memory_mock_repo, \
|
||||
patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository') as agent_mock_repo:
|
||||
with patch('ra_aid.tools.memory.get_key_snippet_repository') as memory_mock_repo, \
|
||||
patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository') as agent_mock_repo:
|
||||
|
||||
# Setup both mocks with the same implementation
|
||||
for mock_repo in [memory_mock_repo, agent_mock_repo]:
|
||||
mock_repo.create.side_effect = mock_create
|
||||
mock_repo.get.side_effect = mock_get
|
||||
mock_repo.delete.side_effect = mock_delete
|
||||
mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict
|
||||
mock_repo.get_all.side_effect = mock_get_all
|
||||
mock_repo.return_value.create.side_effect = mock_create
|
||||
mock_repo.return_value.get.side_effect = mock_get
|
||||
mock_repo.return_value.delete.side_effect = mock_delete
|
||||
mock_repo.return_value.get_snippets_dict.side_effect = mock_get_snippets_dict
|
||||
mock_repo.return_value.get_all.side_effect = mock_get_all
|
||||
|
||||
yield memory_mock_repo
|
||||
|
||||
|
|
@ -180,7 +180,7 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
|||
assert result == "Facts stored."
|
||||
|
||||
# Verify the repository's create method was called
|
||||
mock_repository.create.assert_called_once_with("First fact", human_input_id=ANY)
|
||||
mock_repository.return_value.create.assert_called_once_with("First fact", human_input_id=ANY)
|
||||
|
||||
|
||||
def test_get_memory_value_other_types(reset_memory):
|
||||
|
|
@ -264,10 +264,10 @@ def test_emit_key_facts(reset_memory, mock_repository):
|
|||
assert result == "Facts stored."
|
||||
|
||||
# Verify create was called for each fact
|
||||
assert mock_repository.create.call_count == 3
|
||||
mock_repository.create.assert_any_call("First fact", human_input_id=ANY)
|
||||
mock_repository.create.assert_any_call("Second fact", human_input_id=ANY)
|
||||
mock_repository.create.assert_any_call("Third fact", human_input_id=ANY)
|
||||
assert mock_repository.return_value.create.call_count == 3
|
||||
mock_repository.return_value.create.assert_any_call("First fact", human_input_id=ANY)
|
||||
mock_repository.return_value.create.assert_any_call("Second fact", human_input_id=ANY)
|
||||
mock_repository.return_value.create.assert_any_call("Third fact", human_input_id=ANY)
|
||||
|
||||
|
||||
def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
||||
|
|
@ -278,7 +278,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
|||
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
|
||||
|
||||
# Mock the get_all method to return more than 30 facts
|
||||
mock_repository.get_all.return_value = facts
|
||||
mock_repository.return_value.get_all.return_value = facts
|
||||
|
||||
# Note on testing approach:
|
||||
# Rather than trying to mock the dynamic import which is challenging due to
|
||||
|
|
@ -295,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
|||
|
||||
# Verify that mock_repository.get_all was called,
|
||||
# which is the condition that would trigger the GC agent
|
||||
mock_repository.get_all.assert_called_once()
|
||||
mock_repository.return_value.get_all.assert_called_once()
|
||||
|
||||
|
||||
def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
||||
|
|
@ -315,7 +315,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
|||
assert result == "Snippet #0 stored."
|
||||
|
||||
# Verify create was called correctly
|
||||
mock_key_snippet_repository.create.assert_called_with(
|
||||
mock_key_snippet_repository.return_value.create.assert_called_with(
|
||||
filepath="test.py",
|
||||
line_number=10,
|
||||
snippet="def test():\n pass",
|
||||
|
|
@ -338,7 +338,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
|||
assert result == "Snippet #1 stored."
|
||||
|
||||
# Verify create was called correctly
|
||||
mock_key_snippet_repository.create.assert_called_with(
|
||||
mock_key_snippet_repository.return_value.create.assert_called_with(
|
||||
filepath="main.py",
|
||||
line_number=20,
|
||||
snippet="print('hello')",
|
||||
|
|
@ -379,16 +379,16 @@ def test_delete_key_snippets(mock_log_work_event, reset_memory, mock_key_snippet
|
|||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Test deleting mix of valid and invalid IDs
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
|
||||
|
||||
# Verify success message
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
# Verify repository get was called with correct IDs
|
||||
mock_key_snippet_repository.get.assert_any_call(0)
|
||||
mock_key_snippet_repository.get.assert_any_call(1)
|
||||
mock_key_snippet_repository.get.assert_any_call(999)
|
||||
mock_key_snippet_repository.return_value.get.assert_any_call(0)
|
||||
mock_key_snippet_repository.return_value.get.assert_any_call(1)
|
||||
mock_key_snippet_repository.return_value.get.assert_any_call(999)
|
||||
|
||||
# We skip verifying delete calls because they are prone to test environment issues
|
||||
# The implementation logic will properly delete IDs 0 and 1 but not 999
|
||||
|
|
@ -410,12 +410,12 @@ def test_delete_key_snippets_empty(mock_log_work_event, reset_memory, mock_key_s
|
|||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Test with empty list
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||
result = delete_key_snippets.invoke({"snippet_ids": []})
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
# Verify no call to delete method
|
||||
mock_key_snippet_repository.delete.assert_not_called()
|
||||
mock_key_snippet_repository.return_value.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_emit_related_files_basic(reset_memory, tmp_path):
|
||||
|
|
@ -458,13 +458,13 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path):
|
|||
new_file.write_text("# New file")
|
||||
|
||||
# Add initial files
|
||||
result = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
|
||||
assert result == "Files noted."
|
||||
result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
|
||||
assert result1 == "Files noted."
|
||||
_first_id = 0 # ID of test.py
|
||||
|
||||
# Try adding duplicates
|
||||
result = emit_related_files.invoke({"files": [str(test_file)]})
|
||||
assert result == "Files noted."
|
||||
result2 = emit_related_files.invoke({"files": [str(test_file)]})
|
||||
assert result2 == "Files noted."
|
||||
assert len(_global_memory["related_files"]) == 2 # Count should not increase
|
||||
|
||||
# Try mix of new and duplicate files
|
||||
|
|
@ -670,7 +670,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
|
|||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Delete some but not all snippets (0 and 2)
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
|
|
@ -692,7 +692,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
|
|||
assert result == "Snippet #3 stored."
|
||||
|
||||
# Verify create was called with correct params
|
||||
mock_key_snippet_repository.create.assert_called_with(
|
||||
mock_key_snippet_repository.return_value.create.assert_called_with(
|
||||
filepath=file4,
|
||||
line_number=40,
|
||||
snippet="def func4():\n return False",
|
||||
|
|
@ -704,7 +704,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
|
|||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Delete remaining snippets
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
||||
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue