key fact db

This commit is contained in:
AI Christianson 2025-03-01 22:41:35 -05:00
parent 50d618c8f8
commit 23d5e267f4
10 changed files with 761 additions and 99 deletions

View File

@ -39,7 +39,7 @@ class MigrationManager:
Args:
db_path: Optional path to the database file. If None, uses the default.
migrations_dir: Optional path to the migrations directory. If None, uses default.
migrations_dir: Optional path to the migrations directory. If None, uses source package migrations.
"""
self.db = get_db()
@ -54,28 +54,56 @@ class MigrationManager:
# Determine migrations directory
if migrations_dir is None:
# Use a directory within .ra-aid
ra_aid_dir = os.path.dirname(self.db_path)
migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME)
# Use the source package migrations directory
migrations_dir = self._get_source_package_migrations_dir()
logger.debug(f"Using source package migrations directory: {migrations_dir}")
else:
# Use the specified migrations directory
# Ensure the directory exists if a custom path is provided
self._ensure_migrations_dir(migrations_dir)
self.migrations_dir = migrations_dir
# Ensure migrations directory exists
self._ensure_migrations_dir()
# Initialize router
self.router = self._init_router()
def _ensure_migrations_dir(self) -> None:
def _get_source_package_migrations_dir(self) -> str:
"""
Get the path to the migrations directory in the source package.
Returns:
str: Path to the source package migrations directory
"""
try:
# Get the base directory of the ra_aid package
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
source_migrations_dir = os.path.join(base_dir, MIGRATIONS_DIRNAME)
if not os.path.exists(source_migrations_dir):
error_msg = f"Source migrations directory not found: {source_migrations_dir}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
logger.debug(f"Found source migrations directory: {source_migrations_dir}")
return source_migrations_dir
except Exception as e:
error_msg = f"Failed to locate source migrations directory: {str(e)}"
logger.error(error_msg)
raise
def _ensure_migrations_dir(self, migrations_dir: str) -> None:
"""
Ensure that the migrations directory exists.
Creates the directory if it doesn't exist.
Args:
migrations_dir: Path to the migrations directory
"""
try:
migrations_path = Path(self.migrations_dir)
migrations_path = Path(migrations_dir)
if not migrations_path.exists():
logger.debug(f"Creating migrations directory at: {self.migrations_dir}")
logger.debug(f"Creating migrations directory at: {migrations_dir}")
migrations_path.mkdir(parents=True, exist_ok=True)
# Create __init__.py to make it a proper package
@ -83,7 +111,7 @@ class MigrationManager:
if not init_file.exists():
init_file.touch()
logger.debug(f"Using migrations directory: {self.migrations_dir}")
logger.debug(f"Using migrations directory: {migrations_dir}")
except Exception as e:
logger.error(f"Failed to create migrations directory: {str(e)}")
raise
@ -233,19 +261,36 @@ def ensure_migrations_applied() -> bool:
"""
Check for and apply any pending migrations.
Creates the .ra-aid directory if it doesn't exist,
but uses migrations directly from the source package.
This function should be called during application startup to ensure
the database schema is up to date.
Returns:
bool: True if migrations were applied successfully or none were pending
"""
try:
# Ensure .ra-aid directory exists for the database file
cwd = os.getcwd()
ra_aid_dir = os.path.join(cwd, ".ra-aid")
os.makedirs(ra_aid_dir, exist_ok=True)
# Use source package migrations directory
import ra_aid
package_dir = os.path.dirname(os.path.abspath(ra_aid.__file__))
migrations_dir = os.path.join(package_dir, MIGRATIONS_DIRNAME)
with DatabaseManager() as db:
try:
migration_manager = init_migrations()
migration_manager = init_migrations(migrations_dir=migrations_dir)
return migration_manager.apply_migrations()
except Exception as e:
logger.error(f"Failed to apply migrations: {str(e)}")
return False
except Exception as e:
logger.error(f"Failed to ensure .ra-aid directory exists: {str(e)}")
return False
def create_new_migration(name: str, auto: bool = True) -> Optional[str]:

View File

@ -62,3 +62,17 @@ class BaseModel(peewee.Model):
# Log the error with logger
logger.error(f"Failed in get_or_create: {str(e)}")
raise
class KeyFact(BaseModel):
"""
Model representing a key fact stored in the database.
Key facts are important information about the project or current task
that need to be referenced later.
"""
content = peewee.TextField()
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "key_fact"

View File

View File

@ -0,0 +1,168 @@
"""
Key fact repository implementation for database access.
This module provides a repository implementation for the KeyFact model,
following the repository pattern for data access abstraction.
"""
from typing import Dict, List, Optional
import peewee
from ra_aid.database.connection import DatabaseManager, get_db
from ra_aid.database.models import KeyFact
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
class KeyFactRepository:
"""
Repository for managing KeyFact database operations.
This class provides methods for performing CRUD operations on the KeyFact model,
abstracting the database access details from the business logic.
Example:
repo = KeyFactRepository()
fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
"""
def create(self, content: str) -> KeyFact:
"""
Create a new key fact in the database.
Args:
content: The text content of the key fact
Returns:
KeyFact: The newly created key fact instance
Raises:
peewee.DatabaseError: If there's an error creating the fact
"""
try:
with DatabaseManager() as db:
fact = KeyFact.create(content=content)
logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact
except peewee.DatabaseError as e:
logger.error(f"Failed to create key fact: {str(e)}")
raise
def get(self, fact_id: int) -> Optional[KeyFact]:
"""
Retrieve a key fact by its ID.
Args:
fact_id: The ID of the key fact to retrieve
Returns:
Optional[KeyFact]: The key fact instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
with DatabaseManager() as db:
return KeyFact.get_or_none(KeyFact.id == fact_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
raise
def update(self, fact_id: int, content: str) -> Optional[KeyFact]:
"""
Update an existing key fact.
Args:
fact_id: The ID of the key fact to update
content: The new content for the key fact
Returns:
Optional[KeyFact]: The updated key fact if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the fact
"""
try:
with DatabaseManager() as db:
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
return None
# Update the fact
fact.content = content
fact.save()
logger.debug(f"Updated key fact ID {fact_id}: {content}")
return fact
except peewee.DatabaseError as e:
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
raise
def delete(self, fact_id: int) -> bool:
"""
Delete a key fact by its ID.
Args:
fact_id: The ID of the key fact to delete
Returns:
bool: True if the fact was deleted, False if it wasn't found
Raises:
peewee.DatabaseError: If there's an error deleting the fact
"""
try:
with DatabaseManager() as db:
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
return False
# Delete the fact
fact.delete_instance()
logger.debug(f"Deleted key fact ID {fact_id}")
return True
except peewee.DatabaseError as e:
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
raise
def get_all(self) -> List[KeyFact]:
"""
Retrieve all key facts from the database.
Returns:
List[KeyFact]: List of all key fact instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
with DatabaseManager() as db:
return list(KeyFact.select().order_by(KeyFact.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key facts: {str(e)}")
raise
def get_facts_dict(self) -> Dict[int, str]:
"""
Retrieve all key facts as a dictionary mapping IDs to content.
This method is useful for compatibility with the existing memory format.
Returns:
Dict[int, str]: Dictionary with fact IDs as keys and content as values
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
facts = self.get_all()
return {fact.id: fact.content for fact in facts}
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key facts as dictionary: {str(e)}")
raise

View File

@ -0,0 +1,54 @@
"""Peewee migrations -- 002_20250301_212203_add_key_fact_model.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class KeyFact(pw.Model):
id = pw.AutoField()
created_at = pw.DateTimeField()
updated_at = pw.DateTimeField()
content = pw.TextField()
class Meta:
table_name = "key_fact"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model('key_fact')

View File

View File

@ -17,6 +17,7 @@ from ra_aid.agent_context import (
mark_should_exit,
mark_task_completed,
)
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
class WorkLogEntry(TypedDict):
@ -33,14 +34,17 @@ class SnippetInfo(TypedDict):
console = Console()
# Initialize repository for key facts
key_fact_repository = KeyFactRepository()
# Global memory store
_global_memory: Dict[str, Any] = {
"research_notes": [],
"plans": [],
"tasks": {}, # Dict[int, str] - ID to task mapping
"task_id_counter": 1, # Counter for generating unique task IDs
"key_facts": {}, # Dict[int, str] - ID to fact mapping
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
"key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now)
"key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now)
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
"implementation_requested": False,
@ -106,12 +110,9 @@ def emit_key_facts(facts: List[str]) -> str:
"""
results = []
for fact in facts:
# Get and increment fact ID
fact_id = _global_memory["key_fact_id_counter"]
_global_memory["key_fact_id_counter"] += 1
# Store fact with ID
_global_memory["key_facts"][fact_id] = fact
# Create fact in database using repository
created_fact = key_fact_repository.create(fact)
fact_id = created_fact.id
# Display panel with ID
console.print(
@ -139,10 +140,13 @@ def delete_key_facts(fact_ids: List[int]) -> str:
"""
results = []
for fact_id in fact_ids:
if fact_id in _global_memory["key_facts"]:
# Get the fact first to display information
fact = key_fact_repository.get(fact_id)
if fact:
# Delete the fact
deleted_fact = _global_memory["key_facts"].pop(fact_id)
success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}"
was_deleted = key_fact_repository.delete(fact_id)
if was_deleted:
success_msg = f"Successfully deleted fact #{fact_id}: {fact.content}"
console.print(
Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")
)
@ -601,15 +605,18 @@ def get_memory_value(key: str) -> str:
- For key_snippets: Formatted snippet blocks
- For other types: One value per line
"""
values = _global_memory.get(key, [])
if key == "key_facts":
try:
# Get facts from repository as a dictionary
facts_dict = key_fact_repository.get_facts_dict()
# For empty dict, return empty string
if not values:
if not facts_dict:
return ""
# Sort by ID for consistent output and format as markdown sections
facts = []
for k, v in sorted(values.items()):
for k, v in sorted(facts_dict.items()):
facts.extend(
[
f"## 🔑 Key Fact #{k}",
@ -619,8 +626,25 @@ def get_memory_value(key: str) -> str:
]
)
return "\n".join(facts).rstrip() # Remove trailing newline
except Exception:
# Fallback to old memory if database access fails
values = _global_memory.get(key, {})
if not values:
return ""
facts = []
for k, v in sorted(values.items()):
facts.extend(
[
f"## 🔑 Key Fact #{k}",
"",
v,
"",
]
)
return "\n".join(facts).rstrip()
if key == "key_snippets":
values = _global_memory.get(key, {})
if not values:
return ""
# Format each snippet with file info and content using markdown
@ -645,10 +669,12 @@ def get_memory_value(key: str) -> str:
return "\n\n".join(snippets)
if key == "work_log":
values = _global_memory.get(key, [])
if not values:
return ""
entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values]
return "\n\n".join(entries)
# For other types (lists), join with newlines
values = _global_memory.get(key, [])
return "\n".join(str(v) for v in values)

View File

@ -561,10 +561,12 @@ def rollback(migrator, database, fake=False, **kwargs):
# Check migrations
applied, pending = manager.check_migrations()
assert len(applied) == 0
assert len(pending) == 1
assert (
migration_name in pending[0]
) # Instead of exact equality, check if name is contained
# There may be multiple pending migrations (source package migrations + our test migration)
assert len(pending) >= 1
# Make sure our newly created migration is in the pending list
assert any(
migration_name in migration for migration in pending
)
# Apply migrations
result = manager.apply_migrations()
@ -572,17 +574,22 @@ def rollback(migrator, database, fake=False, **kwargs):
# Check migrations again
applied, pending = manager.check_migrations()
assert len(applied) == 1
# There should be at least one applied migration (our test migration)
assert len(applied) >= 1
# All migrations should now be applied
assert len(pending) == 0
assert (
migration_name in applied[0]
) # Instead of exact equality, check if name is contained
# Make sure our migration is in the applied list
assert any(
migration_name in migration for migration in applied
)
# Verify migration status
status = manager.get_migration_status()
assert status["applied_count"] == 1
# There should be at least one applied migration (our test migration)
assert status["applied_count"] >= 1
assert status["pending_count"] == 0
# Use substring check for applied migrations
assert len(status["applied"]) == 1
assert migration_name in status["applied"][0]
assert len(status["applied"]) >= 1
# Make sure our migration is in the applied list
assert any(migration_name in migration for migration in status["applied"])
assert status["pending"] == []

View File

@ -0,0 +1,283 @@
"""
Tests for the migration system's source package migration handling.
"""
import os
import shutil
import tempfile
from unittest.mock import MagicMock, patch
import pytest
from ra_aid.database.migrations import (
MIGRATIONS_DIRNAME,
MigrationManager,
ensure_migrations_applied,
)
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# Clean up
shutil.rmtree(temp_dir)
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch("ra_aid.database.migrations.logger") as mock:
yield mock
class TestSourceMigrations:
"""Tests for source package migration handling."""
def test_migration_manager_uses_source_migrations_dir(self, temp_dir, mock_logger):
"""Test that MigrationManager uses source package migrations directory by default."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
# Mock _get_source_package_migrations_dir to return a test path
source_migrations_dir = os.path.join(temp_dir, "source_migrations")
os.makedirs(source_migrations_dir, exist_ok=True)
# Create __init__.py to make it a proper package
with open(os.path.join(source_migrations_dir, "__init__.py"), "w") as f:
pass
# Mock router initialization
with patch("ra_aid.database.migrations.Router") as mock_router:
mock_router.return_value = MagicMock()
# Mock _get_source_package_migrations_dir
with patch.object(
MigrationManager,
"_get_source_package_migrations_dir",
return_value=source_migrations_dir
):
# Initialize manager
manager = MigrationManager(db_path=db_path)
# Verify source migrations directory is used
assert manager.migrations_dir == source_migrations_dir
# Verify logging
mock_logger.debug.assert_any_call(
f"Using source package migrations directory: {source_migrations_dir}"
)
def test_migration_manager_with_custom_migrations_dir(self, temp_dir, mock_logger):
"""Test that MigrationManager uses custom migrations directory when provided."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
custom_migrations_dir = os.path.join(temp_dir, "custom_migrations")
# Mock router initialization
with patch("ra_aid.database.migrations.Router") as mock_router:
mock_router.return_value = MagicMock()
# Initialize manager with custom migrations directory
manager = MigrationManager(db_path=db_path, migrations_dir=custom_migrations_dir)
# Verify custom migrations directory is used
assert manager.migrations_dir == custom_migrations_dir
# Verify directory was created
assert os.path.exists(custom_migrations_dir)
assert os.path.exists(os.path.join(custom_migrations_dir, "__init__.py"))
# Verify logging
mock_logger.debug.assert_any_call(
f"Using migrations directory: {custom_migrations_dir}"
)
def test_get_source_package_migrations_dir(self, temp_dir, mock_logger):
"""Test that _get_source_package_migrations_dir returns the correct path."""
# Set up a mock source directory structure
mock_base_dir = os.path.join(temp_dir, "ra_aid")
os.makedirs(mock_base_dir, exist_ok=True)
source_migrations_dir = os.path.join(mock_base_dir, MIGRATIONS_DIRNAME)
os.makedirs(source_migrations_dir, exist_ok=True)
# Create a manager with patch in place
with patch("ra_aid.database.migrations.os.path.dirname") as mock_dirname:
with patch("ra_aid.database.migrations.os.path.abspath") as mock_abspath:
# Mock the path functions to return our test paths
mock_abspath.return_value = os.path.join(mock_base_dir, "database", "migrations.py")
# Use a custom side_effect to avoid recursion
def dirname_side_effect(path):
if path == os.path.join(mock_base_dir, "database", "migrations.py"):
return os.path.join(mock_base_dir, "database")
elif path == os.path.join(mock_base_dir, "database"):
return mock_base_dir
else:
return os.path.dirname(path)
mock_dirname.side_effect = dirname_side_effect
# Create the manager
manager = MigrationManager(db_path=os.path.join(temp_dir, "test.db"),
migrations_dir=os.path.join(temp_dir, "custom_migrations"))
# Call the method directly
with patch.object(manager, "_get_source_package_migrations_dir") as mock_method:
mock_method.return_value = source_migrations_dir
# Verify the method returns the expected path
assert manager._get_source_package_migrations_dir() == source_migrations_dir
def test_get_source_package_migrations_dir_not_found(self, temp_dir, mock_logger):
"""Test that _get_source_package_migrations_dir handles missing directory."""
# Create a manager
manager = MigrationManager(db_path=os.path.join(temp_dir, "test.db"),
migrations_dir=os.path.join(temp_dir, "custom_migrations"))
# Create a test implementation that will raise the expected error
def raise_not_found(*args, **kwargs):
error_msg = f"Source migrations directory not found: /path/to/migrations"
logger = mock_logger # Use the mocked logger
logger.error(error_msg)
raise FileNotFoundError(error_msg)
# Replace the method with our test implementation
with patch.object(manager, "_get_source_package_migrations_dir", side_effect=raise_not_found):
# Call should raise FileNotFoundError
with pytest.raises(FileNotFoundError) as excinfo:
manager._get_source_package_migrations_dir()
# Verify error message
assert "Source migrations directory not found" in str(excinfo.value)
# Verify logging
mock_logger.error.assert_called_with(
"Source migrations directory not found: /path/to/migrations"
)
def test_ensure_migrations_applied_creates_ra_aid_dir(self, temp_dir, mock_logger):
"""Test that ensure_migrations_applied creates the .ra-aid directory if it doesn't exist."""
# Get a path to a directory that doesn't exist
ra_aid_dir = os.path.join(temp_dir, ".ra-aid")
# Mock getcwd to return our temp directory
with patch("os.getcwd", return_value=temp_dir):
# Mock DatabaseManager
with patch("ra_aid.database.migrations.DatabaseManager") as mock_db_manager:
# Mock the context manager
mock_db_manager.return_value.__enter__.return_value = MagicMock()
mock_db_manager.return_value.__exit__.return_value = None
# Mock ra_aid package import
with patch("ra_aid.database.migrations.ra_aid", create=True) as mock_ra_aid:
# Set up the mock package directory path
mock_package_dir = os.path.join(temp_dir, "ra_aid_package")
os.makedirs(mock_package_dir, exist_ok=True)
mock_migrations_dir = os.path.join(mock_package_dir, MIGRATIONS_DIRNAME)
os.makedirs(mock_migrations_dir, exist_ok=True)
# Configure the mock
mock_ra_aid.__file__ = os.path.join(mock_package_dir, "__init__.py")
# Mock init_migrations and apply_migrations
mock_migration_manager = MagicMock()
mock_migration_manager.apply_migrations.return_value = True
with patch("ra_aid.database.migrations.init_migrations", return_value=mock_migration_manager):
# Call ensure_migrations_applied
result = ensure_migrations_applied()
# Verify result
assert result is True
# Verify .ra-aid directory was created
assert os.path.exists(ra_aid_dir)
def test_ensure_migrations_applied_handles_directory_error(self, mock_logger):
"""Test that ensure_migrations_applied handles errors creating the .ra-aid directory."""
# Mock os.makedirs to raise an exception
with patch("os.makedirs", side_effect=PermissionError("Permission denied")):
# Call ensure_migrations_applied
result = ensure_migrations_applied()
# Verify result is False on error
assert result is False
# Verify error was logged
mock_logger.error.assert_called_with(
"Failed to ensure .ra-aid directory exists: Permission denied"
)
def test_ensure_migrations_applied_uses_package_migrations(self, temp_dir, mock_logger):
"""Test that ensure_migrations_applied uses the source package migrations directory."""
# Set up test paths
ra_aid_dir = os.path.join(temp_dir, ".ra-aid")
# Mock getcwd to return our temp directory
with patch("os.getcwd", return_value=temp_dir):
# Mock DatabaseManager
with patch("ra_aid.database.migrations.DatabaseManager") as mock_db_manager:
# Mock the context manager
mock_db_manager.return_value.__enter__.return_value = MagicMock()
mock_db_manager.return_value.__exit__.return_value = None
# Mock ra_aid package import
with patch("ra_aid.database.migrations.ra_aid", create=True) as mock_ra_aid:
# Set up the mock package directory path
mock_package_dir = os.path.join(temp_dir, "ra_aid_package")
os.makedirs(mock_package_dir, exist_ok=True)
mock_migrations_dir = os.path.join(mock_package_dir, MIGRATIONS_DIRNAME)
os.makedirs(mock_migrations_dir, exist_ok=True)
# Configure the mock
mock_ra_aid.__file__ = os.path.join(mock_package_dir, "__init__.py")
# Create a mock migration manager that we can verify
mock_init_migrations = MagicMock()
mock_init_migrations.apply_migrations.return_value = True
with patch("ra_aid.database.migrations.init_migrations", return_value=mock_init_migrations) as mock_init:
# Call ensure_migrations_applied
result = ensure_migrations_applied()
# Verify init_migrations was called
mock_init.assert_called_once()
# We can't verify the exact path since it's derived from non-mock objects
# Instead, verify that init_migrations was called and succeeded
assert result is True
def test_router_initialization_with_source_migrations(self, temp_dir, mock_logger):
"""Test that the migration router is initialized with the source package migrations."""
# Set up test paths
db_path = os.path.join(temp_dir, "test.db")
# Create a mock source migrations directory
source_migrations_dir = os.path.join(temp_dir, "source_migrations")
os.makedirs(source_migrations_dir, exist_ok=True)
# Create __init__.py to make it a proper package
with open(os.path.join(source_migrations_dir, "__init__.py"), "w") as f:
pass
# Mock the Router class
with patch("ra_aid.database.migrations.Router") as mock_router_class:
# Create a mock router instance
mock_router = MagicMock()
mock_router_class.return_value = mock_router
# Mock _get_source_package_migrations_dir
with patch.object(
MigrationManager,
"_get_source_package_migrations_dir",
return_value=source_migrations_dir
):
# Initialize manager
manager = MigrationManager(db_path=db_path)
# Verify router was initialized with the source migrations directory
mock_router_class.assert_called_once()
# Get the args from the call
call_args = mock_router_class.call_args
assert call_args.kwargs["migrate_dir"] == source_migrations_dir

View File

@ -1,4 +1,5 @@
import pytest
from unittest.mock import patch, MagicMock
from ra_aid.tools.memory import (
_global_memory,
@ -13,10 +14,13 @@ from ra_aid.tools.memory import (
get_memory_value,
get_related_files,
get_work_log,
key_fact_repository,
log_work_event,
reset_work_log,
swap_task_order,
)
from ra_aid.database.connection import DatabaseManager
from ra_aid.database.models import KeyFact
@pytest.fixture
@ -48,49 +52,111 @@ def reset_memory():
_global_memory["work_log"] = []
def test_emit_key_facts_single_fact(reset_memory):
@pytest.fixture
def in_memory_db():
"""Set up an in-memory database for testing."""
with DatabaseManager(in_memory=True) as db:
db.create_tables([KeyFact])
yield db
# Clean up database tables after test
KeyFact.delete().execute()
@pytest.fixture(autouse=True)
def mock_repository():
"""Mock the KeyFactRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory
facts = {} # Local in-memory storage
fact_id_counter = 0
# Mock KeyFact objects
class MockKeyFact:
def __init__(self, id, content):
self.id = id
self.content = content
# Mock create method
def mock_create(content):
nonlocal fact_id_counter
fact = MockKeyFact(fact_id_counter, content)
facts[fact_id_counter] = fact
fact_id_counter += 1
return fact
mock_repo.create.side_effect = mock_create
# Mock get method
def mock_get(fact_id):
return facts.get(fact_id)
mock_repo.get.side_effect = mock_get
# Mock delete method
def mock_delete(fact_id):
if fact_id in facts:
del facts[fact_id]
return True
return False
mock_repo.delete.side_effect = mock_delete
# Mock get_facts_dict method
def mock_get_facts_dict():
return {fact_id: fact.content for fact_id, fact in facts.items()}
mock_repo.get_facts_dict.side_effect = mock_get_facts_dict
yield mock_repo
def test_emit_key_facts_single_fact(reset_memory, mock_repository):
"""Test emitting a single key fact using emit_key_facts"""
# Test with single fact
result = emit_key_facts.invoke({"facts": ["First fact"]})
assert result == "Facts stored."
assert _global_memory["key_facts"][0] == "First fact"
assert _global_memory["key_fact_id_counter"] == 1
# Verify the repository's create method was called
mock_repository.create.assert_called_once_with("First fact")
def test_delete_key_facts_single_fact(reset_memory):
def test_delete_key_facts_single_fact(reset_memory, mock_repository):
"""Test deleting a single key fact using delete_key_facts"""
# Add a fact
emit_key_facts.invoke({"facts": ["Test fact"]})
fact = mock_repository.create("Test fact")
fact_id = fact.id
# Delete the fact
result = delete_key_facts.invoke({"fact_ids": [0]})
result = delete_key_facts.invoke({"fact_ids": [fact_id]})
assert result == "Facts deleted."
assert 0 not in _global_memory["key_facts"]
# Verify the repository's delete method was called
mock_repository.delete.assert_called_once_with(fact_id)
def test_delete_key_facts_invalid(reset_memory):
def test_delete_key_facts_invalid(reset_memory, mock_repository):
"""Test deleting non-existent facts returns empty list"""
# Try to delete non-existent fact
result = delete_key_facts.invoke({"fact_ids": [999]})
assert result == "Facts deleted."
# Add and delete a fact, then try to delete it again
emit_key_facts.invoke({"facts": ["Test fact"]})
delete_key_facts.invoke({"fact_ids": [0]})
result = delete_key_facts.invoke({"fact_ids": [0]})
assert result == "Facts deleted."
# Verify the repository's get method was called
mock_repository.get.assert_called_once_with(999)
def test_get_memory_value_key_facts(reset_memory):
def test_get_memory_value_key_facts(reset_memory, mock_repository):
"""Test get_memory_value with key facts dictionary"""
# Empty key facts should return empty string
assert get_memory_value("key_facts") == ""
# Add some facts
emit_key_facts.invoke({"facts": ["First fact", "Second fact"]})
# Add some facts through the mocked repository
fact1 = mock_repository.create("First fact")
fact2 = mock_repository.create("Second fact")
# Mock get_facts_dict to return our test data
mock_repository.get_facts_dict.return_value = {
fact1.id: "First fact",
fact2.id: "Second fact"
}
# Should return markdown formatted list
expected = "## 🔑 Key Fact #0\n\nFirst fact\n\n## 🔑 Key Fact #1\n\nSecond fact"
expected = f"## 🔑 Key Fact #{fact1.id}\n\nFirst fact\n\n## 🔑 Key Fact #{fact2.id}\n\nSecond fact"
assert get_memory_value("key_facts") == expected
@ -165,7 +231,7 @@ def test_empty_work_log(reset_memory):
assert get_memory_value("work_log") == ""
def test_emit_key_facts(reset_memory):
def test_emit_key_facts(reset_memory, mock_repository):
"""Test emitting multiple key facts at once"""
# Test emitting multiple facts
facts = ["First fact", "Second fact", "Third fact"]
@ -174,31 +240,30 @@ def test_emit_key_facts(reset_memory):
# Verify return message
assert result == "Facts stored."
# Verify facts stored in memory with correct IDs
assert _global_memory["key_facts"][0] == "First fact"
assert _global_memory["key_facts"][1] == "Second fact"
assert _global_memory["key_facts"][2] == "Third fact"
# Verify counter incremented correctly
assert _global_memory["key_fact_id_counter"] == 3
# Verify create was called for each fact
assert mock_repository.create.call_count == 3
mock_repository.create.assert_any_call("First fact")
mock_repository.create.assert_any_call("Second fact")
mock_repository.create.assert_any_call("Third fact")
def test_delete_key_facts(reset_memory):
def test_delete_key_facts(reset_memory, mock_repository):
"""Test deleting multiple key facts"""
# Add some test facts
emit_key_facts.invoke({"facts": ["First fact", "Second fact", "Third fact"]})
fact0 = mock_repository.create("First fact")
fact1 = mock_repository.create("Second fact")
fact2 = mock_repository.create("Third fact")
# Test deleting mix of existing and non-existing IDs
result = delete_key_facts.invoke({"fact_ids": [0, 1, 999]})
result = delete_key_facts.invoke({"fact_ids": [fact0.id, fact1.id, 999]})
# Verify success message
assert result == "Facts deleted."
# Verify correct facts removed from memory
assert 0 not in _global_memory["key_facts"]
assert 1 not in _global_memory["key_facts"]
assert 2 in _global_memory["key_facts"] # ID 2 should remain
assert _global_memory["key_facts"][2] == "Third fact"
# Verify delete was called for each valid fact ID
assert mock_repository.delete.call_count == 2
mock_repository.delete.assert_any_call(fact0.id)
mock_repository.delete.assert_any_call(fact1.id)
def test_emit_key_snippet(reset_memory):