key facts repository

This commit is contained in:
AI Christianson 2025-03-02 16:00:55 -05:00
parent 8819f463a1
commit 935a013a4c
12 changed files with 569 additions and 105 deletions

View File

@ -84,6 +84,8 @@ from ra_aid.tool_configs import (
get_web_research_tools,
)
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.text.key_facts_formatter import format_key_facts_dict
from ra_aid.tools.memory import (
_global_memory,
get_memory_value,
@ -95,6 +97,9 @@ console = Console()
logger = get_logger(__name__)
# Initialize key fact repository
key_fact_repository = KeyFactRepository()
@tool
def output_markdown_message(message: str) -> str:
@ -381,7 +386,7 @@ def run_research_agent(
else ""
)
key_facts = _global_memory.get("key_facts", "")
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
@ -520,7 +525,7 @@ def run_web_research_agent(
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
key_facts = _global_memory.get("key_facts", "")
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
@ -640,7 +645,7 @@ def run_planning_agent(
project_info=formatted_project_info,
research_notes=get_memory_value("research_notes"),
related_files="\n".join(get_related_files()),
key_facts=get_memory_value("key_facts"),
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
key_snippets=get_memory_value("key_snippets"),
work_log=get_memory_value("work_log"),
research_only_note=(
@ -742,7 +747,7 @@ def run_task_implementation_agent(
tasks=tasks,
plan=plan,
related_files=related_files,
key_facts=get_memory_value("key_facts"),
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
key_snippets=get_memory_value("key_snippets"),
research_notes=get_memory_value("research_notes"),
work_log=get_memory_value("work_log"),

View File

@ -9,7 +9,7 @@ from typing import Dict, List, Optional
import peewee
from ra_aid.database.connection import DatabaseManager, get_db
from ra_aid.database.connection import get_db
from ra_aid.database.models import KeyFact
from ra_aid.logging_config import get_logger
@ -43,10 +43,10 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error creating the fact
"""
try:
with DatabaseManager() as db:
fact = KeyFact.create(content=content)
logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact
db = get_db()
fact = KeyFact.create(content=content)
logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact
except peewee.DatabaseError as e:
logger.error(f"Failed to create key fact: {str(e)}")
raise
@ -65,8 +65,8 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
with DatabaseManager() as db:
return KeyFact.get_or_none(KeyFact.id == fact_id)
db = get_db()
return KeyFact.get_or_none(KeyFact.id == fact_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
raise
@ -86,18 +86,18 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error updating the fact
"""
try:
with DatabaseManager() as db:
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
return None
# Update the fact
fact.content = content
fact.save()
logger.debug(f"Updated key fact ID {fact_id}: {content}")
return fact
db = get_db()
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
return None
# Update the fact
fact.content = content
fact.save()
logger.debug(f"Updated key fact ID {fact_id}: {content}")
return fact
except peewee.DatabaseError as e:
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
raise
@ -116,17 +116,17 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error deleting the fact
"""
try:
with DatabaseManager() as db:
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
return False
# Delete the fact
fact.delete_instance()
logger.debug(f"Deleted key fact ID {fact_id}")
return True
db = get_db()
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
return False
# Delete the fact
fact.delete_instance()
logger.debug(f"Deleted key fact ID {fact_id}")
return True
except peewee.DatabaseError as e:
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
raise
@ -142,8 +142,8 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
with DatabaseManager() as db:
return list(KeyFact.select().order_by(KeyFact.id))
db = get_db()
return list(KeyFact.select().order_by(KeyFact.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key facts: {str(e)}")
raise

View File

@ -37,6 +37,7 @@ Retention priority (from highest to lowest):
- Build and deployment information
- Testing approaches
- Low-level implementation details that are easily rediscovered
- If there are contradictory facts, that probably means that the older fact is no longer true and should be deleted.
For facts of similar importance, prefer to keep more recent facts if they supersede older information.

View File

@ -0,0 +1 @@
"""Tests for the text module."""

View File

@ -0,0 +1,68 @@
"""Unit tests for the key_facts_formatter module."""
from ra_aid.text.key_facts_formatter import format_key_fact, format_key_facts_dict
class TestKeyFactsFormatter:
"""Test cases for key facts formatting functions."""
def test_format_key_fact(self):
"""Test formatting a single key fact."""
# Test with valid input
formatted = format_key_fact(1, "This is an important fact")
assert formatted == "## 🔑 Key Fact #1\n\nThis is an important fact"
# Test with large ID number
formatted = format_key_fact(999, "Fact with large ID")
assert formatted == "## 🔑 Key Fact #999\n\nFact with large ID"
# Test with empty content
formatted = format_key_fact(5, "")
assert formatted == ""
# Test with multi-line content
multi_line = "Line 1\nLine 2\nLine 3"
formatted = format_key_fact(3, multi_line)
assert formatted == f"## 🔑 Key Fact #3\n\n{multi_line}"
def test_format_key_facts_dict(self):
"""Test formatting a dictionary of key facts."""
# Test with multiple facts
facts_dict = {
1: "First fact",
2: "Second fact",
5: "Fifth fact"
}
formatted = format_key_facts_dict(facts_dict)
expected = (
"## 🔑 Key Fact #1\n\nFirst fact\n\n"
"## 🔑 Key Fact #2\n\nSecond fact\n\n"
"## 🔑 Key Fact #5\n\nFifth fact"
)
assert formatted == expected
# Test with empty dictionary
formatted = format_key_facts_dict({})
assert formatted == ""
# Test with None value
formatted = format_key_facts_dict(None)
assert formatted == ""
# Test with single fact
formatted = format_key_facts_dict({3: "Only fact"})
assert formatted == "## 🔑 Key Fact #3\n\nOnly fact"
# Test ordering - should be ordered by key
unordered_dict = {
5: "Fifth",
1: "First",
3: "Third"
}
formatted = format_key_facts_dict(unordered_dict)
expected = (
"## 🔑 Key Fact #1\n\nFirst\n\n"
"## 🔑 Key Fact #3\n\nThird\n\n"
"## 🔑 Key Fact #5\n\nFifth"
)
assert formatted == expected

View File

@ -0,0 +1,58 @@
"""
Key facts text formatting module.
This module provides utility functions for formatting key facts
with consistent markdown styling.
"""
from typing import Dict, Optional
def format_key_fact(fact_id: int, content: str) -> str:
"""
Format a single key fact with markdown formatting.
Args:
fact_id: The identifier of the fact
content: The text content of the fact
Returns:
str: Formatted key fact as markdown
Example:
>>> format_key_fact(1, "This is an important fact")
'## 🔑 Key Fact #1\n\nThis is an important fact'
"""
if not content:
return ""
return f"## 🔑 Key Fact #{fact_id}\n\n{content}"
def format_key_facts_dict(facts_dict: Dict[int, str]) -> str:
"""
Format a dictionary of key facts with consistent markdown formatting.
Args:
facts_dict: Dictionary mapping fact IDs to content strings
Returns:
str: Formatted key facts as markdown with proper spacing and headings
Example:
>>> format_key_facts_dict({1: "First fact", 2: "Second fact"})
'## 🔑 Key Fact #1\n\nFirst fact\n\n## 🔑 Key Fact #2\n\nSecond fact'
"""
if not facts_dict:
return ""
# Sort by ID for consistent output and format as markdown sections
facts = []
for fact_id, content in sorted(facts_dict.items()):
facts.extend([
format_key_fact(fact_id, content),
"" # Empty line between facts
])
# Join all facts and remove trailing newline
return "\n".join(facts).rstrip()

View File

@ -12,7 +12,9 @@ from ra_aid.agent_context import (
reset_completion_flags,
)
from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.exceptions import AgentInterrupt
from ra_aid.text.key_facts_formatter import format_key_facts_dict
from ra_aid.tools.memory import _global_memory
from ..console import print_task_header
@ -27,6 +29,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user.
RESEARCH_AGENT_RECURSION_LIMIT = 3
console = Console()
key_fact_repository = KeyFactRepository()
@tool("request_research")
@ -53,7 +56,7 @@ def request_research(query: str) -> ResearchResult:
print_error("Maximum research recursion depth reached")
return {
"completion_message": "Research stopped - maximum recursion depth reached",
"key_facts": get_memory_value("key_facts"),
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"),
"key_snippets": get_memory_value("key_snippets"),
@ -101,7 +104,7 @@ def request_research(query: str) -> ResearchResult:
response_data = {
"completion_message": completion_message,
"key_facts": get_memory_value("key_facts"),
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"),
"key_snippets": get_memory_value("key_snippets"),
@ -235,7 +238,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
response_data = {
"completion_message": completion_message,
"key_facts": get_memory_value("key_facts"),
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"),
"key_snippets": get_memory_value("key_snippets"),
@ -319,7 +322,7 @@ def request_task_implementation(task_spec: str) -> str:
crash_message = get_crash_message() if agent_crashed else None
response_data = {
"key_facts": get_memory_value("key_facts"),
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"related_files": get_related_files(),
"key_snippets": get_memory_value("key_snippets"),
"completion_message": completion_message,
@ -440,7 +443,7 @@ def request_implementation(task_spec: str) -> str:
response_data = {
"completion_message": completion_message,
"key_facts": get_memory_value("key_facts"),
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
"related_files": get_related_files(),
"key_snippets": get_memory_value("key_snippets"),
"success": success and not agent_crashed,
@ -497,4 +500,4 @@ def request_implementation(task_spec: str) -> str:
# Join all parts into a single markdown string
markdown_output = "".join(markdown_parts)
return markdown_output
return markdown_output

View File

@ -6,11 +6,14 @@ from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ..database.repositories.key_fact_repository import KeyFactRepository
from ..llm import initialize_expert_llm
from ..text.key_facts_formatter import format_key_facts_dict
from .memory import _global_memory, get_memory_value
console = Console()
_model = None
key_fact_repository = KeyFactRepository()
def get_model():
@ -148,7 +151,9 @@ def ask_expert(question: str) -> str:
file_paths = list(_global_memory["related_files"].values())
related_contents = read_related_files(file_paths)
key_snippets = get_memory_value("key_snippets")
key_facts = get_memory_value("key_facts")
# Get key facts directly from repository and format using the formatter
facts_dict = key_fact_repository.get_facts_dict()
key_facts = format_key_facts_dict(facts_dict)
research_notes = get_memory_value("research_notes")
# Build display query (just question)
@ -203,4 +208,4 @@ def ask_expert(question: str) -> str:
Panel(Markdown(response.content), title="Expert Response", border_style="blue")
)
return response.content
return response.content

View File

@ -577,10 +577,13 @@ def deregister_related_files(file_ids: List[int]) -> str:
def get_memory_value(key: str) -> str:
"""Get a value from global memory.
"""
Get a value from global memory.
Note: Key facts are now handled by KeyFactRepository and the key_facts_formatter module,
not through this function.
Different memory types return different formats:
- key_facts: Returns numbered list of facts in format '#ID: fact'
- key_snippets: Returns formatted snippets with file path, line number and content
- All other types: Returns newline-separated list of values
@ -589,48 +592,9 @@ def get_memory_value(key: str) -> str:
Returns:
String representation of the memory values:
- For key_facts: '#ID: fact' format, one per line
- For key_snippets: Formatted snippet blocks
- For other types: One value per line
"""
if key == "key_facts":
try:
# Get facts from repository as a dictionary
facts_dict = key_fact_repository.get_facts_dict()
# For empty dict, return empty string
if not facts_dict:
return ""
# Sort by ID for consistent output and format as markdown sections
facts = []
for k, v in sorted(facts_dict.items()):
facts.extend(
[
f"## 🔑 Key Fact #{k}",
"", # Empty line for better markdown spacing
v,
"", # Empty line between facts
]
)
return "\n".join(facts).rstrip() # Remove trailing newline
except Exception:
# Fallback to old memory if database access fails
values = _global_memory.get(key, {})
if not values:
return ""
facts = []
for k, v in sorted(values.items()):
facts.extend(
[
f"## 🔑 Key Fact #{k}",
"",
v,
"",
]
)
return "\n".join(facts).rstrip()
if key == "key_snippets":
values = _global_memory.get(key, {})
if not values:

View File

@ -0,0 +1,192 @@
"""
Tests for the KeyFactRepository class.
"""
import pytest
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import KeyFact
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def setup_db(cleanup_db):
"""Set up an in-memory database with the KeyFact table."""
# Initialize an in-memory database connection
with DatabaseManager(in_memory=True) as db:
# Create the KeyFact table
with db.atomic():
db.create_tables([KeyFact], safe=True)
yield db
# Clean up
with db.atomic():
KeyFact.drop_table(safe=True)
def test_create_key_fact(setup_db):
"""Test creating a key fact."""
# Set up repository
repo = KeyFactRepository()
# Create a key fact
content = "Test key fact"
fact = repo.create(content)
# Verify the fact was created correctly
assert fact.id is not None
assert fact.content == content
# Verify we can retrieve it from the database
fact_from_db = KeyFact.get_by_id(fact.id)
assert fact_from_db.content == content
def test_get_key_fact(setup_db):
"""Test retrieving a key fact by ID."""
# Set up repository
repo = KeyFactRepository()
# Create a key fact
content = "Test key fact"
fact = repo.create(content)
# Retrieve the fact by ID
retrieved_fact = repo.get(fact.id)
# Verify the retrieved fact matches the original
assert retrieved_fact is not None
assert retrieved_fact.id == fact.id
assert retrieved_fact.content == content
# Try to retrieve a non-existent fact
non_existent_fact = repo.get(999)
assert non_existent_fact is None
def test_update_key_fact(setup_db):
"""Test updating a key fact."""
# Set up repository
repo = KeyFactRepository()
# Create a key fact
original_content = "Original content"
fact = repo.create(original_content)
# Update the fact
new_content = "Updated content"
updated_fact = repo.update(fact.id, new_content)
# Verify the fact was updated correctly
assert updated_fact is not None
assert updated_fact.id == fact.id
assert updated_fact.content == new_content
# Verify we can retrieve the updated content from the database
fact_from_db = KeyFact.get_by_id(fact.id)
assert fact_from_db.content == new_content
# Try to update a non-existent fact
non_existent_update = repo.update(999, "This shouldn't work")
assert non_existent_update is None
def test_delete_key_fact(setup_db):
"""Test deleting a key fact."""
# Set up repository
repo = KeyFactRepository()
# Create a key fact
content = "Test key fact to delete"
fact = repo.create(content)
# Verify the fact exists
assert KeyFact.get_or_none(KeyFact.id == fact.id) is not None
# Delete the fact
delete_result = repo.delete(fact.id)
# Verify the delete operation was successful
assert delete_result is True
# Verify the fact no longer exists in the database
assert KeyFact.get_or_none(KeyFact.id == fact.id) is None
# Try to delete a non-existent fact
non_existent_delete = repo.delete(999)
assert non_existent_delete is False
def test_get_all_key_facts(setup_db):
"""Test retrieving all key facts."""
# Set up repository
repo = KeyFactRepository()
# Create some key facts
contents = ["Fact 1", "Fact 2", "Fact 3"]
for content in contents:
repo.create(content)
# Retrieve all facts
all_facts = repo.get_all()
# Verify we got the correct number of facts
assert len(all_facts) == len(contents)
# Verify the content of each fact
fact_contents = [fact.content for fact in all_facts]
for content in contents:
assert content in fact_contents
def test_get_facts_dict(setup_db):
"""Test retrieving key facts as a dictionary."""
# Set up repository
repo = KeyFactRepository()
# Create some key facts
facts = []
contents = ["Fact 1", "Fact 2", "Fact 3"]
for content in contents:
facts.append(repo.create(content))
# Retrieve facts as dictionary
facts_dict = repo.get_facts_dict()
# Verify we got the correct number of facts
assert len(facts_dict) == len(contents)
# Verify each fact is in the dictionary with the correct content
for fact in facts:
assert fact.id in facts_dict
assert facts_dict[fact.id] == fact.content

View File

@ -0,0 +1,187 @@
"""Tests for the agent.py module to verify KeyFactRepository integration."""
import pytest
from unittest.mock import patch, MagicMock
from ra_aid.tools.agent import (
request_research,
request_task_implementation,
request_implementation,
request_research_and_implementation,
)
from ra_aid.tools.memory import _global_memory
@pytest.fixture
def reset_memory():
"""Reset global memory before each test"""
_global_memory["key_facts"] = {}
_global_memory["key_fact_id_counter"] = 0
_global_memory["key_snippets"] = {}
_global_memory["key_snippet_id_counter"] = 0
_global_memory["research_notes"] = []
_global_memory["plans"] = []
_global_memory["tasks"] = {}
_global_memory["task_id_counter"] = 0
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 0
_global_memory["work_log"] = []
yield
# Clean up after test
_global_memory["key_facts"] = {}
_global_memory["key_fact_id_counter"] = 0
_global_memory["key_snippets"] = {}
_global_memory["key_snippet_id_counter"] = 0
_global_memory["research_notes"] = []
_global_memory["plans"] = []
_global_memory["tasks"] = {}
_global_memory["task_id_counter"] = 0
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 0
_global_memory["work_log"] = []
@pytest.fixture
def mock_functions():
"""Mock functions used in agent.py"""
with patch('ra_aid.tools.agent.key_fact_repository') as mock_repo, \
patch('ra_aid.tools.agent.format_key_facts_dict') as mock_formatter, \
patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
patch('ra_aid.tools.agent.get_memory_value') as mock_get_memory, \
patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \
patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion:
# Setup mock return values
mock_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"}
mock_formatter.return_value = "Formatted facts"
mock_llm.return_value = MagicMock()
mock_get_files.return_value = ["file1.py", "file2.py"]
mock_get_memory.return_value = "Test memory value"
mock_get_work_log.return_value = "Test work log"
mock_get_completion.return_value = "Task completed"
# Return all mocks as a dictionary
yield {
'key_fact_repository': mock_repo,
'format_key_facts_dict': mock_formatter,
'initialize_llm': mock_llm,
'get_related_files': mock_get_files,
'get_memory_value': mock_get_memory,
'get_work_log': mock_get_work_log,
'reset_completion_flags': mock_reset,
'get_completion_message': mock_get_completion
}
def test_request_research_uses_key_fact_repository(reset_memory, mock_functions):
"""Test that request_research uses KeyFactRepository directly with formatting."""
# Mock running the research agent
with patch('ra_aid.agent_utils.run_research_agent'):
# Call the function
result = request_research("test query")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
)
# Verify formatted facts are used in response
assert result["key_facts"] == "Formatted facts"
# Verify get_memory_value is not called with "key_facts"
for call in mock_functions['get_memory_value'].call_args_list:
assert call[0][0] != "key_facts"
def test_request_research_max_depth(reset_memory, mock_functions):
"""Test that max recursion depth handling uses KeyFactRepository."""
# Set recursion depth to max
_global_memory["agent_depth"] = 3
# Call the function (should hit max depth case)
result = request_research("test query")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
)
# Verify formatted facts are used in response
assert result["key_facts"] == "Formatted facts"
# Verify get_memory_value is not called with "key_facts"
for call in mock_functions['get_memory_value'].call_args_list:
assert call[0][0] != "key_facts"
def test_request_research_and_implementation_uses_key_fact_repository(reset_memory, mock_functions):
"""Test that request_research_and_implementation uses KeyFactRepository correctly."""
# Mock running the research agent
with patch('ra_aid.agent_utils.run_research_agent'):
# Call the function
result = request_research_and_implementation("test query")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
)
# Verify formatted facts are used in response
assert result["key_facts"] == "Formatted facts"
# Verify get_memory_value is not called with "key_facts"
for call in mock_functions['get_memory_value'].call_args_list:
assert call[0][0] != "key_facts"
def test_request_implementation_uses_key_fact_repository(reset_memory, mock_functions):
"""Test that request_implementation uses KeyFactRepository correctly."""
# Mock running the planning agent
with patch('ra_aid.agent_utils.run_planning_agent'):
# Call the function
result = request_implementation("test task")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
)
# Check that the formatted key facts are included in the response
assert "Formatted facts" in result
def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock_functions):
"""Test that request_task_implementation uses KeyFactRepository correctly."""
# Set up _global_memory with required values
_global_memory["tasks"] = {0: "Task 1"}
_global_memory["base_task"] = "Base task"
# Mock running the implementation agent
with patch('ra_aid.agent_utils.run_task_implementation_agent'):
# Call the function
result = request_task_implementation("test task")
# Verify repository was called
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
# Verify formatter was called with repository results
mock_functions['format_key_facts_dict'].assert_called_once_with(
mock_functions['key_fact_repository'].get_facts_dict.return_value
)
# Check that the formatted key facts are included in the response
assert "Formatted facts" in result

View File

@ -123,26 +123,6 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
mock_repository.create.assert_called_once_with("First fact")
def test_get_memory_value_key_facts(reset_memory, mock_repository):
"""Test get_memory_value with key facts dictionary"""
# Empty key facts should return empty string
assert get_memory_value("key_facts") == ""
# Add some facts through the mocked repository
fact1 = mock_repository.create("First fact")
fact2 = mock_repository.create("Second fact")
# Mock get_facts_dict to return our test data
mock_repository.get_facts_dict.return_value = {
fact1.id: "First fact",
fact2.id: "Second fact"
}
# Should return markdown formatted list
expected = f"## 🔑 Key Fact #{fact1.id}\n\nFirst fact\n\n## 🔑 Key Fact #{fact2.id}\n\nSecond fact"
assert get_memory_value("key_facts") == expected
def test_get_memory_value_other_types(reset_memory):
"""Test get_memory_value remains compatible with other memory types"""
# Add some research notes