diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 540d2b3..2d2b6e0 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -84,6 +84,8 @@ from ra_aid.tool_configs import ( get_web_research_tools, ) from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command +from ra_aid.database.repositories.key_fact_repository import KeyFactRepository +from ra_aid.text.key_facts_formatter import format_key_facts_dict from ra_aid.tools.memory import ( _global_memory, get_memory_value, @@ -95,6 +97,9 @@ console = Console() logger = get_logger(__name__) +# Initialize key fact repository +key_fact_repository = KeyFactRepository() + @tool def output_markdown_message(message: str) -> str: @@ -381,7 +386,7 @@ def run_research_agent( else "" ) - key_facts = _global_memory.get("key_facts", "") + key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict()) code_snippets = _global_memory.get("code_snippets", "") related_files = _global_memory.get("related_files", "") @@ -520,7 +525,7 @@ def run_web_research_agent( expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" - key_facts = _global_memory.get("key_facts", "") + key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict()) code_snippets = _global_memory.get("code_snippets", "") related_files = _global_memory.get("related_files", "") @@ -640,7 +645,7 @@ def run_planning_agent( project_info=formatted_project_info, research_notes=get_memory_value("research_notes"), related_files="\n".join(get_related_files()), - key_facts=get_memory_value("key_facts"), + key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()), key_snippets=get_memory_value("key_snippets"), work_log=get_memory_value("work_log"), research_only_note=( @@ -742,7 +747,7 @@ def run_task_implementation_agent( tasks=tasks, plan=plan, related_files=related_files, - key_facts=get_memory_value("key_facts"), + key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()), key_snippets=get_memory_value("key_snippets"), research_notes=get_memory_value("research_notes"), work_log=get_memory_value("work_log"), diff --git a/ra_aid/database/repositories/key_fact_repository.py b/ra_aid/database/repositories/key_fact_repository.py index ffa9441..1d0e38a 100644 --- a/ra_aid/database/repositories/key_fact_repository.py +++ b/ra_aid/database/repositories/key_fact_repository.py @@ -9,7 +9,7 @@ from typing import Dict, List, Optional import peewee -from ra_aid.database.connection import DatabaseManager, get_db +from ra_aid.database.connection import get_db from ra_aid.database.models import KeyFact from ra_aid.logging_config import get_logger @@ -43,10 +43,10 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error creating the fact """ try: - with DatabaseManager() as db: - fact = KeyFact.create(content=content) - logger.debug(f"Created key fact ID {fact.id}: {content}") - return fact + db = get_db() + fact = KeyFact.create(content=content) + logger.debug(f"Created key fact ID {fact.id}: {content}") + return fact except peewee.DatabaseError as e: logger.error(f"Failed to create key fact: {str(e)}") raise @@ -65,8 +65,8 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - with DatabaseManager() as db: - return KeyFact.get_or_none(KeyFact.id == fact_id) + db = get_db() + return KeyFact.get_or_none(KeyFact.id == fact_id) except peewee.DatabaseError as e: logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}") raise @@ -86,18 +86,18 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error updating the fact """ try: - with DatabaseManager() as db: - # First check if the fact exists - fact = self.get(fact_id) - if not fact: - logger.warning(f"Attempted to update non-existent key fact {fact_id}") - return None - - # Update the fact - fact.content = content - fact.save() - logger.debug(f"Updated key fact ID {fact_id}: {content}") - return fact + db = get_db() + # First check if the fact exists + fact = self.get(fact_id) + if not fact: + logger.warning(f"Attempted to update non-existent key fact {fact_id}") + return None + + # Update the fact + fact.content = content + fact.save() + logger.debug(f"Updated key fact ID {fact_id}: {content}") + return fact except peewee.DatabaseError as e: logger.error(f"Failed to update key fact {fact_id}: {str(e)}") raise @@ -116,17 +116,17 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error deleting the fact """ try: - with DatabaseManager() as db: - # First check if the fact exists - fact = self.get(fact_id) - if not fact: - logger.warning(f"Attempted to delete non-existent key fact {fact_id}") - return False - - # Delete the fact - fact.delete_instance() - logger.debug(f"Deleted key fact ID {fact_id}") - return True + db = get_db() + # First check if the fact exists + fact = self.get(fact_id) + if not fact: + logger.warning(f"Attempted to delete non-existent key fact {fact_id}") + return False + + # Delete the fact + fact.delete_instance() + logger.debug(f"Deleted key fact ID {fact_id}") + return True except peewee.DatabaseError as e: logger.error(f"Failed to delete key fact {fact_id}: {str(e)}") raise @@ -142,8 +142,8 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - with DatabaseManager() as db: - return list(KeyFact.select().order_by(KeyFact.id)) + db = get_db() + return list(KeyFact.select().order_by(KeyFact.id)) except peewee.DatabaseError as e: logger.error(f"Failed to fetch all key facts: {str(e)}") raise diff --git a/ra_aid/prompts/key_facts_gc_prompts.py b/ra_aid/prompts/key_facts_gc_prompts.py index f3070c6..31eb43c 100644 --- a/ra_aid/prompts/key_facts_gc_prompts.py +++ b/ra_aid/prompts/key_facts_gc_prompts.py @@ -37,6 +37,7 @@ Retention priority (from highest to lowest): - Build and deployment information - Testing approaches - Low-level implementation details that are easily rediscovered +- If there are contradictory facts, that probably means that the older fact is no longer true and should be deleted. For facts of similar importance, prefer to keep more recent facts if they supersede older information. diff --git a/ra_aid/tests/ra_aid/text/__init__.py b/ra_aid/tests/ra_aid/text/__init__.py new file mode 100644 index 0000000..5f6b9c8 --- /dev/null +++ b/ra_aid/tests/ra_aid/text/__init__.py @@ -0,0 +1 @@ +"""Tests for the text module.""" \ No newline at end of file diff --git a/ra_aid/tests/ra_aid/text/test_key_facts_formatter.py b/ra_aid/tests/ra_aid/text/test_key_facts_formatter.py new file mode 100644 index 0000000..49eed41 --- /dev/null +++ b/ra_aid/tests/ra_aid/text/test_key_facts_formatter.py @@ -0,0 +1,68 @@ +"""Unit tests for the key_facts_formatter module.""" + +from ra_aid.text.key_facts_formatter import format_key_fact, format_key_facts_dict + + +class TestKeyFactsFormatter: + """Test cases for key facts formatting functions.""" + + def test_format_key_fact(self): + """Test formatting a single key fact.""" + # Test with valid input + formatted = format_key_fact(1, "This is an important fact") + assert formatted == "## 🔑 Key Fact #1\n\nThis is an important fact" + + # Test with large ID number + formatted = format_key_fact(999, "Fact with large ID") + assert formatted == "## 🔑 Key Fact #999\n\nFact with large ID" + + # Test with empty content + formatted = format_key_fact(5, "") + assert formatted == "" + + # Test with multi-line content + multi_line = "Line 1\nLine 2\nLine 3" + formatted = format_key_fact(3, multi_line) + assert formatted == f"## 🔑 Key Fact #3\n\n{multi_line}" + + def test_format_key_facts_dict(self): + """Test formatting a dictionary of key facts.""" + # Test with multiple facts + facts_dict = { + 1: "First fact", + 2: "Second fact", + 5: "Fifth fact" + } + formatted = format_key_facts_dict(facts_dict) + expected = ( + "## 🔑 Key Fact #1\n\nFirst fact\n\n" + "## 🔑 Key Fact #2\n\nSecond fact\n\n" + "## 🔑 Key Fact #5\n\nFifth fact" + ) + assert formatted == expected + + # Test with empty dictionary + formatted = format_key_facts_dict({}) + assert formatted == "" + + # Test with None value + formatted = format_key_facts_dict(None) + assert formatted == "" + + # Test with single fact + formatted = format_key_facts_dict({3: "Only fact"}) + assert formatted == "## 🔑 Key Fact #3\n\nOnly fact" + + # Test ordering - should be ordered by key + unordered_dict = { + 5: "Fifth", + 1: "First", + 3: "Third" + } + formatted = format_key_facts_dict(unordered_dict) + expected = ( + "## 🔑 Key Fact #1\n\nFirst\n\n" + "## 🔑 Key Fact #3\n\nThird\n\n" + "## 🔑 Key Fact #5\n\nFifth" + ) + assert formatted == expected \ No newline at end of file diff --git a/ra_aid/text/key_facts_formatter.py b/ra_aid/text/key_facts_formatter.py new file mode 100644 index 0000000..f30caca --- /dev/null +++ b/ra_aid/text/key_facts_formatter.py @@ -0,0 +1,58 @@ +""" +Key facts text formatting module. + +This module provides utility functions for formatting key facts +with consistent markdown styling. +""" + +from typing import Dict, Optional + + +def format_key_fact(fact_id: int, content: str) -> str: + """ + Format a single key fact with markdown formatting. + + Args: + fact_id: The identifier of the fact + content: The text content of the fact + + Returns: + str: Formatted key fact as markdown + + Example: + >>> format_key_fact(1, "This is an important fact") + '## 🔑 Key Fact #1\n\nThis is an important fact' + """ + if not content: + return "" + + return f"## 🔑 Key Fact #{fact_id}\n\n{content}" + + +def format_key_facts_dict(facts_dict: Dict[int, str]) -> str: + """ + Format a dictionary of key facts with consistent markdown formatting. + + Args: + facts_dict: Dictionary mapping fact IDs to content strings + + Returns: + str: Formatted key facts as markdown with proper spacing and headings + + Example: + >>> format_key_facts_dict({1: "First fact", 2: "Second fact"}) + '## 🔑 Key Fact #1\n\nFirst fact\n\n## 🔑 Key Fact #2\n\nSecond fact' + """ + if not facts_dict: + return "" + + # Sort by ID for consistent output and format as markdown sections + facts = [] + for fact_id, content in sorted(facts_dict.items()): + facts.extend([ + format_key_fact(fact_id, content), + "" # Empty line between facts + ]) + + # Join all facts and remove trailing newline + return "\n".join(facts).rstrip() \ No newline at end of file diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 48ad62d..d5bd47d 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -12,7 +12,9 @@ from ra_aid.agent_context import ( reset_completion_flags, ) from ra_aid.console.formatting import print_error +from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.exceptions import AgentInterrupt +from ra_aid.text.key_facts_formatter import format_key_facts_dict from ra_aid.tools.memory import _global_memory from ..console import print_task_header @@ -27,6 +29,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. RESEARCH_AGENT_RECURSION_LIMIT = 3 console = Console() +key_fact_repository = KeyFactRepository() @tool("request_research") @@ -53,7 +56,7 @@ def request_research(query: str) -> ResearchResult: print_error("Maximum research recursion depth reached") return { "completion_message": "Research stopped - maximum recursion depth reached", - "key_facts": get_memory_value("key_facts"), + "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "related_files": get_related_files(), "research_notes": get_memory_value("research_notes"), "key_snippets": get_memory_value("key_snippets"), @@ -101,7 +104,7 @@ def request_research(query: str) -> ResearchResult: response_data = { "completion_message": completion_message, - "key_facts": get_memory_value("key_facts"), + "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "related_files": get_related_files(), "research_notes": get_memory_value("research_notes"), "key_snippets": get_memory_value("key_snippets"), @@ -235,7 +238,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: response_data = { "completion_message": completion_message, - "key_facts": get_memory_value("key_facts"), + "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "related_files": get_related_files(), "research_notes": get_memory_value("research_notes"), "key_snippets": get_memory_value("key_snippets"), @@ -319,7 +322,7 @@ def request_task_implementation(task_spec: str) -> str: crash_message = get_crash_message() if agent_crashed else None response_data = { - "key_facts": get_memory_value("key_facts"), + "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "related_files": get_related_files(), "key_snippets": get_memory_value("key_snippets"), "completion_message": completion_message, @@ -440,7 +443,7 @@ def request_implementation(task_spec: str) -> str: response_data = { "completion_message": completion_message, - "key_facts": get_memory_value("key_facts"), + "key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()), "related_files": get_related_files(), "key_snippets": get_memory_value("key_snippets"), "success": success and not agent_crashed, @@ -497,4 +500,4 @@ def request_implementation(task_spec: str) -> str: # Join all parts into a single markdown string markdown_output = "".join(markdown_parts) - return markdown_output + return markdown_output \ No newline at end of file diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index cd8183e..420cb04 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -6,11 +6,14 @@ from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel +from ..database.repositories.key_fact_repository import KeyFactRepository from ..llm import initialize_expert_llm +from ..text.key_facts_formatter import format_key_facts_dict from .memory import _global_memory, get_memory_value console = Console() _model = None +key_fact_repository = KeyFactRepository() def get_model(): @@ -148,7 +151,9 @@ def ask_expert(question: str) -> str: file_paths = list(_global_memory["related_files"].values()) related_contents = read_related_files(file_paths) key_snippets = get_memory_value("key_snippets") - key_facts = get_memory_value("key_facts") + # Get key facts directly from repository and format using the formatter + facts_dict = key_fact_repository.get_facts_dict() + key_facts = format_key_facts_dict(facts_dict) research_notes = get_memory_value("research_notes") # Build display query (just question) @@ -203,4 +208,4 @@ def ask_expert(question: str) -> str: Panel(Markdown(response.content), title="Expert Response", border_style="blue") ) - return response.content + return response.content \ No newline at end of file diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 1430c54..3c2fbef 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -577,10 +577,13 @@ def deregister_related_files(file_ids: List[int]) -> str: def get_memory_value(key: str) -> str: - """Get a value from global memory. + """ + Get a value from global memory. + + Note: Key facts are now handled by KeyFactRepository and the key_facts_formatter module, + not through this function. Different memory types return different formats: - - key_facts: Returns numbered list of facts in format '#ID: fact' - key_snippets: Returns formatted snippets with file path, line number and content - All other types: Returns newline-separated list of values @@ -589,48 +592,9 @@ def get_memory_value(key: str) -> str: Returns: String representation of the memory values: - - For key_facts: '#ID: fact' format, one per line - For key_snippets: Formatted snippet blocks - For other types: One value per line """ - if key == "key_facts": - try: - # Get facts from repository as a dictionary - facts_dict = key_fact_repository.get_facts_dict() - - # For empty dict, return empty string - if not facts_dict: - return "" - - # Sort by ID for consistent output and format as markdown sections - facts = [] - for k, v in sorted(facts_dict.items()): - facts.extend( - [ - f"## 🔑 Key Fact #{k}", - "", # Empty line for better markdown spacing - v, - "", # Empty line between facts - ] - ) - return "\n".join(facts).rstrip() # Remove trailing newline - except Exception: - # Fallback to old memory if database access fails - values = _global_memory.get(key, {}) - if not values: - return "" - facts = [] - for k, v in sorted(values.items()): - facts.extend( - [ - f"## 🔑 Key Fact #{k}", - "", - v, - "", - ] - ) - return "\n".join(facts).rstrip() - if key == "key_snippets": values = _global_memory.get(key, {}) if not values: diff --git a/tests/ra_aid/database/test_key_fact_repository.py b/tests/ra_aid/database/test_key_fact_repository.py new file mode 100644 index 0000000..c65e6d7 --- /dev/null +++ b/tests/ra_aid/database/test_key_fact_repository.py @@ -0,0 +1,192 @@ +""" +Tests for the KeyFactRepository class. +""" + +import pytest + +from ra_aid.database.connection import DatabaseManager, db_var +from ra_aid.database.models import KeyFact +from ra_aid.database.repositories.key_fact_repository import KeyFactRepository + + +@pytest.fixture +def cleanup_db(): + """Reset the database contextvar and connection state after each test.""" + # Reset before the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + # Run the test + yield + + # Reset after the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + +@pytest.fixture +def setup_db(cleanup_db): + """Set up an in-memory database with the KeyFact table.""" + # Initialize an in-memory database connection + with DatabaseManager(in_memory=True) as db: + # Create the KeyFact table + with db.atomic(): + db.create_tables([KeyFact], safe=True) + + yield db + + # Clean up + with db.atomic(): + KeyFact.drop_table(safe=True) + + +def test_create_key_fact(setup_db): + """Test creating a key fact.""" + # Set up repository + repo = KeyFactRepository() + + # Create a key fact + content = "Test key fact" + fact = repo.create(content) + + # Verify the fact was created correctly + assert fact.id is not None + assert fact.content == content + + # Verify we can retrieve it from the database + fact_from_db = KeyFact.get_by_id(fact.id) + assert fact_from_db.content == content + + +def test_get_key_fact(setup_db): + """Test retrieving a key fact by ID.""" + # Set up repository + repo = KeyFactRepository() + + # Create a key fact + content = "Test key fact" + fact = repo.create(content) + + # Retrieve the fact by ID + retrieved_fact = repo.get(fact.id) + + # Verify the retrieved fact matches the original + assert retrieved_fact is not None + assert retrieved_fact.id == fact.id + assert retrieved_fact.content == content + + # Try to retrieve a non-existent fact + non_existent_fact = repo.get(999) + assert non_existent_fact is None + + +def test_update_key_fact(setup_db): + """Test updating a key fact.""" + # Set up repository + repo = KeyFactRepository() + + # Create a key fact + original_content = "Original content" + fact = repo.create(original_content) + + # Update the fact + new_content = "Updated content" + updated_fact = repo.update(fact.id, new_content) + + # Verify the fact was updated correctly + assert updated_fact is not None + assert updated_fact.id == fact.id + assert updated_fact.content == new_content + + # Verify we can retrieve the updated content from the database + fact_from_db = KeyFact.get_by_id(fact.id) + assert fact_from_db.content == new_content + + # Try to update a non-existent fact + non_existent_update = repo.update(999, "This shouldn't work") + assert non_existent_update is None + + +def test_delete_key_fact(setup_db): + """Test deleting a key fact.""" + # Set up repository + repo = KeyFactRepository() + + # Create a key fact + content = "Test key fact to delete" + fact = repo.create(content) + + # Verify the fact exists + assert KeyFact.get_or_none(KeyFact.id == fact.id) is not None + + # Delete the fact + delete_result = repo.delete(fact.id) + + # Verify the delete operation was successful + assert delete_result is True + + # Verify the fact no longer exists in the database + assert KeyFact.get_or_none(KeyFact.id == fact.id) is None + + # Try to delete a non-existent fact + non_existent_delete = repo.delete(999) + assert non_existent_delete is False + + +def test_get_all_key_facts(setup_db): + """Test retrieving all key facts.""" + # Set up repository + repo = KeyFactRepository() + + # Create some key facts + contents = ["Fact 1", "Fact 2", "Fact 3"] + for content in contents: + repo.create(content) + + # Retrieve all facts + all_facts = repo.get_all() + + # Verify we got the correct number of facts + assert len(all_facts) == len(contents) + + # Verify the content of each fact + fact_contents = [fact.content for fact in all_facts] + for content in contents: + assert content in fact_contents + + +def test_get_facts_dict(setup_db): + """Test retrieving key facts as a dictionary.""" + # Set up repository + repo = KeyFactRepository() + + # Create some key facts + facts = [] + contents = ["Fact 1", "Fact 2", "Fact 3"] + for content in contents: + facts.append(repo.create(content)) + + # Retrieve facts as dictionary + facts_dict = repo.get_facts_dict() + + # Verify we got the correct number of facts + assert len(facts_dict) == len(contents) + + # Verify each fact is in the dictionary with the correct content + for fact in facts: + assert fact.id in facts_dict + assert facts_dict[fact.id] == fact.content \ No newline at end of file diff --git a/tests/ra_aid/tools/test_agent.py b/tests/ra_aid/tools/test_agent.py new file mode 100644 index 0000000..53550f3 --- /dev/null +++ b/tests/ra_aid/tools/test_agent.py @@ -0,0 +1,187 @@ +"""Tests for the agent.py module to verify KeyFactRepository integration.""" + +import pytest +from unittest.mock import patch, MagicMock + +from ra_aid.tools.agent import ( + request_research, + request_task_implementation, + request_implementation, + request_research_and_implementation, +) +from ra_aid.tools.memory import _global_memory + + +@pytest.fixture +def reset_memory(): + """Reset global memory before each test""" + _global_memory["key_facts"] = {} + _global_memory["key_fact_id_counter"] = 0 + _global_memory["key_snippets"] = {} + _global_memory["key_snippet_id_counter"] = 0 + _global_memory["research_notes"] = [] + _global_memory["plans"] = [] + _global_memory["tasks"] = {} + _global_memory["task_id_counter"] = 0 + _global_memory["related_files"] = {} + _global_memory["related_file_id_counter"] = 0 + _global_memory["work_log"] = [] + yield + # Clean up after test + _global_memory["key_facts"] = {} + _global_memory["key_fact_id_counter"] = 0 + _global_memory["key_snippets"] = {} + _global_memory["key_snippet_id_counter"] = 0 + _global_memory["research_notes"] = [] + _global_memory["plans"] = [] + _global_memory["tasks"] = {} + _global_memory["task_id_counter"] = 0 + _global_memory["related_files"] = {} + _global_memory["related_file_id_counter"] = 0 + _global_memory["work_log"] = [] + + +@pytest.fixture +def mock_functions(): + """Mock functions used in agent.py""" + with patch('ra_aid.tools.agent.key_fact_repository') as mock_repo, \ + patch('ra_aid.tools.agent.format_key_facts_dict') as mock_formatter, \ + patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \ + patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \ + patch('ra_aid.tools.agent.get_memory_value') as mock_get_memory, \ + patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \ + patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \ + patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion: + + # Setup mock return values + mock_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"} + mock_formatter.return_value = "Formatted facts" + mock_llm.return_value = MagicMock() + mock_get_files.return_value = ["file1.py", "file2.py"] + mock_get_memory.return_value = "Test memory value" + mock_get_work_log.return_value = "Test work log" + mock_get_completion.return_value = "Task completed" + + # Return all mocks as a dictionary + yield { + 'key_fact_repository': mock_repo, + 'format_key_facts_dict': mock_formatter, + 'initialize_llm': mock_llm, + 'get_related_files': mock_get_files, + 'get_memory_value': mock_get_memory, + 'get_work_log': mock_get_work_log, + 'reset_completion_flags': mock_reset, + 'get_completion_message': mock_get_completion + } + + +def test_request_research_uses_key_fact_repository(reset_memory, mock_functions): + """Test that request_research uses KeyFactRepository directly with formatting.""" + # Mock running the research agent + with patch('ra_aid.agent_utils.run_research_agent'): + # Call the function + result = request_research("test query") + + # Verify repository was called + mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + + # Verify formatter was called with repository results + mock_functions['format_key_facts_dict'].assert_called_once_with( + mock_functions['key_fact_repository'].get_facts_dict.return_value + ) + + # Verify formatted facts are used in response + assert result["key_facts"] == "Formatted facts" + + # Verify get_memory_value is not called with "key_facts" + for call in mock_functions['get_memory_value'].call_args_list: + assert call[0][0] != "key_facts" + + +def test_request_research_max_depth(reset_memory, mock_functions): + """Test that max recursion depth handling uses KeyFactRepository.""" + # Set recursion depth to max + _global_memory["agent_depth"] = 3 + + # Call the function (should hit max depth case) + result = request_research("test query") + + # Verify repository was called + mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + + # Verify formatter was called with repository results + mock_functions['format_key_facts_dict'].assert_called_once_with( + mock_functions['key_fact_repository'].get_facts_dict.return_value + ) + + # Verify formatted facts are used in response + assert result["key_facts"] == "Formatted facts" + + # Verify get_memory_value is not called with "key_facts" + for call in mock_functions['get_memory_value'].call_args_list: + assert call[0][0] != "key_facts" + + +def test_request_research_and_implementation_uses_key_fact_repository(reset_memory, mock_functions): + """Test that request_research_and_implementation uses KeyFactRepository correctly.""" + # Mock running the research agent + with patch('ra_aid.agent_utils.run_research_agent'): + # Call the function + result = request_research_and_implementation("test query") + + # Verify repository was called + mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + + # Verify formatter was called with repository results + mock_functions['format_key_facts_dict'].assert_called_once_with( + mock_functions['key_fact_repository'].get_facts_dict.return_value + ) + + # Verify formatted facts are used in response + assert result["key_facts"] == "Formatted facts" + + # Verify get_memory_value is not called with "key_facts" + for call in mock_functions['get_memory_value'].call_args_list: + assert call[0][0] != "key_facts" + + +def test_request_implementation_uses_key_fact_repository(reset_memory, mock_functions): + """Test that request_implementation uses KeyFactRepository correctly.""" + # Mock running the planning agent + with patch('ra_aid.agent_utils.run_planning_agent'): + # Call the function + result = request_implementation("test task") + + # Verify repository was called + mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + + # Verify formatter was called with repository results + mock_functions['format_key_facts_dict'].assert_called_once_with( + mock_functions['key_fact_repository'].get_facts_dict.return_value + ) + + # Check that the formatted key facts are included in the response + assert "Formatted facts" in result + + +def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock_functions): + """Test that request_task_implementation uses KeyFactRepository correctly.""" + # Set up _global_memory with required values + _global_memory["tasks"] = {0: "Task 1"} + _global_memory["base_task"] = "Base task" + + # Mock running the implementation agent + with patch('ra_aid.agent_utils.run_task_implementation_agent'): + # Call the function + result = request_task_implementation("test task") + + # Verify repository was called + mock_functions['key_fact_repository'].get_facts_dict.assert_called_once() + + # Verify formatter was called with repository results + mock_functions['format_key_facts_dict'].assert_called_once_with( + mock_functions['key_fact_repository'].get_facts_dict.return_value + ) + + # Check that the formatted key facts are included in the response + assert "Formatted facts" in result \ No newline at end of file diff --git a/tests/ra_aid/tools/test_memory.py b/tests/ra_aid/tools/test_memory.py index 8530270..db8cfbf 100644 --- a/tests/ra_aid/tools/test_memory.py +++ b/tests/ra_aid/tools/test_memory.py @@ -123,26 +123,6 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository): mock_repository.create.assert_called_once_with("First fact") -def test_get_memory_value_key_facts(reset_memory, mock_repository): - """Test get_memory_value with key facts dictionary""" - # Empty key facts should return empty string - assert get_memory_value("key_facts") == "" - - # Add some facts through the mocked repository - fact1 = mock_repository.create("First fact") - fact2 = mock_repository.create("Second fact") - - # Mock get_facts_dict to return our test data - mock_repository.get_facts_dict.return_value = { - fact1.id: "First fact", - fact2.id: "Second fact" - } - - # Should return markdown formatted list - expected = f"## 🔑 Key Fact #{fact1.id}\n\nFirst fact\n\n## 🔑 Key Fact #{fact2.id}\n\nSecond fact" - assert get_memory_value("key_facts") == expected - - def test_get_memory_value_other_types(reset_memory): """Test get_memory_value remains compatible with other memory types""" # Add some research notes