use context vars for key facts repo
This commit is contained in:
parent
dd9af78693
commit
36e4004db0
|
|
@ -42,7 +42,7 @@ from ra_aid.config import (
|
||||||
DEFAULT_TEST_CMD_TIMEOUT,
|
DEFAULT_TEST_CMD_TIMEOUT,
|
||||||
VALID_PROVIDERS,
|
VALID_PROVIDERS,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
|
|
@ -393,251 +393,256 @@ def main():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Database migration error: {str(e)}")
|
logger.error(f"Database migration error: {str(e)}")
|
||||||
|
|
||||||
# Check dependencies before proceeding
|
# Initialize repositories with database connection
|
||||||
check_dependencies()
|
with KeyFactRepositoryManager(db) as key_fact_repo:
|
||||||
|
# This initializes the repository and makes it available via get_key_fact_repository()
|
||||||
|
logger.debug("Initialized KeyFactRepository")
|
||||||
|
|
||||||
(
|
# Check dependencies before proceeding
|
||||||
expert_enabled,
|
check_dependencies()
|
||||||
expert_missing,
|
|
||||||
web_research_enabled,
|
|
||||||
web_research_missing,
|
|
||||||
) = validate_environment(args) # Will exit if main env vars missing
|
|
||||||
logger.debug("Environment validation successful")
|
|
||||||
|
|
||||||
# Validate model configuration early
|
(
|
||||||
model_config = models_params.get(args.provider, {}).get(
|
expert_enabled,
|
||||||
args.model or "", {}
|
expert_missing,
|
||||||
)
|
web_research_enabled,
|
||||||
supports_temperature = model_config.get(
|
web_research_missing,
|
||||||
"supports_temperature",
|
) = validate_environment(args) # Will exit if main env vars missing
|
||||||
args.provider
|
logger.debug("Environment validation successful")
|
||||||
in [
|
|
||||||
"anthropic",
|
|
||||||
"openai",
|
|
||||||
"openrouter",
|
|
||||||
"openai-compatible",
|
|
||||||
"deepseek",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if supports_temperature and args.temperature is None:
|
# Validate model configuration early
|
||||||
args.temperature = model_config.get("default_temperature")
|
model_config = models_params.get(args.provider, {}).get(
|
||||||
if args.temperature is None:
|
args.model or "", {}
|
||||||
cpm(
|
)
|
||||||
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
supports_temperature = model_config.get(
|
||||||
|
"supports_temperature",
|
||||||
|
args.provider
|
||||||
|
in [
|
||||||
|
"anthropic",
|
||||||
|
"openai",
|
||||||
|
"openrouter",
|
||||||
|
"openai-compatible",
|
||||||
|
"deepseek",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if supports_temperature and args.temperature is None:
|
||||||
|
args.temperature = model_config.get("default_temperature")
|
||||||
|
if args.temperature is None:
|
||||||
|
cpm(
|
||||||
|
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||||
|
)
|
||||||
|
args.temperature = DEFAULT_TEMPERATURE
|
||||||
|
logger.debug(
|
||||||
|
f"Using default temperature {args.temperature} for model {args.model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
status = build_status(args, expert_enabled, web_research_enabled)
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
Panel(
|
||||||
|
status,
|
||||||
|
title=f"RA.Aid v{__version__}",
|
||||||
|
border_style="bright_blue",
|
||||||
|
padding=(0, 1),
|
||||||
)
|
)
|
||||||
args.temperature = DEFAULT_TEMPERATURE
|
|
||||||
logger.debug(
|
|
||||||
f"Using default temperature {args.temperature} for model {args.model}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
status = build_status(args, expert_enabled, web_research_enabled)
|
# Handle chat mode
|
||||||
|
if args.chat:
|
||||||
|
# Initialize chat model with default provider/model
|
||||||
|
chat_model = initialize_llm(
|
||||||
|
args.provider, args.model, temperature=args.temperature
|
||||||
|
)
|
||||||
|
|
||||||
console.print(
|
if args.research_only:
|
||||||
Panel(
|
print_error("Chat mode cannot be used with --research-only")
|
||||||
status,
|
sys.exit(1)
|
||||||
title=f"RA.Aid v{__version__}",
|
|
||||||
border_style="bright_blue",
|
|
||||||
padding=(0, 1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle chat mode
|
print_stage_header("Chat Mode")
|
||||||
if args.chat:
|
|
||||||
# Initialize chat model with default provider/model
|
|
||||||
chat_model = initialize_llm(
|
|
||||||
args.provider, args.model, temperature=args.temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.research_only:
|
# Get project info
|
||||||
print_error("Chat mode cannot be used with --research-only")
|
try:
|
||||||
|
project_info = get_project_info(".", file_limit=2000)
|
||||||
|
formatted_project_info = format_project_info(project_info)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get project info: {e}")
|
||||||
|
formatted_project_info = ""
|
||||||
|
|
||||||
|
# Get initial request from user
|
||||||
|
initial_request = ask_human.invoke(
|
||||||
|
{"question": "What would you like help with?"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record chat input in database (redundant as ask_human already records it,
|
||||||
|
# but needed in case the ask_human implementation changes)
|
||||||
|
try:
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
human_input_repo = HumanInputRepository(db)
|
||||||
|
human_input_repo.create(content=initial_request, source='chat')
|
||||||
|
human_input_repo.garbage_collect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to record initial chat input: {str(e)}")
|
||||||
|
|
||||||
|
# Get working directory and current date
|
||||||
|
working_directory = os.getcwd()
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Run chat agent with CHAT_PROMPT
|
||||||
|
config = {
|
||||||
|
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||||
|
"recursion_limit": args.recursion_limit,
|
||||||
|
"chat_mode": True,
|
||||||
|
"cowboy_mode": args.cowboy_mode,
|
||||||
|
"hil": True, # Always true in chat mode
|
||||||
|
"web_research_enabled": web_research_enabled,
|
||||||
|
"initial_request": initial_request,
|
||||||
|
"limit_tokens": args.disable_limit_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store config in global memory
|
||||||
|
_global_memory["config"] = config
|
||||||
|
_global_memory["config"]["provider"] = args.provider
|
||||||
|
_global_memory["config"]["model"] = args.model
|
||||||
|
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||||
|
_global_memory["config"]["expert_model"] = args.expert_model
|
||||||
|
_global_memory["config"]["temperature"] = args.temperature
|
||||||
|
|
||||||
|
# Set modification tools based on use_aider flag
|
||||||
|
set_modification_tools(args.use_aider)
|
||||||
|
|
||||||
|
# Create chat agent with appropriate tools
|
||||||
|
chat_agent = create_agent(
|
||||||
|
chat_model,
|
||||||
|
get_chat_tools(
|
||||||
|
expert_enabled=expert_enabled,
|
||||||
|
web_research_enabled=web_research_enabled,
|
||||||
|
),
|
||||||
|
checkpointer=MemorySaver(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run chat agent and exit
|
||||||
|
run_agent_with_retry(
|
||||||
|
chat_agent,
|
||||||
|
CHAT_PROMPT.format(
|
||||||
|
initial_request=initial_request,
|
||||||
|
web_research_section=(
|
||||||
|
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||||
|
if web_research_enabled
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
working_directory=working_directory,
|
||||||
|
current_date=current_date,
|
||||||
|
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
||||||
|
key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()),
|
||||||
|
project_info=formatted_project_info,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate message is provided
|
||||||
|
if not args.message:
|
||||||
|
print_error("--message is required")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
print_stage_header("Chat Mode")
|
base_task = args.message
|
||||||
|
|
||||||
# Get project info
|
|
||||||
try:
|
|
||||||
project_info = get_project_info(".", file_limit=2000)
|
|
||||||
formatted_project_info = format_project_info(project_info)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get project info: {e}")
|
|
||||||
formatted_project_info = ""
|
|
||||||
|
|
||||||
# Get initial request from user
|
|
||||||
initial_request = ask_human.invoke(
|
|
||||||
{"question": "What would you like help with?"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record chat input in database (redundant as ask_human already records it,
|
# Record CLI input in database
|
||||||
# but needed in case the ask_human implementation changes)
|
|
||||||
try:
|
try:
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
human_input_repo = HumanInputRepository()
|
human_input_repo = HumanInputRepository(db)
|
||||||
human_input_repo.create(content=initial_request, source='chat')
|
human_input_repo.create(content=base_task, source='cli')
|
||||||
|
# Run garbage collection to ensure we don't exceed 100 inputs
|
||||||
human_input_repo.garbage_collect()
|
human_input_repo.garbage_collect()
|
||||||
|
logger.debug(f"Recorded CLI input: {base_task}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to record initial chat input: {str(e)}")
|
logger.error(f"Failed to record CLI input: {str(e)}")
|
||||||
|
|
||||||
# Get working directory and current date
|
|
||||||
working_directory = os.getcwd()
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# Run chat agent with CHAT_PROMPT
|
|
||||||
config = {
|
config = {
|
||||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||||
"recursion_limit": args.recursion_limit,
|
"recursion_limit": args.recursion_limit,
|
||||||
"chat_mode": True,
|
"research_only": args.research_only,
|
||||||
"cowboy_mode": args.cowboy_mode,
|
"cowboy_mode": args.cowboy_mode,
|
||||||
"hil": True, # Always true in chat mode
|
|
||||||
"web_research_enabled": web_research_enabled,
|
"web_research_enabled": web_research_enabled,
|
||||||
"initial_request": initial_request,
|
"aider_config": args.aider_config,
|
||||||
|
"use_aider": args.use_aider,
|
||||||
"limit_tokens": args.disable_limit_tokens,
|
"limit_tokens": args.disable_limit_tokens,
|
||||||
|
"auto_test": args.auto_test,
|
||||||
|
"test_cmd": args.test_cmd,
|
||||||
|
"max_test_cmd_retries": args.max_test_cmd_retries,
|
||||||
|
"experimental_fallback_handler": args.experimental_fallback_handler,
|
||||||
|
"test_cmd_timeout": args.test_cmd_timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store config in global memory
|
# Store config in global memory for access by is_informational_query
|
||||||
_global_memory["config"] = config
|
_global_memory["config"] = config
|
||||||
|
|
||||||
|
# Store base provider/model configuration
|
||||||
_global_memory["config"]["provider"] = args.provider
|
_global_memory["config"]["provider"] = args.provider
|
||||||
_global_memory["config"]["model"] = args.model
|
_global_memory["config"]["model"] = args.model
|
||||||
|
|
||||||
|
# Store expert provider/model (no fallback)
|
||||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||||
_global_memory["config"]["expert_model"] = args.expert_model
|
_global_memory["config"]["expert_model"] = args.expert_model
|
||||||
|
|
||||||
|
# Store planner config with fallback to base values
|
||||||
|
_global_memory["config"]["planner_provider"] = (
|
||||||
|
args.planner_provider or args.provider
|
||||||
|
)
|
||||||
|
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||||
|
|
||||||
|
# Store research config with fallback to base values
|
||||||
|
_global_memory["config"]["research_provider"] = (
|
||||||
|
args.research_provider or args.provider
|
||||||
|
)
|
||||||
|
_global_memory["config"]["research_model"] = (
|
||||||
|
args.research_model or args.model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store temperature in global config
|
||||||
_global_memory["config"]["temperature"] = args.temperature
|
_global_memory["config"]["temperature"] = args.temperature
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
||||||
# Create chat agent with appropriate tools
|
# Run research stage
|
||||||
chat_agent = create_agent(
|
print_stage_header("Research Stage")
|
||||||
chat_model,
|
|
||||||
get_chat_tools(
|
# Initialize research model with potential overrides
|
||||||
expert_enabled=expert_enabled,
|
research_provider = args.research_provider or args.provider
|
||||||
web_research_enabled=web_research_enabled,
|
research_model_name = args.research_model or args.model
|
||||||
),
|
research_model = initialize_llm(
|
||||||
checkpointer=MemorySaver(),
|
research_provider, research_model_name, temperature=args.temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run chat agent and exit
|
run_research_agent(
|
||||||
run_agent_with_retry(
|
|
||||||
chat_agent,
|
|
||||||
CHAT_PROMPT.format(
|
|
||||||
initial_request=initial_request,
|
|
||||||
web_research_section=(
|
|
||||||
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
|
||||||
if web_research_enabled
|
|
||||||
else ""
|
|
||||||
),
|
|
||||||
working_directory=working_directory,
|
|
||||||
current_date=current_date,
|
|
||||||
key_facts=format_key_facts_dict(KeyFactRepository().get_facts_dict()),
|
|
||||||
key_snippets=format_key_snippets_dict(KeySnippetRepository().get_snippets_dict()),
|
|
||||||
project_info=formatted_project_info,
|
|
||||||
),
|
|
||||||
config,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Validate message is provided
|
|
||||||
if not args.message:
|
|
||||||
print_error("--message is required")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
base_task = args.message
|
|
||||||
|
|
||||||
# Record CLI input in database
|
|
||||||
try:
|
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
|
||||||
human_input_repo = HumanInputRepository()
|
|
||||||
human_input_repo.create(content=base_task, source='cli')
|
|
||||||
# Run garbage collection to ensure we don't exceed 100 inputs
|
|
||||||
human_input_repo.garbage_collect()
|
|
||||||
logger.debug(f"Recorded CLI input: {base_task}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to record CLI input: {str(e)}")
|
|
||||||
config = {
|
|
||||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
|
||||||
"recursion_limit": args.recursion_limit,
|
|
||||||
"research_only": args.research_only,
|
|
||||||
"cowboy_mode": args.cowboy_mode,
|
|
||||||
"web_research_enabled": web_research_enabled,
|
|
||||||
"aider_config": args.aider_config,
|
|
||||||
"use_aider": args.use_aider,
|
|
||||||
"limit_tokens": args.disable_limit_tokens,
|
|
||||||
"auto_test": args.auto_test,
|
|
||||||
"test_cmd": args.test_cmd,
|
|
||||||
"max_test_cmd_retries": args.max_test_cmd_retries,
|
|
||||||
"experimental_fallback_handler": args.experimental_fallback_handler,
|
|
||||||
"test_cmd_timeout": args.test_cmd_timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store config in global memory for access by is_informational_query
|
|
||||||
_global_memory["config"] = config
|
|
||||||
|
|
||||||
# Store base provider/model configuration
|
|
||||||
_global_memory["config"]["provider"] = args.provider
|
|
||||||
_global_memory["config"]["model"] = args.model
|
|
||||||
|
|
||||||
# Store expert provider/model (no fallback)
|
|
||||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
|
||||||
_global_memory["config"]["expert_model"] = args.expert_model
|
|
||||||
|
|
||||||
# Store planner config with fallback to base values
|
|
||||||
_global_memory["config"]["planner_provider"] = (
|
|
||||||
args.planner_provider or args.provider
|
|
||||||
)
|
|
||||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
|
||||||
|
|
||||||
# Store research config with fallback to base values
|
|
||||||
_global_memory["config"]["research_provider"] = (
|
|
||||||
args.research_provider or args.provider
|
|
||||||
)
|
|
||||||
_global_memory["config"]["research_model"] = (
|
|
||||||
args.research_model or args.model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store temperature in global config
|
|
||||||
_global_memory["config"]["temperature"] = args.temperature
|
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
|
||||||
set_modification_tools(args.use_aider)
|
|
||||||
|
|
||||||
# Run research stage
|
|
||||||
print_stage_header("Research Stage")
|
|
||||||
|
|
||||||
# Initialize research model with potential overrides
|
|
||||||
research_provider = args.research_provider or args.provider
|
|
||||||
research_model_name = args.research_model or args.model
|
|
||||||
research_model = initialize_llm(
|
|
||||||
research_provider, research_model_name, temperature=args.temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
run_research_agent(
|
|
||||||
base_task,
|
|
||||||
research_model,
|
|
||||||
expert_enabled=expert_enabled,
|
|
||||||
research_only=args.research_only,
|
|
||||||
hil=args.hil,
|
|
||||||
memory=research_memory,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Proceed with planning and implementation if not an informational query
|
|
||||||
if not is_informational_query():
|
|
||||||
# Initialize planning model with potential overrides
|
|
||||||
planner_provider = args.planner_provider or args.provider
|
|
||||||
planner_model_name = args.planner_model or args.model
|
|
||||||
planning_model = initialize_llm(
|
|
||||||
planner_provider, planner_model_name, temperature=args.temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run planning agent
|
|
||||||
run_planning_agent(
|
|
||||||
base_task,
|
base_task,
|
||||||
planning_model,
|
research_model,
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
|
research_only=args.research_only,
|
||||||
hil=args.hil,
|
hil=args.hil,
|
||||||
memory=planning_memory,
|
memory=research_memory,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Proceed with planning and implementation if not an informational query
|
||||||
|
if not is_informational_query():
|
||||||
|
# Initialize planning model with potential overrides
|
||||||
|
planner_provider = args.planner_provider or args.provider
|
||||||
|
planner_model_name = args.planner_model or args.model
|
||||||
|
planning_model = initialize_llm(
|
||||||
|
planner_provider, planner_model_name, temperature=args.temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run planning agent
|
||||||
|
run_planning_agent(
|
||||||
|
base_task,
|
||||||
|
planning_model,
|
||||||
|
expert_enabled=expert_enabled,
|
||||||
|
hil=args.hil,
|
||||||
|
memory=planning_memory,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
except (KeyboardInterrupt, AgentInterrupt):
|
except (KeyboardInterrupt, AgentInterrupt):
|
||||||
print()
|
print()
|
||||||
print(" 👋 Bye!")
|
print(" 👋 Bye!")
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,8 @@ from ra_aid.tool_configs import (
|
||||||
get_web_research_tools,
|
get_web_research_tools,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
|
|
@ -100,10 +100,8 @@ console = Console()
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# Initialize repositories
|
# Import repositories using get_* functions
|
||||||
key_fact_repository = KeyFactRepository()
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
key_snippet_repository = KeySnippetRepository()
|
|
||||||
human_input_repository = HumanInputRepository()
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -391,7 +389,11 @@ def run_research_agent(
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
code_snippets = _global_memory.get("code_snippets", "")
|
code_snippets = _global_memory.get("code_snippets", "")
|
||||||
related_files = _global_memory.get("related_files", "")
|
related_files = _global_memory.get("related_files", "")
|
||||||
|
|
||||||
|
|
@ -400,6 +402,7 @@ def run_research_agent(
|
||||||
|
|
||||||
# Get the last human input, if it exists
|
# Get the last human input, if it exists
|
||||||
base_task = base_task_or_query
|
base_task = base_task_or_query
|
||||||
|
human_input_repository = HumanInputRepository()
|
||||||
recent_inputs = human_input_repository.get_recent(1)
|
recent_inputs = human_input_repository.get_recent(1)
|
||||||
if recent_inputs and len(recent_inputs) > 0:
|
if recent_inputs and len(recent_inputs) > 0:
|
||||||
last_human_input = recent_inputs[0].content
|
last_human_input = recent_inputs[0].content
|
||||||
|
|
@ -537,7 +540,11 @@ def run_web_research_agent(
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||||
|
|
||||||
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
code_snippets = _global_memory.get("code_snippets", "")
|
code_snippets = _global_memory.get("code_snippets", "")
|
||||||
related_files = _global_memory.get("related_files", "")
|
related_files = _global_memory.get("related_files", "")
|
||||||
|
|
||||||
|
|
@ -647,6 +654,20 @@ def run_planning_agent(
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
working_directory = os.getcwd()
|
working_directory = os.getcwd()
|
||||||
|
|
||||||
|
# Make sure key_facts is defined before using it
|
||||||
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
|
|
||||||
|
# Make sure key_snippets is defined before using it
|
||||||
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
|
|
||||||
planning_prompt = PLANNING_PROMPT.format(
|
planning_prompt = PLANNING_PROMPT.format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
|
|
@ -657,8 +678,8 @@ def run_planning_agent(
|
||||||
project_info=formatted_project_info,
|
project_info=formatted_project_info,
|
||||||
research_notes=get_memory_value("research_notes"),
|
research_notes=get_memory_value("research_notes"),
|
||||||
related_files="\n".join(get_related_files()),
|
related_files="\n".join(get_related_files()),
|
||||||
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
key_facts=key_facts,
|
||||||
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
key_snippets=key_snippets,
|
||||||
work_log=get_memory_value("work_log"),
|
work_log=get_memory_value("work_log"),
|
||||||
research_only_note=(
|
research_only_note=(
|
||||||
""
|
""
|
||||||
|
|
@ -751,6 +772,13 @@ def run_task_implementation_agent(
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
working_directory = os.getcwd()
|
working_directory = os.getcwd()
|
||||||
|
|
||||||
|
# Make sure key_facts is defined before using it
|
||||||
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
|
|
||||||
prompt = IMPLEMENTATION_PROMPT.format(
|
prompt = IMPLEMENTATION_PROMPT.format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
|
|
@ -759,8 +787,8 @@ def run_task_implementation_agent(
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
plan=plan,
|
plan=plan,
|
||||||
related_files=related_files,
|
related_files=related_files,
|
||||||
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
key_facts=key_facts,
|
||||||
key_snippets=format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||||
research_notes=get_memory_value("research_notes"),
|
research_notes=get_memory_value("research_notes"),
|
||||||
work_log=get_memory_value("work_log"),
|
work_log=get_memory_value("work_log"),
|
||||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ facts when the total number exceeds a specified threshold. The agent evaluates a
|
||||||
key facts and deletes the least valuable ones to keep the database clean and relevant.
|
key facts and deletes the least valuable ones to keep the database clean and relevant.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
@ -13,8 +14,10 @@ from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||||
|
|
@ -22,7 +25,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
key_fact_repository = KeyFactRepository()
|
|
||||||
human_input_repository = HumanInputRepository()
|
human_input_repository = HumanInputRepository()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,24 +53,30 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
||||||
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
|
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
|
||||||
|
|
||||||
for fact_id in fact_ids:
|
for fact_id in fact_ids:
|
||||||
# Get the fact first to display information
|
try:
|
||||||
fact = key_fact_repository.get(fact_id)
|
# Get the fact first to display information
|
||||||
if fact:
|
fact = get_key_fact_repository().get(fact_id)
|
||||||
# Check if this fact is associated with the current human input
|
if fact:
|
||||||
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
|
# Check if this fact is associated with the current human input
|
||||||
protected_facts.append((fact_id, fact.content))
|
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
|
||||||
continue
|
protected_facts.append((fact_id, fact.content))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Delete the fact if it's not protected
|
||||||
|
was_deleted = get_key_fact_repository().delete(fact_id)
|
||||||
|
if was_deleted:
|
||||||
|
deleted_facts.append((fact_id, fact.content))
|
||||||
|
log_work_event(f"Deleted fact {fact_id}.")
|
||||||
|
else:
|
||||||
|
failed_facts.append(fact_id)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
failed_facts.append(fact_id)
|
||||||
|
except Exception as e:
|
||||||
|
# For any other exceptions, log and continue
|
||||||
|
logger.error(f"Error processing fact {fact_id}: {str(e)}")
|
||||||
|
failed_facts.append(fact_id)
|
||||||
|
|
||||||
# Delete the fact if it's not protected
|
|
||||||
was_deleted = key_fact_repository.delete(fact_id)
|
|
||||||
if was_deleted:
|
|
||||||
deleted_facts.append((fact_id, fact.content))
|
|
||||||
log_work_event(f"Deleted fact {fact_id}.")
|
|
||||||
else:
|
|
||||||
failed_facts.append(fact_id)
|
|
||||||
else:
|
|
||||||
not_found_facts.append(fact_id)
|
|
||||||
|
|
||||||
# Prepare result message
|
# Prepare result message
|
||||||
result_parts = []
|
result_parts = []
|
||||||
if deleted_facts:
|
if deleted_facts:
|
||||||
|
|
@ -104,8 +112,13 @@ def run_key_facts_gc_agent() -> None:
|
||||||
Facts associated with the current human input are excluded from deletion.
|
Facts associated with the current human input are excluded from deletion.
|
||||||
"""
|
"""
|
||||||
# Get the count of key facts
|
# Get the count of key facts
|
||||||
facts = key_fact_repository.get_all()
|
try:
|
||||||
fact_count = len(facts)
|
facts = get_key_fact_repository().get_all()
|
||||||
|
fact_count = len(facts)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||||
|
return # Exit the function if we can't access the repository
|
||||||
|
|
||||||
# Display status panel with fact count included
|
# Display status panel with fact count included
|
||||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
|
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
|
||||||
|
|
@ -161,8 +174,12 @@ def run_key_facts_gc_agent() -> None:
|
||||||
run_agent_with_retry(agent, prompt, agent_config)
|
run_agent_with_retry(agent, prompt, agent_config)
|
||||||
|
|
||||||
# Get updated count
|
# Get updated count
|
||||||
updated_facts = key_fact_repository.get_all()
|
try:
|
||||||
updated_count = len(updated_facts)
|
updated_facts = get_key_fact_repository().get_all()
|
||||||
|
updated_count = len(updated_facts)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository for update count: {str(e)}")
|
||||||
|
updated_count = "unknown"
|
||||||
|
|
||||||
# Show info panel with updated count and protected facts count
|
# Show info panel with updated count and protected facts count
|
||||||
protected_count = len(protected_facts)
|
protected_count = len(protected_facts)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||||
|
|
@ -22,7 +22,6 @@ from ra_aid.tools.memory import log_work_event, _global_memory
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
key_snippet_repository = KeySnippetRepository()
|
|
||||||
human_input_repository = HumanInputRepository()
|
human_input_repository = HumanInputRepository()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -53,7 +52,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||||
|
|
||||||
for snippet_id in snippet_ids:
|
for snippet_id in snippet_ids:
|
||||||
# Get the snippet first to capture filepath for the message
|
# Get the snippet first to capture filepath for the message
|
||||||
snippet = key_snippet_repository.get(snippet_id)
|
snippet = get_key_snippet_repository().get(snippet_id)
|
||||||
if snippet:
|
if snippet:
|
||||||
filepath = snippet.filepath
|
filepath = snippet.filepath
|
||||||
|
|
||||||
|
|
@ -63,7 +62,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Delete from database if not protected
|
# Delete from database if not protected
|
||||||
success = key_snippet_repository.delete(snippet_id)
|
success = get_key_snippet_repository().delete(snippet_id)
|
||||||
if success:
|
if success:
|
||||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||||
console.print(
|
console.print(
|
||||||
|
|
@ -110,7 +109,7 @@ def run_key_snippets_gc_agent() -> None:
|
||||||
Snippets associated with the current human input are excluded from deletion.
|
Snippets associated with the current human input are excluded from deletion.
|
||||||
"""
|
"""
|
||||||
# Get the count of key snippets
|
# Get the count of key snippets
|
||||||
snippets = key_snippet_repository.get_all()
|
snippets = get_key_snippet_repository().get_all()
|
||||||
snippet_count = len(snippets)
|
snippet_count = len(snippets)
|
||||||
|
|
||||||
# Display status panel with snippet count included
|
# Display status panel with snippet count included
|
||||||
|
|
@ -179,7 +178,7 @@ def run_key_snippets_gc_agent() -> None:
|
||||||
run_agent_with_retry(agent, prompt, agent_config)
|
run_agent_with_retry(agent, prompt, agent_config)
|
||||||
|
|
||||||
# Get updated count
|
# Get updated count
|
||||||
updated_snippets = key_snippet_repository.get_all()
|
updated_snippets = get_key_snippet_repository().get_all()
|
||||||
updated_count = len(updated_snippets)
|
updated_count = len(updated_snippets)
|
||||||
|
|
||||||
# Show info panel with updated count and protected snippets count
|
# Show info panel with updated count and protected snippets count
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,94 @@ following the repository pattern for data access abstraction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
import contextvars
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
from ra_aid.database.connection import get_db
|
from ra_aid.database.models import KeyFact
|
||||||
from ra_aid.database.models import KeyFact, initialize_database
|
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Create contextvar to hold the KeyFactRepository instance
|
||||||
|
key_fact_repo_var = contextvars.ContextVar("key_fact_repo", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyFactRepositoryManager:
|
||||||
|
"""
|
||||||
|
Context manager for KeyFactRepository.
|
||||||
|
|
||||||
|
This class provides a context manager interface for KeyFactRepository,
|
||||||
|
using the contextvars approach for thread safety.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with DatabaseManager() as db:
|
||||||
|
with KeyFactRepositoryManager(db) as repo:
|
||||||
|
# Use the repository
|
||||||
|
fact = repo.create("Important fact about the project")
|
||||||
|
all_facts = repo.get_all()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db):
|
||||||
|
"""
|
||||||
|
Initialize the KeyFactRepositoryManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database connection to use (required)
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def __enter__(self) -> 'KeyFactRepository':
|
||||||
|
"""
|
||||||
|
Initialize the KeyFactRepository and return it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeyFactRepository: The initialized repository
|
||||||
|
"""
|
||||||
|
repo = KeyFactRepository(self.db)
|
||||||
|
key_fact_repo_var.set(repo)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type],
|
||||||
|
exc_val: Optional[Exception],
|
||||||
|
exc_tb: Optional[object],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reset the repository when exiting the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: The exception type if an exception was raised
|
||||||
|
exc_val: The exception value if an exception was raised
|
||||||
|
exc_tb: The traceback if an exception was raised
|
||||||
|
"""
|
||||||
|
# Reset the contextvar to None
|
||||||
|
key_fact_repo_var.set(None)
|
||||||
|
|
||||||
|
# Don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_fact_repository() -> 'KeyFactRepository':
|
||||||
|
"""
|
||||||
|
Get the current KeyFactRepository instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeyFactRepository: The current repository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no repository has been initialized with KeyFactRepositoryManager
|
||||||
|
"""
|
||||||
|
repo = key_fact_repo_var.get()
|
||||||
|
if repo is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No KeyFactRepository available. "
|
||||||
|
"Make sure to initialize one with KeyFactRepositoryManager first."
|
||||||
|
)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
class KeyFactRepository:
|
class KeyFactRepository:
|
||||||
"""
|
"""
|
||||||
|
|
@ -24,18 +103,21 @@ class KeyFactRepository:
|
||||||
abstracting the database access details from the business logic.
|
abstracting the database access details from the business logic.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
repo = KeyFactRepository()
|
with DatabaseManager() as db:
|
||||||
fact = repo.create("Important fact about the project")
|
with KeyFactRepositoryManager(db) as repo:
|
||||||
all_facts = repo.get_all()
|
fact = repo.create("Important fact about the project")
|
||||||
|
all_facts = repo.get_all()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db=None):
|
def __init__(self, db):
|
||||||
"""
|
"""
|
||||||
Initialize the repository with an optional database connection.
|
Initialize the repository with a database connection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Optional database connection to use. If None, will use initialize_database()
|
db: Database connection to use (required)
|
||||||
"""
|
"""
|
||||||
|
if db is None:
|
||||||
|
raise ValueError("Database connection is required for KeyFactRepository")
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
|
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
|
||||||
|
|
@ -53,7 +135,6 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error creating the fact
|
peewee.DatabaseError: If there's an error creating the fact
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
fact = KeyFact.create(content=content, human_input_id=human_input_id)
|
fact = KeyFact.create(content=content, human_input_id=human_input_id)
|
||||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||||
return fact
|
return fact
|
||||||
|
|
@ -75,7 +156,6 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error accessing the database
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||||
|
|
@ -96,7 +176,6 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error updating the fact
|
peewee.DatabaseError: If there's an error updating the fact
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
# First check if the fact exists
|
# First check if the fact exists
|
||||||
fact = self.get(fact_id)
|
fact = self.get(fact_id)
|
||||||
if not fact:
|
if not fact:
|
||||||
|
|
@ -126,7 +205,6 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error deleting the fact
|
peewee.DatabaseError: If there's an error deleting the fact
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
# First check if the fact exists
|
# First check if the fact exists
|
||||||
fact = self.get(fact_id)
|
fact = self.get(fact_id)
|
||||||
if not fact:
|
if not fact:
|
||||||
|
|
@ -152,7 +230,6 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error accessing the database
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
return list(KeyFact.select().order_by(KeyFact.id))
|
return list(KeyFact.select().order_by(KeyFact.id))
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||||
|
|
|
||||||
|
|
@ -214,4 +214,21 @@ class KeySnippetRepository:
|
||||||
}
|
}
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
|
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance
|
||||||
|
_key_snippet_repository = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_snippet_repository() -> KeySnippetRepository:
|
||||||
|
"""
|
||||||
|
Get or create a singleton instance of KeySnippetRepository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeySnippetRepository: Singleton instance of the repository
|
||||||
|
"""
|
||||||
|
global _key_snippet_repository
|
||||||
|
if _key_snippet_repository is None:
|
||||||
|
_key_snippet_repository = KeySnippetRepository()
|
||||||
|
return _key_snippet_repository
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Tools for spawning and managing sub-agents."""
|
"""Tools for spawning and managing sub-agents."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
import logging
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
@ -12,8 +13,9 @@ from ra_aid.agent_context import (
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
)
|
)
|
||||||
from ra_aid.console.formatting import print_error
|
from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
|
|
@ -31,8 +33,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user.
|
||||||
RESEARCH_AGENT_RECURSION_LIMIT = 3
|
RESEARCH_AGENT_RECURSION_LIMIT = 3
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
key_fact_repository = KeyFactRepository()
|
logger = logging.getLogger(__name__)
|
||||||
key_snippet_repository = KeySnippetRepository()
|
|
||||||
|
|
||||||
|
|
||||||
@tool("request_research")
|
@tool("request_research")
|
||||||
|
|
@ -57,12 +58,24 @@ def request_research(query: str) -> ResearchResult:
|
||||||
current_depth = _global_memory.get("agent_depth", 0)
|
current_depth = _global_memory.get("agent_depth", 0)
|
||||||
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
||||||
print_error("Maximum research recursion depth reached")
|
print_error("Maximum research recursion depth reached")
|
||||||
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"completion_message": "Research stopped - maximum recursion depth reached",
|
"completion_message": "Research stopped - maximum recursion depth reached",
|
||||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
"key_facts": key_facts,
|
||||||
"related_files": get_related_files(),
|
"related_files": get_related_files(),
|
||||||
"research_notes": get_memory_value("research_notes"),
|
"research_notes": get_memory_value("research_notes"),
|
||||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
"key_snippets": key_snippets,
|
||||||
"success": False,
|
"success": False,
|
||||||
"reason": "max_depth_exceeded",
|
"reason": "max_depth_exceeded",
|
||||||
}
|
}
|
||||||
|
|
@ -105,12 +118,24 @@ def request_research(query: str) -> ResearchResult:
|
||||||
# Clear completion state
|
# Clear completion state
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
"key_facts": key_facts,
|
||||||
"related_files": get_related_files(),
|
"related_files": get_related_files(),
|
||||||
"research_notes": get_memory_value("research_notes"),
|
"research_notes": get_memory_value("research_notes"),
|
||||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
"key_snippets": key_snippets,
|
||||||
"success": success,
|
"success": success,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
}
|
}
|
||||||
|
|
@ -171,9 +196,15 @@ def request_web_research(query: str) -> ResearchResult:
|
||||||
# Clear completion state
|
# Clear completion state
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
"key_snippets": key_snippets,
|
||||||
"research_notes": get_memory_value("research_notes"),
|
"research_notes": get_memory_value("research_notes"),
|
||||||
"success": success,
|
"success": success,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
|
|
@ -239,12 +270,24 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
||||||
# Clear completion state
|
# Clear completion state
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
"key_facts": key_facts,
|
||||||
"related_files": get_related_files(),
|
"related_files": get_related_files(),
|
||||||
"research_notes": get_memory_value("research_notes"),
|
"research_notes": get_memory_value("research_notes"),
|
||||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
"key_snippets": key_snippets,
|
||||||
"success": success,
|
"success": success,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
}
|
}
|
||||||
|
|
@ -324,10 +367,22 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
agent_crashed = is_crashed()
|
agent_crashed = is_crashed()
|
||||||
crash_message = get_crash_message() if agent_crashed else None
|
crash_message = get_crash_message() if agent_crashed else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
"key_facts": key_facts,
|
||||||
"related_files": get_related_files(),
|
"related_files": get_related_files(),
|
||||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
"key_snippets": key_snippets,
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
"success": success and not agent_crashed,
|
"success": success and not agent_crashed,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
|
|
@ -444,11 +499,23 @@ def request_implementation(task_spec: str) -> str:
|
||||||
agent_crashed = is_crashed()
|
agent_crashed = is_crashed()
|
||||||
crash_message = get_crash_message() if agent_crashed else None
|
crash_message = get_crash_message() if agent_crashed else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
"key_facts": key_facts,
|
||||||
"related_files": get_related_files(),
|
"related_files": get_related_files(),
|
||||||
"key_snippets": format_key_snippets_dict(key_snippet_repository.get_snippets_dict()),
|
"key_snippets": key_snippets,
|
||||||
"success": success and not agent_crashed,
|
"success": success and not agent_crashed,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"agent_crashed": agent_crashed,
|
"agent_crashed": agent_crashed,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
@ -6,8 +7,10 @@ from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from ..database.repositories.key_fact_repository import KeyFactRepository
|
logger = logging.getLogger(__name__)
|
||||||
from ..database.repositories.key_snippet_repository import KeySnippetRepository
|
|
||||||
|
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
|
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ..llm import initialize_expert_llm
|
from ..llm import initialize_expert_llm
|
||||||
from ..model_formatters import format_key_facts_dict
|
from ..model_formatters import format_key_facts_dict
|
||||||
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
|
|
@ -15,8 +18,6 @@ from .memory import _global_memory, get_memory_value
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
_model = None
|
_model = None
|
||||||
key_fact_repository = KeyFactRepository()
|
|
||||||
key_snippet_repository = KeySnippetRepository()
|
|
||||||
|
|
||||||
|
|
||||||
def get_model():
|
def get_model():
|
||||||
|
|
@ -154,10 +155,18 @@ def ask_expert(question: str) -> str:
|
||||||
file_paths = list(_global_memory["related_files"].values())
|
file_paths = list(_global_memory["related_files"].values())
|
||||||
related_contents = read_related_files(file_paths)
|
related_contents = read_related_files(file_paths)
|
||||||
# Get key snippets directly from repository and format using the formatter
|
# Get key snippets directly from repository and format using the formatter
|
||||||
key_snippets = format_key_snippets_dict(key_snippet_repository.get_snippets_dict())
|
try:
|
||||||
|
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
|
key_snippets = ""
|
||||||
# Get key facts directly from repository and format using the formatter
|
# Get key facts directly from repository and format using the formatter
|
||||||
facts_dict = key_fact_repository.get_facts_dict()
|
try:
|
||||||
key_facts = format_key_facts_dict(facts_dict)
|
facts_dict = get_key_fact_repository().get_facts_dict()
|
||||||
|
key_facts = format_key_facts_dict(facts_dict)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
key_facts = ""
|
||||||
research_notes = get_memory_value("research_notes")
|
research_notes = get_memory_value("research_notes")
|
||||||
|
|
||||||
# Build display query (just question)
|
# Build display query (just question)
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,8 @@ from ra_aid.agent_context import (
|
||||||
mark_should_exit,
|
mark_should_exit,
|
||||||
mark_task_completed,
|
mark_task_completed,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository, get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository, get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.model_formatters import key_snippets_formatter
|
from ra_aid.model_formatters import key_snippets_formatter
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
@ -40,14 +40,8 @@ class SnippetInfo(TypedDict):
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
# Initialize repository for key facts
|
# Import repositories using the get_* functions
|
||||||
key_fact_repository = KeyFactRepository()
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
|
|
||||||
# Initialize repository for key snippets
|
|
||||||
key_snippet_repository = KeySnippetRepository()
|
|
||||||
|
|
||||||
# Initialize repository for human inputs
|
|
||||||
human_input_repository = HumanInputRepository()
|
|
||||||
|
|
||||||
# Global memory store
|
# Global memory store
|
||||||
_global_memory: Dict[str, Any] = {
|
_global_memory: Dict[str, Any] = {
|
||||||
|
|
@ -120,17 +114,23 @@ def emit_key_facts(facts: List[str]) -> str:
|
||||||
|
|
||||||
# Try to get the latest human input
|
# Try to get the latest human input
|
||||||
human_input_id = None
|
human_input_id = None
|
||||||
|
human_input_repo = HumanInputRepository()
|
||||||
try:
|
try:
|
||||||
recent_inputs = human_input_repository.get_recent(1)
|
recent_inputs = human_input_repo.get_recent(1)
|
||||||
if recent_inputs and len(recent_inputs) > 0:
|
if recent_inputs and len(recent_inputs) > 0:
|
||||||
human_input_id = recent_inputs[0].id
|
human_input_id = recent_inputs[0].id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get recent human input: {str(e)}")
|
logger.warning(f"Failed to get recent human input: {str(e)}")
|
||||||
|
|
||||||
for fact in facts:
|
for fact in facts:
|
||||||
# Create fact in database using repository
|
try:
|
||||||
created_fact = key_fact_repository.create(fact, human_input_id=human_input_id)
|
# Create fact in database using repository
|
||||||
fact_id = created_fact.id
|
created_fact = get_key_fact_repository().create(fact, human_input_id=human_input_id)
|
||||||
|
fact_id = created_fact.id
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
console.print(f"Error storing fact: {str(e)}", style="red")
|
||||||
|
continue
|
||||||
|
|
||||||
# Display panel with ID
|
# Display panel with ID
|
||||||
console.print(
|
console.print(
|
||||||
|
|
@ -147,14 +147,17 @@ def emit_key_facts(facts: List[str]) -> str:
|
||||||
log_work_event(f"Stored {len(facts)} key facts.")
|
log_work_event(f"Stored {len(facts)} key facts.")
|
||||||
|
|
||||||
# Check if we need to clean up facts (more than 30)
|
# Check if we need to clean up facts (more than 30)
|
||||||
all_facts = key_fact_repository.get_all()
|
try:
|
||||||
if len(all_facts) > 30:
|
all_facts = get_key_fact_repository().get_all()
|
||||||
# Trigger the key facts cleaner agent
|
if len(all_facts) > 30:
|
||||||
try:
|
# Trigger the key facts cleaner agent
|
||||||
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
|
try:
|
||||||
run_key_facts_gc_agent()
|
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
|
||||||
except Exception as e:
|
run_key_facts_gc_agent()
|
||||||
logger.error(f"Failed to run key facts cleaner: {str(e)}")
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to run key facts cleaner: {str(e)}")
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
|
||||||
return "Facts stored."
|
return "Facts stored."
|
||||||
|
|
||||||
|
|
@ -222,14 +225,15 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
||||||
# Try to get the latest human input
|
# Try to get the latest human input
|
||||||
human_input_id = None
|
human_input_id = None
|
||||||
try:
|
try:
|
||||||
recent_inputs = human_input_repository.get_recent(1)
|
human_input_repo = HumanInputRepository()
|
||||||
|
recent_inputs = human_input_repo.get_recent(1)
|
||||||
if recent_inputs and len(recent_inputs) > 0:
|
if recent_inputs and len(recent_inputs) > 0:
|
||||||
human_input_id = recent_inputs[0].id
|
human_input_id = recent_inputs[0].id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get recent human input: {str(e)}")
|
logger.warning(f"Failed to get recent human input: {str(e)}")
|
||||||
|
|
||||||
# Create a new key snippet in the database
|
# Create a new key snippet in the database
|
||||||
key_snippet = key_snippet_repository.create(
|
key_snippet = get_key_snippet_repository().create(
|
||||||
filepath=snippet_info["filepath"],
|
filepath=snippet_info["filepath"],
|
||||||
line_number=snippet_info["line_number"],
|
line_number=snippet_info["line_number"],
|
||||||
snippet=snippet_info["snippet"],
|
snippet=snippet_info["snippet"],
|
||||||
|
|
@ -266,7 +270,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
||||||
log_work_event(f"Stored code snippet #{snippet_id}.")
|
log_work_event(f"Stored code snippet #{snippet_id}.")
|
||||||
|
|
||||||
# Check if we need to clean up snippets (more than 20)
|
# Check if we need to clean up snippets (more than 20)
|
||||||
all_snippets = key_snippet_repository.get_all()
|
all_snippets = get_key_snippet_repository().get_all()
|
||||||
if len(all_snippets) > 20:
|
if len(all_snippets) > 20:
|
||||||
# Trigger the key snippets cleaner agent
|
# Trigger the key snippets cleaner agent
|
||||||
try:
|
try:
|
||||||
|
|
@ -279,7 +283,6 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@tool("swap_task_order")
|
@tool("swap_task_order")
|
||||||
def swap_task_order(id1: int, id2: int) -> str:
|
def swap_task_order(id1: int, id2: int) -> str:
|
||||||
"""Swap the order of two tasks in global memory by their IDs.
|
"""Swap the order of two tasks in global memory by their IDs.
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,12 @@ import peewee
|
||||||
|
|
||||||
from ra_aid.database.connection import DatabaseManager, db_var
|
from ra_aid.database.connection import DatabaseManager, db_var
|
||||||
from ra_aid.database.models import KeyFact, BaseModel
|
from ra_aid.database.models import KeyFact, BaseModel
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.key_fact_repository import (
|
||||||
|
KeyFactRepository,
|
||||||
|
KeyFactRepositoryManager,
|
||||||
|
get_key_fact_repository,
|
||||||
|
key_fact_repo_var
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -41,6 +46,19 @@ def cleanup_db():
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cleanup_repo():
|
||||||
|
"""Reset the repository contextvar after each test."""
|
||||||
|
# Reset before the test
|
||||||
|
key_fact_repo_var.set(None)
|
||||||
|
|
||||||
|
# Run the test
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Reset after the test
|
||||||
|
key_fact_repo_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def setup_db(cleanup_db):
|
def setup_db(cleanup_db):
|
||||||
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
|
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
|
||||||
|
|
@ -195,4 +213,49 @@ def test_get_facts_dict(setup_db):
|
||||||
# Verify each fact is in the dictionary with the correct content
|
# Verify each fact is in the dictionary with the correct content
|
||||||
for fact in facts:
|
for fact in facts:
|
||||||
assert fact.id in facts_dict
|
assert fact.id in facts_dict
|
||||||
assert facts_dict[fact.id] == fact.content
|
assert facts_dict[fact.id] == fact.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_repository_init_without_db():
|
||||||
|
"""Test that KeyFactRepository raises an error when initialized without a db parameter."""
|
||||||
|
# Attempt to create a repository without a database connection
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
KeyFactRepository(db=None)
|
||||||
|
|
||||||
|
# Verify the correct error message
|
||||||
|
assert "Database connection is required" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_key_fact_repository_manager(setup_db, cleanup_repo):
|
||||||
|
"""Test the KeyFactRepositoryManager context manager."""
|
||||||
|
# Use the context manager to create a repository
|
||||||
|
with KeyFactRepositoryManager(setup_db) as repo:
|
||||||
|
# Verify the repository was created correctly
|
||||||
|
assert isinstance(repo, KeyFactRepository)
|
||||||
|
assert repo.db is setup_db
|
||||||
|
|
||||||
|
# Verify we can use the repository
|
||||||
|
content = "Test fact via context manager"
|
||||||
|
fact = repo.create(content)
|
||||||
|
assert fact.id is not None
|
||||||
|
assert fact.content == content
|
||||||
|
|
||||||
|
# Verify we can get the repository using get_key_fact_repository
|
||||||
|
repo_from_var = get_key_fact_repository()
|
||||||
|
assert repo_from_var is repo
|
||||||
|
|
||||||
|
# Verify the repository was removed from the context var
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
get_key_fact_repository()
|
||||||
|
|
||||||
|
assert "No KeyFactRepository available" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_key_fact_repository_when_not_set(cleanup_repo):
|
||||||
|
"""Test that get_key_fact_repository raises an error when no repository is in context."""
|
||||||
|
# Attempt to get the repository when none exists
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
get_key_fact_repository()
|
||||||
|
|
||||||
|
# Verify the correct error message
|
||||||
|
assert "No KeyFactRepository available" in str(excinfo.value)
|
||||||
|
|
@ -36,9 +36,11 @@ def reset_memory():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_functions():
|
def mock_functions():
|
||||||
"""Mock functions used in agent.py"""
|
"""Mock functions used in agent.py"""
|
||||||
with patch('ra_aid.tools.agent.key_fact_repository') as mock_fact_repo, \
|
mock_fact_repo = MagicMock()
|
||||||
|
mock_snippet_repo = MagicMock()
|
||||||
|
with patch('ra_aid.tools.agent.get_key_fact_repository', return_value=mock_fact_repo) as mock_get_fact_repo, \
|
||||||
patch('ra_aid.tools.agent.format_key_facts_dict') as mock_fact_formatter, \
|
patch('ra_aid.tools.agent.format_key_facts_dict') as mock_fact_formatter, \
|
||||||
patch('ra_aid.tools.agent.key_snippet_repository') as mock_snippet_repo, \
|
patch('ra_aid.tools.agent.get_key_snippet_repository', return_value=mock_snippet_repo) as mock_get_snippet_repo, \
|
||||||
patch('ra_aid.tools.agent.format_key_snippets_dict') as mock_snippet_formatter, \
|
patch('ra_aid.tools.agent.format_key_snippets_dict') as mock_snippet_formatter, \
|
||||||
patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \
|
patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \
|
||||||
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
||||||
|
|
@ -60,8 +62,8 @@ def mock_functions():
|
||||||
|
|
||||||
# Return all mocks as a dictionary
|
# Return all mocks as a dictionary
|
||||||
yield {
|
yield {
|
||||||
'key_fact_repository': mock_fact_repo,
|
'get_key_fact_repository': mock_get_fact_repo,
|
||||||
'key_snippet_repository': mock_snippet_repo,
|
'get_key_snippet_repository': mock_get_snippet_repo,
|
||||||
'format_key_facts_dict': mock_fact_formatter,
|
'format_key_facts_dict': mock_fact_formatter,
|
||||||
'format_key_snippets_dict': mock_snippet_formatter,
|
'format_key_snippets_dict': mock_snippet_formatter,
|
||||||
'initialize_llm': mock_llm,
|
'initialize_llm': mock_llm,
|
||||||
|
|
@ -81,11 +83,12 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
|
||||||
result = request_research("test query")
|
result = request_research("test query")
|
||||||
|
|
||||||
# Verify repository was called
|
# Verify repository was called
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||||
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||||
|
|
||||||
# Verify formatter was called with repository results
|
# Verify formatter was called with repository results
|
||||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify formatted facts are used in response
|
# Verify formatted facts are used in response
|
||||||
|
|
@ -105,11 +108,12 @@ def test_request_research_max_depth(reset_memory, mock_functions):
|
||||||
result = request_research("test query")
|
result = request_research("test query")
|
||||||
|
|
||||||
# Verify repository was called
|
# Verify repository was called
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||||
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||||
|
|
||||||
# Verify formatter was called with repository results
|
# Verify formatter was called with repository results
|
||||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify formatted facts are used in response
|
# Verify formatted facts are used in response
|
||||||
|
|
@ -128,11 +132,12 @@ def test_request_research_and_implementation_uses_key_fact_repository(reset_memo
|
||||||
result = request_research_and_implementation("test query")
|
result = request_research_and_implementation("test query")
|
||||||
|
|
||||||
# Verify repository was called
|
# Verify repository was called
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||||
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||||
|
|
||||||
# Verify formatter was called with repository results
|
# Verify formatter was called with repository results
|
||||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify formatted facts are used in response
|
# Verify formatted facts are used in response
|
||||||
|
|
@ -151,11 +156,12 @@ def test_request_implementation_uses_key_fact_repository(reset_memory, mock_func
|
||||||
result = request_implementation("test task")
|
result = request_implementation("test task")
|
||||||
|
|
||||||
# Verify repository was called
|
# Verify repository was called
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||||
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||||
|
|
||||||
# Verify formatter was called with repository results
|
# Verify formatter was called with repository results
|
||||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the formatted key facts are included in the response
|
# Check that the formatted key facts are included in the response
|
||||||
|
|
@ -174,11 +180,12 @@ def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock
|
||||||
result = request_task_implementation("test task")
|
result = request_task_implementation("test task")
|
||||||
|
|
||||||
# Verify repository was called
|
# Verify repository was called
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
mock_functions['get_key_fact_repository'].assert_called_once()
|
||||||
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.assert_called_once()
|
||||||
|
|
||||||
# Verify formatter was called with repository results
|
# Verify formatter was called with repository results
|
||||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
mock_functions['get_key_fact_repository'].return_value.get_facts_dict.return_value
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the formatted key facts are included in the response
|
# Check that the formatted key facts are included in the response
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,12 @@ from ra_aid.tools.memory import (
|
||||||
get_memory_value,
|
get_memory_value,
|
||||||
get_related_files,
|
get_related_files,
|
||||||
get_work_log,
|
get_work_log,
|
||||||
key_fact_repository,
|
|
||||||
key_snippet_repository,
|
|
||||||
log_work_event,
|
log_work_event,
|
||||||
reset_work_log,
|
reset_work_log,
|
||||||
swap_task_order,
|
swap_task_order,
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.connection import DatabaseManager
|
from ra_aid.database.connection import DatabaseManager
|
||||||
from ra_aid.database.models import KeyFact
|
from ra_aid.database.models import KeyFact
|
||||||
|
|
||||||
|
|
@ -60,7 +60,7 @@ def in_memory_db():
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_repository():
|
def mock_repository():
|
||||||
"""Mock the KeyFactRepository to avoid database operations during tests"""
|
"""Mock the KeyFactRepository to avoid database operations during tests"""
|
||||||
with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo:
|
with patch('ra_aid.tools.memory.get_key_fact_repository') as mock_repo:
|
||||||
# Setup the mock repository to behave like the original, but using memory
|
# Setup the mock repository to behave like the original, but using memory
|
||||||
facts = {} # Local in-memory storage
|
facts = {} # Local in-memory storage
|
||||||
fact_id_counter = 0
|
fact_id_counter = 0
|
||||||
|
|
@ -79,12 +79,12 @@ def mock_repository():
|
||||||
facts[fact_id_counter] = fact
|
facts[fact_id_counter] = fact
|
||||||
fact_id_counter += 1
|
fact_id_counter += 1
|
||||||
return fact
|
return fact
|
||||||
mock_repo.create.side_effect = mock_create
|
mock_repo.return_value.create.side_effect = mock_create
|
||||||
|
|
||||||
# Mock get method
|
# Mock get method
|
||||||
def mock_get(fact_id):
|
def mock_get(fact_id):
|
||||||
return facts.get(fact_id)
|
return facts.get(fact_id)
|
||||||
mock_repo.get.side_effect = mock_get
|
mock_repo.return_value.get.side_effect = mock_get
|
||||||
|
|
||||||
# Mock delete method
|
# Mock delete method
|
||||||
def mock_delete(fact_id):
|
def mock_delete(fact_id):
|
||||||
|
|
@ -92,17 +92,17 @@ def mock_repository():
|
||||||
del facts[fact_id]
|
del facts[fact_id]
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
mock_repo.delete.side_effect = mock_delete
|
mock_repo.return_value.delete.side_effect = mock_delete
|
||||||
|
|
||||||
# Mock get_facts_dict method
|
# Mock get_facts_dict method
|
||||||
def mock_get_facts_dict():
|
def mock_get_facts_dict():
|
||||||
return {fact_id: fact.content for fact_id, fact in facts.items()}
|
return {fact_id: fact.content for fact_id, fact in facts.items()}
|
||||||
mock_repo.get_facts_dict.side_effect = mock_get_facts_dict
|
mock_repo.return_value.get_facts_dict.side_effect = mock_get_facts_dict
|
||||||
|
|
||||||
# Mock get_all method
|
# Mock get_all method
|
||||||
def mock_get_all():
|
def mock_get_all():
|
||||||
return list(facts.values())
|
return list(facts.values())
|
||||||
mock_repo.get_all.side_effect = mock_get_all
|
mock_repo.return_value.get_all.side_effect = mock_get_all
|
||||||
|
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
|
@ -159,16 +159,16 @@ def mock_key_snippet_repository():
|
||||||
return list(snippets.values())
|
return list(snippets.values())
|
||||||
|
|
||||||
# Create the actual mocks for both memory.py and key_snippets_gc_agent.py
|
# Create the actual mocks for both memory.py and key_snippets_gc_agent.py
|
||||||
with patch('ra_aid.tools.memory.key_snippet_repository') as memory_mock_repo, \
|
with patch('ra_aid.tools.memory.get_key_snippet_repository') as memory_mock_repo, \
|
||||||
patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository') as agent_mock_repo:
|
patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository') as agent_mock_repo:
|
||||||
|
|
||||||
# Setup both mocks with the same implementation
|
# Setup both mocks with the same implementation
|
||||||
for mock_repo in [memory_mock_repo, agent_mock_repo]:
|
for mock_repo in [memory_mock_repo, agent_mock_repo]:
|
||||||
mock_repo.create.side_effect = mock_create
|
mock_repo.return_value.create.side_effect = mock_create
|
||||||
mock_repo.get.side_effect = mock_get
|
mock_repo.return_value.get.side_effect = mock_get
|
||||||
mock_repo.delete.side_effect = mock_delete
|
mock_repo.return_value.delete.side_effect = mock_delete
|
||||||
mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict
|
mock_repo.return_value.get_snippets_dict.side_effect = mock_get_snippets_dict
|
||||||
mock_repo.get_all.side_effect = mock_get_all
|
mock_repo.return_value.get_all.side_effect = mock_get_all
|
||||||
|
|
||||||
yield memory_mock_repo
|
yield memory_mock_repo
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
||||||
assert result == "Facts stored."
|
assert result == "Facts stored."
|
||||||
|
|
||||||
# Verify the repository's create method was called
|
# Verify the repository's create method was called
|
||||||
mock_repository.create.assert_called_once_with("First fact", human_input_id=ANY)
|
mock_repository.return_value.create.assert_called_once_with("First fact", human_input_id=ANY)
|
||||||
|
|
||||||
|
|
||||||
def test_get_memory_value_other_types(reset_memory):
|
def test_get_memory_value_other_types(reset_memory):
|
||||||
|
|
@ -264,10 +264,10 @@ def test_emit_key_facts(reset_memory, mock_repository):
|
||||||
assert result == "Facts stored."
|
assert result == "Facts stored."
|
||||||
|
|
||||||
# Verify create was called for each fact
|
# Verify create was called for each fact
|
||||||
assert mock_repository.create.call_count == 3
|
assert mock_repository.return_value.create.call_count == 3
|
||||||
mock_repository.create.assert_any_call("First fact", human_input_id=ANY)
|
mock_repository.return_value.create.assert_any_call("First fact", human_input_id=ANY)
|
||||||
mock_repository.create.assert_any_call("Second fact", human_input_id=ANY)
|
mock_repository.return_value.create.assert_any_call("Second fact", human_input_id=ANY)
|
||||||
mock_repository.create.assert_any_call("Third fact", human_input_id=ANY)
|
mock_repository.return_value.create.assert_any_call("Third fact", human_input_id=ANY)
|
||||||
|
|
||||||
|
|
||||||
def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
||||||
|
|
@ -278,7 +278,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
||||||
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
|
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
|
||||||
|
|
||||||
# Mock the get_all method to return more than 30 facts
|
# Mock the get_all method to return more than 30 facts
|
||||||
mock_repository.get_all.return_value = facts
|
mock_repository.return_value.get_all.return_value = facts
|
||||||
|
|
||||||
# Note on testing approach:
|
# Note on testing approach:
|
||||||
# Rather than trying to mock the dynamic import which is challenging due to
|
# Rather than trying to mock the dynamic import which is challenging due to
|
||||||
|
|
@ -295,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
||||||
|
|
||||||
# Verify that mock_repository.get_all was called,
|
# Verify that mock_repository.get_all was called,
|
||||||
# which is the condition that would trigger the GC agent
|
# which is the condition that would trigger the GC agent
|
||||||
mock_repository.get_all.assert_called_once()
|
mock_repository.return_value.get_all.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
||||||
|
|
@ -315,7 +315,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
||||||
assert result == "Snippet #0 stored."
|
assert result == "Snippet #0 stored."
|
||||||
|
|
||||||
# Verify create was called correctly
|
# Verify create was called correctly
|
||||||
mock_key_snippet_repository.create.assert_called_with(
|
mock_key_snippet_repository.return_value.create.assert_called_with(
|
||||||
filepath="test.py",
|
filepath="test.py",
|
||||||
line_number=10,
|
line_number=10,
|
||||||
snippet="def test():\n pass",
|
snippet="def test():\n pass",
|
||||||
|
|
@ -338,7 +338,7 @@ def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
||||||
assert result == "Snippet #1 stored."
|
assert result == "Snippet #1 stored."
|
||||||
|
|
||||||
# Verify create was called correctly
|
# Verify create was called correctly
|
||||||
mock_key_snippet_repository.create.assert_called_with(
|
mock_key_snippet_repository.return_value.create.assert_called_with(
|
||||||
filepath="main.py",
|
filepath="main.py",
|
||||||
line_number=20,
|
line_number=20,
|
||||||
snippet="print('hello')",
|
snippet="print('hello')",
|
||||||
|
|
@ -379,16 +379,16 @@ def test_delete_key_snippets(mock_log_work_event, reset_memory, mock_key_snippet
|
||||||
mock_key_snippet_repository.reset_mock()
|
mock_key_snippet_repository.reset_mock()
|
||||||
|
|
||||||
# Test deleting mix of valid and invalid IDs
|
# Test deleting mix of valid and invalid IDs
|
||||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||||
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
|
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
|
||||||
|
|
||||||
# Verify success message
|
# Verify success message
|
||||||
assert result == "Snippets deleted."
|
assert result == "Snippets deleted."
|
||||||
|
|
||||||
# Verify repository get was called with correct IDs
|
# Verify repository get was called with correct IDs
|
||||||
mock_key_snippet_repository.get.assert_any_call(0)
|
mock_key_snippet_repository.return_value.get.assert_any_call(0)
|
||||||
mock_key_snippet_repository.get.assert_any_call(1)
|
mock_key_snippet_repository.return_value.get.assert_any_call(1)
|
||||||
mock_key_snippet_repository.get.assert_any_call(999)
|
mock_key_snippet_repository.return_value.get.assert_any_call(999)
|
||||||
|
|
||||||
# We skip verifying delete calls because they are prone to test environment issues
|
# We skip verifying delete calls because they are prone to test environment issues
|
||||||
# The implementation logic will properly delete IDs 0 and 1 but not 999
|
# The implementation logic will properly delete IDs 0 and 1 but not 999
|
||||||
|
|
@ -410,12 +410,12 @@ def test_delete_key_snippets_empty(mock_log_work_event, reset_memory, mock_key_s
|
||||||
mock_key_snippet_repository.reset_mock()
|
mock_key_snippet_repository.reset_mock()
|
||||||
|
|
||||||
# Test with empty list
|
# Test with empty list
|
||||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||||
result = delete_key_snippets.invoke({"snippet_ids": []})
|
result = delete_key_snippets.invoke({"snippet_ids": []})
|
||||||
assert result == "Snippets deleted."
|
assert result == "Snippets deleted."
|
||||||
|
|
||||||
# Verify no call to delete method
|
# Verify no call to delete method
|
||||||
mock_key_snippet_repository.delete.assert_not_called()
|
mock_key_snippet_repository.return_value.delete.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_emit_related_files_basic(reset_memory, tmp_path):
|
def test_emit_related_files_basic(reset_memory, tmp_path):
|
||||||
|
|
@ -458,13 +458,13 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path):
|
||||||
new_file.write_text("# New file")
|
new_file.write_text("# New file")
|
||||||
|
|
||||||
# Add initial files
|
# Add initial files
|
||||||
result = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
|
result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
|
||||||
assert result == "Files noted."
|
assert result1 == "Files noted."
|
||||||
_first_id = 0 # ID of test.py
|
_first_id = 0 # ID of test.py
|
||||||
|
|
||||||
# Try adding duplicates
|
# Try adding duplicates
|
||||||
result = emit_related_files.invoke({"files": [str(test_file)]})
|
result2 = emit_related_files.invoke({"files": [str(test_file)]})
|
||||||
assert result == "Files noted."
|
assert result2 == "Files noted."
|
||||||
assert len(_global_memory["related_files"]) == 2 # Count should not increase
|
assert len(_global_memory["related_files"]) == 2 # Count should not increase
|
||||||
|
|
||||||
# Try mix of new and duplicate files
|
# Try mix of new and duplicate files
|
||||||
|
|
@ -670,7 +670,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
|
||||||
mock_key_snippet_repository.reset_mock()
|
mock_key_snippet_repository.reset_mock()
|
||||||
|
|
||||||
# Delete some but not all snippets (0 and 2)
|
# Delete some but not all snippets (0 and 2)
|
||||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||||
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
|
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
|
||||||
assert result == "Snippets deleted."
|
assert result == "Snippets deleted."
|
||||||
|
|
||||||
|
|
@ -692,7 +692,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
|
||||||
assert result == "Snippet #3 stored."
|
assert result == "Snippet #3 stored."
|
||||||
|
|
||||||
# Verify create was called with correct params
|
# Verify create was called with correct params
|
||||||
mock_key_snippet_repository.create.assert_called_with(
|
mock_key_snippet_repository.return_value.create.assert_called_with(
|
||||||
filepath=file4,
|
filepath=file4,
|
||||||
line_number=40,
|
line_number=40,
|
||||||
snippet="def func4():\n return False",
|
snippet="def func4():\n return False",
|
||||||
|
|
@ -704,7 +704,7 @@ def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_sn
|
||||||
mock_key_snippet_repository.reset_mock()
|
mock_key_snippet_repository.reset_mock()
|
||||||
|
|
||||||
# Delete remaining snippets
|
# Delete remaining snippets
|
||||||
with patch('ra_aid.agents.key_snippets_gc_agent.key_snippet_repository', mock_key_snippet_repository):
|
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
||||||
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
|
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
|
||||||
assert result == "Snippets deleted."
|
assert result == "Snippets deleted."
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue