key facts repository
This commit is contained in:
parent
8819f463a1
commit
935a013a4c
|
|
@ -84,6 +84,8 @@ from ra_aid.tool_configs import (
|
|||
get_web_research_tools,
|
||||
)
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.text.key_facts_formatter import format_key_facts_dict
|
||||
from ra_aid.tools.memory import (
|
||||
_global_memory,
|
||||
get_memory_value,
|
||||
|
|
@ -95,6 +97,9 @@ console = Console()
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Initialize key fact repository
|
||||
key_fact_repository = KeyFactRepository()
|
||||
|
||||
|
||||
@tool
|
||||
def output_markdown_message(message: str) -> str:
|
||||
|
|
@ -381,7 +386,7 @@ def run_research_agent(
|
|||
else ""
|
||||
)
|
||||
|
||||
key_facts = _global_memory.get("key_facts", "")
|
||||
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
|
||||
code_snippets = _global_memory.get("code_snippets", "")
|
||||
related_files = _global_memory.get("related_files", "")
|
||||
|
||||
|
|
@ -520,7 +525,7 @@ def run_web_research_agent(
|
|||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||
|
||||
key_facts = _global_memory.get("key_facts", "")
|
||||
key_facts = format_key_facts_dict(key_fact_repository.get_facts_dict())
|
||||
code_snippets = _global_memory.get("code_snippets", "")
|
||||
related_files = _global_memory.get("related_files", "")
|
||||
|
||||
|
|
@ -640,7 +645,7 @@ def run_planning_agent(
|
|||
project_info=formatted_project_info,
|
||||
research_notes=get_memory_value("research_notes"),
|
||||
related_files="\n".join(get_related_files()),
|
||||
key_facts=get_memory_value("key_facts"),
|
||||
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
key_snippets=get_memory_value("key_snippets"),
|
||||
work_log=get_memory_value("work_log"),
|
||||
research_only_note=(
|
||||
|
|
@ -742,7 +747,7 @@ def run_task_implementation_agent(
|
|||
tasks=tasks,
|
||||
plan=plan,
|
||||
related_files=related_files,
|
||||
key_facts=get_memory_value("key_facts"),
|
||||
key_facts=format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
key_snippets=get_memory_value("key_snippets"),
|
||||
research_notes=get_memory_value("research_notes"),
|
||||
work_log=get_memory_value("work_log"),
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict, List, Optional
|
|||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, get_db
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
|
|
@ -43,10 +43,10 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error creating the fact
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
fact = KeyFact.create(content=content)
|
||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||
return fact
|
||||
db = get_db()
|
||||
fact = KeyFact.create(content=content)
|
||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||
return fact
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create key fact: {str(e)}")
|
||||
raise
|
||||
|
|
@ -65,8 +65,8 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
db = get_db()
|
||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -86,18 +86,18 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error updating the fact
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
|
||||
return None
|
||||
|
||||
# Update the fact
|
||||
fact.content = content
|
||||
fact.save()
|
||||
logger.debug(f"Updated key fact ID {fact_id}: {content}")
|
||||
return fact
|
||||
db = get_db()
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
|
||||
return None
|
||||
|
||||
# Update the fact
|
||||
fact.content = content
|
||||
fact.save()
|
||||
logger.debug(f"Updated key fact ID {fact_id}: {content}")
|
||||
return fact
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -116,17 +116,17 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error deleting the fact
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
|
||||
return False
|
||||
|
||||
# Delete the fact
|
||||
fact.delete_instance()
|
||||
logger.debug(f"Deleted key fact ID {fact_id}")
|
||||
return True
|
||||
db = get_db()
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
|
||||
return False
|
||||
|
||||
# Delete the fact
|
||||
fact.delete_instance()
|
||||
logger.debug(f"Deleted key fact ID {fact_id}")
|
||||
return True
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -142,8 +142,8 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
return list(KeyFact.select().order_by(KeyFact.id))
|
||||
db = get_db()
|
||||
return list(KeyFact.select().order_by(KeyFact.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ Retention priority (from highest to lowest):
|
|||
- Build and deployment information
|
||||
- Testing approaches
|
||||
- Low-level implementation details that are easily rediscovered
|
||||
- If there are contradictory facts, that probably means that the older fact is no longer true and should be deleted.
|
||||
|
||||
For facts of similar importance, prefer to keep more recent facts if they supersede older information.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for the text module."""
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
"""Unit tests for the key_facts_formatter module."""
|
||||
|
||||
from ra_aid.text.key_facts_formatter import format_key_fact, format_key_facts_dict
|
||||
|
||||
|
||||
class TestKeyFactsFormatter:
|
||||
"""Test cases for key facts formatting functions."""
|
||||
|
||||
def test_format_key_fact(self):
|
||||
"""Test formatting a single key fact."""
|
||||
# Test with valid input
|
||||
formatted = format_key_fact(1, "This is an important fact")
|
||||
assert formatted == "## 🔑 Key Fact #1\n\nThis is an important fact"
|
||||
|
||||
# Test with large ID number
|
||||
formatted = format_key_fact(999, "Fact with large ID")
|
||||
assert formatted == "## 🔑 Key Fact #999\n\nFact with large ID"
|
||||
|
||||
# Test with empty content
|
||||
formatted = format_key_fact(5, "")
|
||||
assert formatted == ""
|
||||
|
||||
# Test with multi-line content
|
||||
multi_line = "Line 1\nLine 2\nLine 3"
|
||||
formatted = format_key_fact(3, multi_line)
|
||||
assert formatted == f"## 🔑 Key Fact #3\n\n{multi_line}"
|
||||
|
||||
def test_format_key_facts_dict(self):
|
||||
"""Test formatting a dictionary of key facts."""
|
||||
# Test with multiple facts
|
||||
facts_dict = {
|
||||
1: "First fact",
|
||||
2: "Second fact",
|
||||
5: "Fifth fact"
|
||||
}
|
||||
formatted = format_key_facts_dict(facts_dict)
|
||||
expected = (
|
||||
"## 🔑 Key Fact #1\n\nFirst fact\n\n"
|
||||
"## 🔑 Key Fact #2\n\nSecond fact\n\n"
|
||||
"## 🔑 Key Fact #5\n\nFifth fact"
|
||||
)
|
||||
assert formatted == expected
|
||||
|
||||
# Test with empty dictionary
|
||||
formatted = format_key_facts_dict({})
|
||||
assert formatted == ""
|
||||
|
||||
# Test with None value
|
||||
formatted = format_key_facts_dict(None)
|
||||
assert formatted == ""
|
||||
|
||||
# Test with single fact
|
||||
formatted = format_key_facts_dict({3: "Only fact"})
|
||||
assert formatted == "## 🔑 Key Fact #3\n\nOnly fact"
|
||||
|
||||
# Test ordering - should be ordered by key
|
||||
unordered_dict = {
|
||||
5: "Fifth",
|
||||
1: "First",
|
||||
3: "Third"
|
||||
}
|
||||
formatted = format_key_facts_dict(unordered_dict)
|
||||
expected = (
|
||||
"## 🔑 Key Fact #1\n\nFirst\n\n"
|
||||
"## 🔑 Key Fact #3\n\nThird\n\n"
|
||||
"## 🔑 Key Fact #5\n\nFifth"
|
||||
)
|
||||
assert formatted == expected
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Key facts text formatting module.
|
||||
|
||||
This module provides utility functions for formatting key facts
|
||||
with consistent markdown styling.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def format_key_fact(fact_id: int, content: str) -> str:
|
||||
"""
|
||||
Format a single key fact with markdown formatting.
|
||||
|
||||
Args:
|
||||
fact_id: The identifier of the fact
|
||||
content: The text content of the fact
|
||||
|
||||
Returns:
|
||||
str: Formatted key fact as markdown
|
||||
|
||||
Example:
|
||||
>>> format_key_fact(1, "This is an important fact")
|
||||
'## 🔑 Key Fact #1\n\nThis is an important fact'
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
return f"## 🔑 Key Fact #{fact_id}\n\n{content}"
|
||||
|
||||
|
||||
def format_key_facts_dict(facts_dict: Dict[int, str]) -> str:
|
||||
"""
|
||||
Format a dictionary of key facts with consistent markdown formatting.
|
||||
|
||||
Args:
|
||||
facts_dict: Dictionary mapping fact IDs to content strings
|
||||
|
||||
Returns:
|
||||
str: Formatted key facts as markdown with proper spacing and headings
|
||||
|
||||
Example:
|
||||
>>> format_key_facts_dict({1: "First fact", 2: "Second fact"})
|
||||
'## 🔑 Key Fact #1\n\nFirst fact\n\n## 🔑 Key Fact #2\n\nSecond fact'
|
||||
"""
|
||||
if not facts_dict:
|
||||
return ""
|
||||
|
||||
# Sort by ID for consistent output and format as markdown sections
|
||||
facts = []
|
||||
for fact_id, content in sorted(facts_dict.items()):
|
||||
facts.extend([
|
||||
format_key_fact(fact_id, content),
|
||||
"" # Empty line between facts
|
||||
])
|
||||
|
||||
# Join all facts and remove trailing newline
|
||||
return "\n".join(facts).rstrip()
|
||||
|
|
@ -12,7 +12,9 @@ from ra_aid.agent_context import (
|
|||
reset_completion_flags,
|
||||
)
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.text.key_facts_formatter import format_key_facts_dict
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
from ..console import print_task_header
|
||||
|
|
@ -27,6 +29,7 @@ CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user.
|
|||
RESEARCH_AGENT_RECURSION_LIMIT = 3
|
||||
|
||||
console = Console()
|
||||
key_fact_repository = KeyFactRepository()
|
||||
|
||||
|
||||
@tool("request_research")
|
||||
|
|
@ -53,7 +56,7 @@ def request_research(query: str) -> ResearchResult:
|
|||
print_error("Maximum research recursion depth reached")
|
||||
return {
|
||||
"completion_message": "Research stopped - maximum recursion depth reached",
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"related_files": get_related_files(),
|
||||
"research_notes": get_memory_value("research_notes"),
|
||||
"key_snippets": get_memory_value("key_snippets"),
|
||||
|
|
@ -101,7 +104,7 @@ def request_research(query: str) -> ResearchResult:
|
|||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"related_files": get_related_files(),
|
||||
"research_notes": get_memory_value("research_notes"),
|
||||
"key_snippets": get_memory_value("key_snippets"),
|
||||
|
|
@ -235,7 +238,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
|||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"related_files": get_related_files(),
|
||||
"research_notes": get_memory_value("research_notes"),
|
||||
"key_snippets": get_memory_value("key_snippets"),
|
||||
|
|
@ -319,7 +322,7 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
crash_message = get_crash_message() if agent_crashed else None
|
||||
|
||||
response_data = {
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"related_files": get_related_files(),
|
||||
"key_snippets": get_memory_value("key_snippets"),
|
||||
"completion_message": completion_message,
|
||||
|
|
@ -440,7 +443,7 @@ def request_implementation(task_spec: str) -> str:
|
|||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
"key_facts": format_key_facts_dict(key_fact_repository.get_facts_dict()),
|
||||
"related_files": get_related_files(),
|
||||
"key_snippets": get_memory_value("key_snippets"),
|
||||
"success": success and not agent_crashed,
|
||||
|
|
@ -497,4 +500,4 @@ def request_implementation(task_spec: str) -> str:
|
|||
# Join all parts into a single markdown string
|
||||
markdown_output = "".join(markdown_parts)
|
||||
|
||||
return markdown_output
|
||||
return markdown_output
|
||||
|
|
@ -6,11 +6,14 @@ from rich.console import Console
|
|||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ..database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ..llm import initialize_expert_llm
|
||||
from ..text.key_facts_formatter import format_key_facts_dict
|
||||
from .memory import _global_memory, get_memory_value
|
||||
|
||||
console = Console()
|
||||
_model = None
|
||||
key_fact_repository = KeyFactRepository()
|
||||
|
||||
|
||||
def get_model():
|
||||
|
|
@ -148,7 +151,9 @@ def ask_expert(question: str) -> str:
|
|||
file_paths = list(_global_memory["related_files"].values())
|
||||
related_contents = read_related_files(file_paths)
|
||||
key_snippets = get_memory_value("key_snippets")
|
||||
key_facts = get_memory_value("key_facts")
|
||||
# Get key facts directly from repository and format using the formatter
|
||||
facts_dict = key_fact_repository.get_facts_dict()
|
||||
key_facts = format_key_facts_dict(facts_dict)
|
||||
research_notes = get_memory_value("research_notes")
|
||||
|
||||
# Build display query (just question)
|
||||
|
|
@ -203,4 +208,4 @@ def ask_expert(question: str) -> str:
|
|||
Panel(Markdown(response.content), title="Expert Response", border_style="blue")
|
||||
)
|
||||
|
||||
return response.content
|
||||
return response.content
|
||||
|
|
@ -577,10 +577,13 @@ def deregister_related_files(file_ids: List[int]) -> str:
|
|||
|
||||
|
||||
def get_memory_value(key: str) -> str:
|
||||
"""Get a value from global memory.
|
||||
"""
|
||||
Get a value from global memory.
|
||||
|
||||
Note: Key facts are now handled by KeyFactRepository and the key_facts_formatter module,
|
||||
not through this function.
|
||||
|
||||
Different memory types return different formats:
|
||||
- key_facts: Returns numbered list of facts in format '#ID: fact'
|
||||
- key_snippets: Returns formatted snippets with file path, line number and content
|
||||
- All other types: Returns newline-separated list of values
|
||||
|
||||
|
|
@ -589,48 +592,9 @@ def get_memory_value(key: str) -> str:
|
|||
|
||||
Returns:
|
||||
String representation of the memory values:
|
||||
- For key_facts: '#ID: fact' format, one per line
|
||||
- For key_snippets: Formatted snippet blocks
|
||||
- For other types: One value per line
|
||||
"""
|
||||
if key == "key_facts":
|
||||
try:
|
||||
# Get facts from repository as a dictionary
|
||||
facts_dict = key_fact_repository.get_facts_dict()
|
||||
|
||||
# For empty dict, return empty string
|
||||
if not facts_dict:
|
||||
return ""
|
||||
|
||||
# Sort by ID for consistent output and format as markdown sections
|
||||
facts = []
|
||||
for k, v in sorted(facts_dict.items()):
|
||||
facts.extend(
|
||||
[
|
||||
f"## 🔑 Key Fact #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
v,
|
||||
"", # Empty line between facts
|
||||
]
|
||||
)
|
||||
return "\n".join(facts).rstrip() # Remove trailing newline
|
||||
except Exception:
|
||||
# Fallback to old memory if database access fails
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
facts = []
|
||||
for k, v in sorted(values.items()):
|
||||
facts.extend(
|
||||
[
|
||||
f"## 🔑 Key Fact #{k}",
|
||||
"",
|
||||
v,
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(facts).rstrip()
|
||||
|
||||
if key == "key_snippets":
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,192 @@
|
|||
"""
|
||||
Tests for the KeyFactRepository class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the KeyFact table."""
|
||||
# Initialize an in-memory database connection
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Create the KeyFact table
|
||||
with db.atomic():
|
||||
db.create_tables([KeyFact], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
KeyFact.drop_table(safe=True)
|
||||
|
||||
|
||||
def test_create_key_fact(setup_db):
|
||||
"""Test creating a key fact."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
|
||||
# Create a key fact
|
||||
content = "Test key fact"
|
||||
fact = repo.create(content)
|
||||
|
||||
# Verify the fact was created correctly
|
||||
assert fact.id is not None
|
||||
assert fact.content == content
|
||||
|
||||
# Verify we can retrieve it from the database
|
||||
fact_from_db = KeyFact.get_by_id(fact.id)
|
||||
assert fact_from_db.content == content
|
||||
|
||||
|
||||
def test_get_key_fact(setup_db):
|
||||
"""Test retrieving a key fact by ID."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
|
||||
# Create a key fact
|
||||
content = "Test key fact"
|
||||
fact = repo.create(content)
|
||||
|
||||
# Retrieve the fact by ID
|
||||
retrieved_fact = repo.get(fact.id)
|
||||
|
||||
# Verify the retrieved fact matches the original
|
||||
assert retrieved_fact is not None
|
||||
assert retrieved_fact.id == fact.id
|
||||
assert retrieved_fact.content == content
|
||||
|
||||
# Try to retrieve a non-existent fact
|
||||
non_existent_fact = repo.get(999)
|
||||
assert non_existent_fact is None
|
||||
|
||||
|
||||
def test_update_key_fact(setup_db):
|
||||
"""Test updating a key fact."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
|
||||
# Create a key fact
|
||||
original_content = "Original content"
|
||||
fact = repo.create(original_content)
|
||||
|
||||
# Update the fact
|
||||
new_content = "Updated content"
|
||||
updated_fact = repo.update(fact.id, new_content)
|
||||
|
||||
# Verify the fact was updated correctly
|
||||
assert updated_fact is not None
|
||||
assert updated_fact.id == fact.id
|
||||
assert updated_fact.content == new_content
|
||||
|
||||
# Verify we can retrieve the updated content from the database
|
||||
fact_from_db = KeyFact.get_by_id(fact.id)
|
||||
assert fact_from_db.content == new_content
|
||||
|
||||
# Try to update a non-existent fact
|
||||
non_existent_update = repo.update(999, "This shouldn't work")
|
||||
assert non_existent_update is None
|
||||
|
||||
|
||||
def test_delete_key_fact(setup_db):
|
||||
"""Test deleting a key fact."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
|
||||
# Create a key fact
|
||||
content = "Test key fact to delete"
|
||||
fact = repo.create(content)
|
||||
|
||||
# Verify the fact exists
|
||||
assert KeyFact.get_or_none(KeyFact.id == fact.id) is not None
|
||||
|
||||
# Delete the fact
|
||||
delete_result = repo.delete(fact.id)
|
||||
|
||||
# Verify the delete operation was successful
|
||||
assert delete_result is True
|
||||
|
||||
# Verify the fact no longer exists in the database
|
||||
assert KeyFact.get_or_none(KeyFact.id == fact.id) is None
|
||||
|
||||
# Try to delete a non-existent fact
|
||||
non_existent_delete = repo.delete(999)
|
||||
assert non_existent_delete is False
|
||||
|
||||
|
||||
def test_get_all_key_facts(setup_db):
|
||||
"""Test retrieving all key facts."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
|
||||
# Create some key facts
|
||||
contents = ["Fact 1", "Fact 2", "Fact 3"]
|
||||
for content in contents:
|
||||
repo.create(content)
|
||||
|
||||
# Retrieve all facts
|
||||
all_facts = repo.get_all()
|
||||
|
||||
# Verify we got the correct number of facts
|
||||
assert len(all_facts) == len(contents)
|
||||
|
||||
# Verify the content of each fact
|
||||
fact_contents = [fact.content for fact in all_facts]
|
||||
for content in contents:
|
||||
assert content in fact_contents
|
||||
|
||||
|
||||
def test_get_facts_dict(setup_db):
|
||||
"""Test retrieving key facts as a dictionary."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
|
||||
# Create some key facts
|
||||
facts = []
|
||||
contents = ["Fact 1", "Fact 2", "Fact 3"]
|
||||
for content in contents:
|
||||
facts.append(repo.create(content))
|
||||
|
||||
# Retrieve facts as dictionary
|
||||
facts_dict = repo.get_facts_dict()
|
||||
|
||||
# Verify we got the correct number of facts
|
||||
assert len(facts_dict) == len(contents)
|
||||
|
||||
# Verify each fact is in the dictionary with the correct content
|
||||
for fact in facts:
|
||||
assert fact.id in facts_dict
|
||||
assert facts_dict[fact.id] == fact.content
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for the agent.py module to verify KeyFactRepository integration."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from ra_aid.tools.agent import (
|
||||
request_research,
|
||||
request_task_implementation,
|
||||
request_implementation,
|
||||
request_research_and_implementation,
|
||||
)
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_memory():
|
||||
"""Reset global memory before each test"""
|
||||
_global_memory["key_facts"] = {}
|
||||
_global_memory["key_fact_id_counter"] = 0
|
||||
_global_memory["key_snippets"] = {}
|
||||
_global_memory["key_snippet_id_counter"] = 0
|
||||
_global_memory["research_notes"] = []
|
||||
_global_memory["plans"] = []
|
||||
_global_memory["tasks"] = {}
|
||||
_global_memory["task_id_counter"] = 0
|
||||
_global_memory["related_files"] = {}
|
||||
_global_memory["related_file_id_counter"] = 0
|
||||
_global_memory["work_log"] = []
|
||||
yield
|
||||
# Clean up after test
|
||||
_global_memory["key_facts"] = {}
|
||||
_global_memory["key_fact_id_counter"] = 0
|
||||
_global_memory["key_snippets"] = {}
|
||||
_global_memory["key_snippet_id_counter"] = 0
|
||||
_global_memory["research_notes"] = []
|
||||
_global_memory["plans"] = []
|
||||
_global_memory["tasks"] = {}
|
||||
_global_memory["task_id_counter"] = 0
|
||||
_global_memory["related_files"] = {}
|
||||
_global_memory["related_file_id_counter"] = 0
|
||||
_global_memory["work_log"] = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_functions():
|
||||
"""Mock functions used in agent.py"""
|
||||
with patch('ra_aid.tools.agent.key_fact_repository') as mock_repo, \
|
||||
patch('ra_aid.tools.agent.format_key_facts_dict') as mock_formatter, \
|
||||
patch('ra_aid.tools.agent.initialize_llm') as mock_llm, \
|
||||
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
||||
patch('ra_aid.tools.agent.get_memory_value') as mock_get_memory, \
|
||||
patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \
|
||||
patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \
|
||||
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion:
|
||||
|
||||
# Setup mock return values
|
||||
mock_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"}
|
||||
mock_formatter.return_value = "Formatted facts"
|
||||
mock_llm.return_value = MagicMock()
|
||||
mock_get_files.return_value = ["file1.py", "file2.py"]
|
||||
mock_get_memory.return_value = "Test memory value"
|
||||
mock_get_work_log.return_value = "Test work log"
|
||||
mock_get_completion.return_value = "Task completed"
|
||||
|
||||
# Return all mocks as a dictionary
|
||||
yield {
|
||||
'key_fact_repository': mock_repo,
|
||||
'format_key_facts_dict': mock_formatter,
|
||||
'initialize_llm': mock_llm,
|
||||
'get_related_files': mock_get_files,
|
||||
'get_memory_value': mock_get_memory,
|
||||
'get_work_log': mock_get_work_log,
|
||||
'reset_completion_flags': mock_reset,
|
||||
'get_completion_message': mock_get_completion
|
||||
}
|
||||
|
||||
|
||||
def test_request_research_uses_key_fact_repository(reset_memory, mock_functions):
|
||||
"""Test that request_research uses KeyFactRepository directly with formatting."""
|
||||
# Mock running the research agent
|
||||
with patch('ra_aid.agent_utils.run_research_agent'):
|
||||
# Call the function
|
||||
result = request_research("test query")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Verify formatted facts are used in response
|
||||
assert result["key_facts"] == "Formatted facts"
|
||||
|
||||
# Verify get_memory_value is not called with "key_facts"
|
||||
for call in mock_functions['get_memory_value'].call_args_list:
|
||||
assert call[0][0] != "key_facts"
|
||||
|
||||
|
||||
def test_request_research_max_depth(reset_memory, mock_functions):
|
||||
"""Test that max recursion depth handling uses KeyFactRepository."""
|
||||
# Set recursion depth to max
|
||||
_global_memory["agent_depth"] = 3
|
||||
|
||||
# Call the function (should hit max depth case)
|
||||
result = request_research("test query")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Verify formatted facts are used in response
|
||||
assert result["key_facts"] == "Formatted facts"
|
||||
|
||||
# Verify get_memory_value is not called with "key_facts"
|
||||
for call in mock_functions['get_memory_value'].call_args_list:
|
||||
assert call[0][0] != "key_facts"
|
||||
|
||||
|
||||
def test_request_research_and_implementation_uses_key_fact_repository(reset_memory, mock_functions):
|
||||
"""Test that request_research_and_implementation uses KeyFactRepository correctly."""
|
||||
# Mock running the research agent
|
||||
with patch('ra_aid.agent_utils.run_research_agent'):
|
||||
# Call the function
|
||||
result = request_research_and_implementation("test query")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Verify formatted facts are used in response
|
||||
assert result["key_facts"] == "Formatted facts"
|
||||
|
||||
# Verify get_memory_value is not called with "key_facts"
|
||||
for call in mock_functions['get_memory_value'].call_args_list:
|
||||
assert call[0][0] != "key_facts"
|
||||
|
||||
|
||||
def test_request_implementation_uses_key_fact_repository(reset_memory, mock_functions):
|
||||
"""Test that request_implementation uses KeyFactRepository correctly."""
|
||||
# Mock running the planning agent
|
||||
with patch('ra_aid.agent_utils.run_planning_agent'):
|
||||
# Call the function
|
||||
result = request_implementation("test task")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Check that the formatted key facts are included in the response
|
||||
assert "Formatted facts" in result
|
||||
|
||||
|
||||
def test_request_task_implementation_uses_key_fact_repository(reset_memory, mock_functions):
|
||||
"""Test that request_task_implementation uses KeyFactRepository correctly."""
|
||||
# Set up _global_memory with required values
|
||||
_global_memory["tasks"] = {0: "Task 1"}
|
||||
_global_memory["base_task"] = "Base task"
|
||||
|
||||
# Mock running the implementation agent
|
||||
with patch('ra_aid.agent_utils.run_task_implementation_agent'):
|
||||
# Call the function
|
||||
result = request_task_implementation("test task")
|
||||
|
||||
# Verify repository was called
|
||||
mock_functions['key_fact_repository'].get_facts_dict.assert_called_once()
|
||||
|
||||
# Verify formatter was called with repository results
|
||||
mock_functions['format_key_facts_dict'].assert_called_once_with(
|
||||
mock_functions['key_fact_repository'].get_facts_dict.return_value
|
||||
)
|
||||
|
||||
# Check that the formatted key facts are included in the response
|
||||
assert "Formatted facts" in result
|
||||
|
|
@ -123,26 +123,6 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
|||
mock_repository.create.assert_called_once_with("First fact")
|
||||
|
||||
|
||||
def test_get_memory_value_key_facts(reset_memory, mock_repository):
|
||||
"""Test get_memory_value with key facts dictionary"""
|
||||
# Empty key facts should return empty string
|
||||
assert get_memory_value("key_facts") == ""
|
||||
|
||||
# Add some facts through the mocked repository
|
||||
fact1 = mock_repository.create("First fact")
|
||||
fact2 = mock_repository.create("Second fact")
|
||||
|
||||
# Mock get_facts_dict to return our test data
|
||||
mock_repository.get_facts_dict.return_value = {
|
||||
fact1.id: "First fact",
|
||||
fact2.id: "Second fact"
|
||||
}
|
||||
|
||||
# Should return markdown formatted list
|
||||
expected = f"## 🔑 Key Fact #{fact1.id}\n\nFirst fact\n\n## 🔑 Key Fact #{fact2.id}\n\nSecond fact"
|
||||
assert get_memory_value("key_facts") == expected
|
||||
|
||||
|
||||
def test_get_memory_value_other_types(reset_memory):
|
||||
"""Test get_memory_value remains compatible with other memory types"""
|
||||
# Add some research notes
|
||||
|
|
|
|||
Loading…
Reference in New Issue