diff --git a/ra_aid/database/migrations.py b/ra_aid/database/migrations.py index 338aa6b..14bcf50 100644 --- a/ra_aid/database/migrations.py +++ b/ra_aid/database/migrations.py @@ -39,7 +39,7 @@ class MigrationManager: Args: db_path: Optional path to the database file. If None, uses the default. - migrations_dir: Optional path to the migrations directory. If None, uses default. + migrations_dir: Optional path to the migrations directory. If None, uses source package migrations. """ self.db = get_db() @@ -54,28 +54,56 @@ class MigrationManager: # Determine migrations directory if migrations_dir is None: - # Use a directory within .ra-aid - ra_aid_dir = os.path.dirname(self.db_path) - migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME) + # Use the source package migrations directory + migrations_dir = self._get_source_package_migrations_dir() + logger.debug(f"Using source package migrations directory: {migrations_dir}") + else: + # Use the specified migrations directory + # Ensure the directory exists if a custom path is provided + self._ensure_migrations_dir(migrations_dir) self.migrations_dir = migrations_dir - # Ensure migrations directory exists - self._ensure_migrations_dir() - # Initialize router self.router = self._init_router() - def _ensure_migrations_dir(self) -> None: + def _get_source_package_migrations_dir(self) -> str: + """ + Get the path to the migrations directory in the source package. + + Returns: + str: Path to the source package migrations directory + """ + try: + # Get the base directory of the ra_aid package + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + source_migrations_dir = os.path.join(base_dir, MIGRATIONS_DIRNAME) + + if not os.path.exists(source_migrations_dir): + error_msg = f"Source migrations directory not found: {source_migrations_dir}" + logger.error(error_msg) + raise FileNotFoundError(error_msg) + + logger.debug(f"Found source migrations directory: {source_migrations_dir}") + return source_migrations_dir + except Exception as e: + error_msg = f"Failed to locate source migrations directory: {str(e)}" + logger.error(error_msg) + raise + + def _ensure_migrations_dir(self, migrations_dir: str) -> None: """ Ensure that the migrations directory exists. Creates the directory if it doesn't exist. + + Args: + migrations_dir: Path to the migrations directory """ try: - migrations_path = Path(self.migrations_dir) + migrations_path = Path(migrations_dir) if not migrations_path.exists(): - logger.debug(f"Creating migrations directory at: {self.migrations_dir}") + logger.debug(f"Creating migrations directory at: {migrations_dir}") migrations_path.mkdir(parents=True, exist_ok=True) # Create __init__.py to make it a proper package @@ -83,7 +111,7 @@ class MigrationManager: if not init_file.exists(): init_file.touch() - logger.debug(f"Using migrations directory: {self.migrations_dir}") + logger.debug(f"Using migrations directory: {migrations_dir}") except Exception as e: logger.error(f"Failed to create migrations directory: {str(e)}") raise @@ -232,20 +260,37 @@ def init_migrations( def ensure_migrations_applied() -> bool: """ Check for and apply any pending migrations. - + + Creates the .ra-aid directory if it doesn't exist, + but uses migrations directly from the source package. + This function should be called during application startup to ensure the database schema is up to date. Returns: bool: True if migrations were applied successfully or none were pending """ - with DatabaseManager() as db: - try: - migration_manager = init_migrations() - return migration_manager.apply_migrations() - except Exception as e: - logger.error(f"Failed to apply migrations: {str(e)}") - return False + try: + # Ensure .ra-aid directory exists for the database file + cwd = os.getcwd() + ra_aid_dir = os.path.join(cwd, ".ra-aid") + os.makedirs(ra_aid_dir, exist_ok=True) + + # Use source package migrations directory + import ra_aid + package_dir = os.path.dirname(os.path.abspath(ra_aid.__file__)) + migrations_dir = os.path.join(package_dir, MIGRATIONS_DIRNAME) + + with DatabaseManager() as db: + try: + migration_manager = init_migrations(migrations_dir=migrations_dir) + return migration_manager.apply_migrations() + except Exception as e: + logger.error(f"Failed to apply migrations: {str(e)}") + return False + except Exception as e: + logger.error(f"Failed to ensure .ra-aid directory exists: {str(e)}") + return False def create_new_migration(name: str, auto: bool = True) -> Optional[str]: @@ -287,4 +332,4 @@ def get_migration_status() -> Dict[str, Any]: "pending_count": 0, "applied": [], "pending": [], - } + } \ No newline at end of file diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py index 60b8c7b..e9875c4 100644 --- a/ra_aid/database/models.py +++ b/ra_aid/database/models.py @@ -62,3 +62,17 @@ class BaseModel(peewee.Model): # Log the error with logger logger.error(f"Failed in get_or_create: {str(e)}") raise + + +class KeyFact(BaseModel): + """ + Model representing a key fact stored in the database. + + Key facts are important information about the project or current task + that need to be referenced later. + """ + content = peewee.TextField() + # created_at and updated_at are inherited from BaseModel + + class Meta: + table_name = "key_fact" \ No newline at end of file diff --git a/ra_aid/database/repositories/__init__.py b/ra_aid/database/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ra_aid/database/repositories/key_fact_repository.py b/ra_aid/database/repositories/key_fact_repository.py new file mode 100644 index 0000000..ffa9441 --- /dev/null +++ b/ra_aid/database/repositories/key_fact_repository.py @@ -0,0 +1,168 @@ +""" +Key fact repository implementation for database access. + +This module provides a repository implementation for the KeyFact model, +following the repository pattern for data access abstraction. +""" + +from typing import Dict, List, Optional + +import peewee + +from ra_aid.database.connection import DatabaseManager, get_db +from ra_aid.database.models import KeyFact +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) + + +class KeyFactRepository: + """ + Repository for managing KeyFact database operations. + + This class provides methods for performing CRUD operations on the KeyFact model, + abstracting the database access details from the business logic. + + Example: + repo = KeyFactRepository() + fact = repo.create("Important fact about the project") + all_facts = repo.get_all() + """ + + def create(self, content: str) -> KeyFact: + """ + Create a new key fact in the database. + + Args: + content: The text content of the key fact + + Returns: + KeyFact: The newly created key fact instance + + Raises: + peewee.DatabaseError: If there's an error creating the fact + """ + try: + with DatabaseManager() as db: + fact = KeyFact.create(content=content) + logger.debug(f"Created key fact ID {fact.id}: {content}") + return fact + except peewee.DatabaseError as e: + logger.error(f"Failed to create key fact: {str(e)}") + raise + + def get(self, fact_id: int) -> Optional[KeyFact]: + """ + Retrieve a key fact by its ID. + + Args: + fact_id: The ID of the key fact to retrieve + + Returns: + Optional[KeyFact]: The key fact instance if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + with DatabaseManager() as db: + return KeyFact.get_or_none(KeyFact.id == fact_id) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}") + raise + + def update(self, fact_id: int, content: str) -> Optional[KeyFact]: + """ + Update an existing key fact. + + Args: + fact_id: The ID of the key fact to update + content: The new content for the key fact + + Returns: + Optional[KeyFact]: The updated key fact if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error updating the fact + """ + try: + with DatabaseManager() as db: + # First check if the fact exists + fact = self.get(fact_id) + if not fact: + logger.warning(f"Attempted to update non-existent key fact {fact_id}") + return None + + # Update the fact + fact.content = content + fact.save() + logger.debug(f"Updated key fact ID {fact_id}: {content}") + return fact + except peewee.DatabaseError as e: + logger.error(f"Failed to update key fact {fact_id}: {str(e)}") + raise + + def delete(self, fact_id: int) -> bool: + """ + Delete a key fact by its ID. + + Args: + fact_id: The ID of the key fact to delete + + Returns: + bool: True if the fact was deleted, False if it wasn't found + + Raises: + peewee.DatabaseError: If there's an error deleting the fact + """ + try: + with DatabaseManager() as db: + # First check if the fact exists + fact = self.get(fact_id) + if not fact: + logger.warning(f"Attempted to delete non-existent key fact {fact_id}") + return False + + # Delete the fact + fact.delete_instance() + logger.debug(f"Deleted key fact ID {fact_id}") + return True + except peewee.DatabaseError as e: + logger.error(f"Failed to delete key fact {fact_id}: {str(e)}") + raise + + def get_all(self) -> List[KeyFact]: + """ + Retrieve all key facts from the database. + + Returns: + List[KeyFact]: List of all key fact instances + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + with DatabaseManager() as db: + return list(KeyFact.select().order_by(KeyFact.id)) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch all key facts: {str(e)}") + raise + + def get_facts_dict(self) -> Dict[int, str]: + """ + Retrieve all key facts as a dictionary mapping IDs to content. + + This method is useful for compatibility with the existing memory format. + + Returns: + Dict[int, str]: Dictionary with fact IDs as keys and content as values + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + facts = self.get_all() + return {fact.id: fact.content for fact in facts} + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch key facts as dictionary: {str(e)}") + raise \ No newline at end of file diff --git a/ra_aid/migrations/002_20250301_212203_add_key_fact_model.py b/ra_aid/migrations/002_20250301_212203_add_key_fact_model.py new file mode 100644 index 0000000..1ecef96 --- /dev/null +++ b/ra_aid/migrations/002_20250301_212203_add_key_fact_model.py @@ -0,0 +1,54 @@ +"""Peewee migrations -- 002_20250301_212203_add_key_fact_model.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class KeyFact(pw.Model): + id = pw.AutoField() + created_at = pw.DateTimeField() + updated_at = pw.DateTimeField() + content = pw.TextField() + + class Meta: + table_name = "key_fact" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model('key_fact') diff --git a/ra_aid/migrations/__init__.py b/ra_aid/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 3752bb4..70a1b87 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -17,6 +17,7 @@ from ra_aid.agent_context import ( mark_should_exit, mark_task_completed, ) +from ra_aid.database.repositories.key_fact_repository import KeyFactRepository class WorkLogEntry(TypedDict): @@ -33,14 +34,17 @@ class SnippetInfo(TypedDict): console = Console() +# Initialize repository for key facts +key_fact_repository = KeyFactRepository() + # Global memory store _global_memory: Dict[str, Any] = { "research_notes": [], "plans": [], "tasks": {}, # Dict[int, str] - ID to task mapping "task_id_counter": 1, # Counter for generating unique task IDs - "key_facts": {}, # Dict[int, str] - ID to fact mapping - "key_fact_id_counter": 1, # Counter for generating unique fact IDs + "key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now) + "key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now) "key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping "key_snippet_id_counter": 1, # Counter for generating unique snippet IDs "implementation_requested": False, @@ -106,12 +110,9 @@ def emit_key_facts(facts: List[str]) -> str: """ results = [] for fact in facts: - # Get and increment fact ID - fact_id = _global_memory["key_fact_id_counter"] - _global_memory["key_fact_id_counter"] += 1 - - # Store fact with ID - _global_memory["key_facts"][fact_id] = fact + # Create fact in database using repository + created_fact = key_fact_repository.create(fact) + fact_id = created_fact.id # Display panel with ID console.print( @@ -139,14 +140,17 @@ def delete_key_facts(fact_ids: List[int]) -> str: """ results = [] for fact_id in fact_ids: - if fact_id in _global_memory["key_facts"]: + # Get the fact first to display information + fact = key_fact_repository.get(fact_id) + if fact: # Delete the fact - deleted_fact = _global_memory["key_facts"].pop(fact_id) - success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}" - console.print( - Panel(Markdown(success_msg), title="Fact Deleted", border_style="green") - ) - results.append(success_msg) + was_deleted = key_fact_repository.delete(fact_id) + if was_deleted: + success_msg = f"Successfully deleted fact #{fact_id}: {fact.content}" + console.print( + Panel(Markdown(success_msg), title="Fact Deleted", border_style="green") + ) + results.append(success_msg) log_work_event(f"Deleted facts {fact_ids}.") return "Facts deleted." @@ -601,26 +605,46 @@ def get_memory_value(key: str) -> str: - For key_snippets: Formatted snippet blocks - For other types: One value per line """ - values = _global_memory.get(key, []) - if key == "key_facts": - # For empty dict, return empty string - if not values: - return "" - # Sort by ID for consistent output and format as markdown sections - facts = [] - for k, v in sorted(values.items()): - facts.extend( - [ - f"## 🔑 Key Fact #{k}", - "", # Empty line for better markdown spacing - v, - "", # Empty line between facts - ] - ) - return "\n".join(facts).rstrip() # Remove trailing newline + try: + # Get facts from repository as a dictionary + facts_dict = key_fact_repository.get_facts_dict() + + # For empty dict, return empty string + if not facts_dict: + return "" + + # Sort by ID for consistent output and format as markdown sections + facts = [] + for k, v in sorted(facts_dict.items()): + facts.extend( + [ + f"## 🔑 Key Fact #{k}", + "", # Empty line for better markdown spacing + v, + "", # Empty line between facts + ] + ) + return "\n".join(facts).rstrip() # Remove trailing newline + except Exception: + # Fallback to old memory if database access fails + values = _global_memory.get(key, {}) + if not values: + return "" + facts = [] + for k, v in sorted(values.items()): + facts.extend( + [ + f"## 🔑 Key Fact #{k}", + "", + v, + "", + ] + ) + return "\n".join(facts).rstrip() if key == "key_snippets": + values = _global_memory.get(key, {}) if not values: return "" # Format each snippet with file info and content using markdown @@ -645,10 +669,12 @@ def get_memory_value(key: str) -> str: return "\n\n".join(snippets) if key == "work_log": + values = _global_memory.get(key, []) if not values: return "" entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values] return "\n\n".join(entries) # For other types (lists), join with newlines - return "\n".join(str(v) for v in values) + values = _global_memory.get(key, []) + return "\n".join(str(v) for v in values) \ No newline at end of file diff --git a/tests/ra_aid/database/test_migrations.py b/tests/ra_aid/database/test_migrations.py index 77b78cf..254a09b 100644 --- a/tests/ra_aid/database/test_migrations.py +++ b/tests/ra_aid/database/test_migrations.py @@ -561,10 +561,12 @@ def rollback(migrator, database, fake=False, **kwargs): # Check migrations applied, pending = manager.check_migrations() assert len(applied) == 0 - assert len(pending) == 1 - assert ( - migration_name in pending[0] - ) # Instead of exact equality, check if name is contained + # There may be multiple pending migrations (source package migrations + our test migration) + assert len(pending) >= 1 + # Make sure our newly created migration is in the pending list + assert any( + migration_name in migration for migration in pending + ) # Apply migrations result = manager.apply_migrations() @@ -572,17 +574,22 @@ def rollback(migrator, database, fake=False, **kwargs): # Check migrations again applied, pending = manager.check_migrations() - assert len(applied) == 1 + # There should be at least one applied migration (our test migration) + assert len(applied) >= 1 + # All migrations should now be applied assert len(pending) == 0 - assert ( - migration_name in applied[0] - ) # Instead of exact equality, check if name is contained + # Make sure our migration is in the applied list + assert any( + migration_name in migration for migration in applied + ) # Verify migration status status = manager.get_migration_status() - assert status["applied_count"] == 1 + # There should be at least one applied migration (our test migration) + assert status["applied_count"] >= 1 assert status["pending_count"] == 0 # Use substring check for applied migrations - assert len(status["applied"]) == 1 - assert migration_name in status["applied"][0] + assert len(status["applied"]) >= 1 + # Make sure our migration is in the applied list + assert any(migration_name in migration for migration in status["applied"]) assert status["pending"] == [] diff --git a/tests/ra_aid/database/test_source_migrations.py b/tests/ra_aid/database/test_source_migrations.py new file mode 100644 index 0000000..3a5f016 --- /dev/null +++ b/tests/ra_aid/database/test_source_migrations.py @@ -0,0 +1,283 @@ +""" +Tests for the migration system's source package migration handling. +""" + +import os +import shutil +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from ra_aid.database.migrations import ( + MIGRATIONS_DIRNAME, + MigrationManager, + ensure_migrations_applied, +) + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.fixture +def mock_logger(): + """Mock the logger to test for output messages.""" + with patch("ra_aid.database.migrations.logger") as mock: + yield mock + + +class TestSourceMigrations: + """Tests for source package migration handling.""" + + def test_migration_manager_uses_source_migrations_dir(self, temp_dir, mock_logger): + """Test that MigrationManager uses source package migrations directory by default.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + + # Mock _get_source_package_migrations_dir to return a test path + source_migrations_dir = os.path.join(temp_dir, "source_migrations") + os.makedirs(source_migrations_dir, exist_ok=True) + + # Create __init__.py to make it a proper package + with open(os.path.join(source_migrations_dir, "__init__.py"), "w") as f: + pass + + # Mock router initialization + with patch("ra_aid.database.migrations.Router") as mock_router: + mock_router.return_value = MagicMock() + + # Mock _get_source_package_migrations_dir + with patch.object( + MigrationManager, + "_get_source_package_migrations_dir", + return_value=source_migrations_dir + ): + # Initialize manager + manager = MigrationManager(db_path=db_path) + + # Verify source migrations directory is used + assert manager.migrations_dir == source_migrations_dir + + # Verify logging + mock_logger.debug.assert_any_call( + f"Using source package migrations directory: {source_migrations_dir}" + ) + + def test_migration_manager_with_custom_migrations_dir(self, temp_dir, mock_logger): + """Test that MigrationManager uses custom migrations directory when provided.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + custom_migrations_dir = os.path.join(temp_dir, "custom_migrations") + + # Mock router initialization + with patch("ra_aid.database.migrations.Router") as mock_router: + mock_router.return_value = MagicMock() + + # Initialize manager with custom migrations directory + manager = MigrationManager(db_path=db_path, migrations_dir=custom_migrations_dir) + + # Verify custom migrations directory is used + assert manager.migrations_dir == custom_migrations_dir + + # Verify directory was created + assert os.path.exists(custom_migrations_dir) + assert os.path.exists(os.path.join(custom_migrations_dir, "__init__.py")) + + # Verify logging + mock_logger.debug.assert_any_call( + f"Using migrations directory: {custom_migrations_dir}" + ) + + def test_get_source_package_migrations_dir(self, temp_dir, mock_logger): + """Test that _get_source_package_migrations_dir returns the correct path.""" + # Set up a mock source directory structure + mock_base_dir = os.path.join(temp_dir, "ra_aid") + os.makedirs(mock_base_dir, exist_ok=True) + + source_migrations_dir = os.path.join(mock_base_dir, MIGRATIONS_DIRNAME) + os.makedirs(source_migrations_dir, exist_ok=True) + + # Create a manager with patch in place + with patch("ra_aid.database.migrations.os.path.dirname") as mock_dirname: + with patch("ra_aid.database.migrations.os.path.abspath") as mock_abspath: + # Mock the path functions to return our test paths + mock_abspath.return_value = os.path.join(mock_base_dir, "database", "migrations.py") + # Use a custom side_effect to avoid recursion + def dirname_side_effect(path): + if path == os.path.join(mock_base_dir, "database", "migrations.py"): + return os.path.join(mock_base_dir, "database") + elif path == os.path.join(mock_base_dir, "database"): + return mock_base_dir + else: + return os.path.dirname(path) + + mock_dirname.side_effect = dirname_side_effect + + # Create the manager + manager = MigrationManager(db_path=os.path.join(temp_dir, "test.db"), + migrations_dir=os.path.join(temp_dir, "custom_migrations")) + + # Call the method directly + with patch.object(manager, "_get_source_package_migrations_dir") as mock_method: + mock_method.return_value = source_migrations_dir + + # Verify the method returns the expected path + assert manager._get_source_package_migrations_dir() == source_migrations_dir + + def test_get_source_package_migrations_dir_not_found(self, temp_dir, mock_logger): + """Test that _get_source_package_migrations_dir handles missing directory.""" + # Create a manager + manager = MigrationManager(db_path=os.path.join(temp_dir, "test.db"), + migrations_dir=os.path.join(temp_dir, "custom_migrations")) + + # Create a test implementation that will raise the expected error + def raise_not_found(*args, **kwargs): + error_msg = f"Source migrations directory not found: /path/to/migrations" + logger = mock_logger # Use the mocked logger + logger.error(error_msg) + raise FileNotFoundError(error_msg) + + # Replace the method with our test implementation + with patch.object(manager, "_get_source_package_migrations_dir", side_effect=raise_not_found): + # Call should raise FileNotFoundError + with pytest.raises(FileNotFoundError) as excinfo: + manager._get_source_package_migrations_dir() + + # Verify error message + assert "Source migrations directory not found" in str(excinfo.value) + + # Verify logging + mock_logger.error.assert_called_with( + "Source migrations directory not found: /path/to/migrations" + ) + + def test_ensure_migrations_applied_creates_ra_aid_dir(self, temp_dir, mock_logger): + """Test that ensure_migrations_applied creates the .ra-aid directory if it doesn't exist.""" + # Get a path to a directory that doesn't exist + ra_aid_dir = os.path.join(temp_dir, ".ra-aid") + + # Mock getcwd to return our temp directory + with patch("os.getcwd", return_value=temp_dir): + # Mock DatabaseManager + with patch("ra_aid.database.migrations.DatabaseManager") as mock_db_manager: + # Mock the context manager + mock_db_manager.return_value.__enter__.return_value = MagicMock() + mock_db_manager.return_value.__exit__.return_value = None + + # Mock ra_aid package import + with patch("ra_aid.database.migrations.ra_aid", create=True) as mock_ra_aid: + # Set up the mock package directory path + mock_package_dir = os.path.join(temp_dir, "ra_aid_package") + os.makedirs(mock_package_dir, exist_ok=True) + mock_migrations_dir = os.path.join(mock_package_dir, MIGRATIONS_DIRNAME) + os.makedirs(mock_migrations_dir, exist_ok=True) + + # Configure the mock + mock_ra_aid.__file__ = os.path.join(mock_package_dir, "__init__.py") + + # Mock init_migrations and apply_migrations + mock_migration_manager = MagicMock() + mock_migration_manager.apply_migrations.return_value = True + + with patch("ra_aid.database.migrations.init_migrations", return_value=mock_migration_manager): + # Call ensure_migrations_applied + result = ensure_migrations_applied() + + # Verify result + assert result is True + + # Verify .ra-aid directory was created + assert os.path.exists(ra_aid_dir) + + def test_ensure_migrations_applied_handles_directory_error(self, mock_logger): + """Test that ensure_migrations_applied handles errors creating the .ra-aid directory.""" + # Mock os.makedirs to raise an exception + with patch("os.makedirs", side_effect=PermissionError("Permission denied")): + # Call ensure_migrations_applied + result = ensure_migrations_applied() + + # Verify result is False on error + assert result is False + + # Verify error was logged + mock_logger.error.assert_called_with( + "Failed to ensure .ra-aid directory exists: Permission denied" + ) + + def test_ensure_migrations_applied_uses_package_migrations(self, temp_dir, mock_logger): + """Test that ensure_migrations_applied uses the source package migrations directory.""" + # Set up test paths + ra_aid_dir = os.path.join(temp_dir, ".ra-aid") + + # Mock getcwd to return our temp directory + with patch("os.getcwd", return_value=temp_dir): + # Mock DatabaseManager + with patch("ra_aid.database.migrations.DatabaseManager") as mock_db_manager: + # Mock the context manager + mock_db_manager.return_value.__enter__.return_value = MagicMock() + mock_db_manager.return_value.__exit__.return_value = None + + # Mock ra_aid package import + with patch("ra_aid.database.migrations.ra_aid", create=True) as mock_ra_aid: + # Set up the mock package directory path + mock_package_dir = os.path.join(temp_dir, "ra_aid_package") + os.makedirs(mock_package_dir, exist_ok=True) + mock_migrations_dir = os.path.join(mock_package_dir, MIGRATIONS_DIRNAME) + os.makedirs(mock_migrations_dir, exist_ok=True) + + # Configure the mock + mock_ra_aid.__file__ = os.path.join(mock_package_dir, "__init__.py") + + # Create a mock migration manager that we can verify + mock_init_migrations = MagicMock() + mock_init_migrations.apply_migrations.return_value = True + + with patch("ra_aid.database.migrations.init_migrations", return_value=mock_init_migrations) as mock_init: + # Call ensure_migrations_applied + result = ensure_migrations_applied() + + # Verify init_migrations was called + mock_init.assert_called_once() + # We can't verify the exact path since it's derived from non-mock objects + # Instead, verify that init_migrations was called and succeeded + assert result is True + + def test_router_initialization_with_source_migrations(self, temp_dir, mock_logger): + """Test that the migration router is initialized with the source package migrations.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + + # Create a mock source migrations directory + source_migrations_dir = os.path.join(temp_dir, "source_migrations") + os.makedirs(source_migrations_dir, exist_ok=True) + + # Create __init__.py to make it a proper package + with open(os.path.join(source_migrations_dir, "__init__.py"), "w") as f: + pass + + # Mock the Router class + with patch("ra_aid.database.migrations.Router") as mock_router_class: + # Create a mock router instance + mock_router = MagicMock() + mock_router_class.return_value = mock_router + + # Mock _get_source_package_migrations_dir + with patch.object( + MigrationManager, + "_get_source_package_migrations_dir", + return_value=source_migrations_dir + ): + # Initialize manager + manager = MigrationManager(db_path=db_path) + + # Verify router was initialized with the source migrations directory + mock_router_class.assert_called_once() + # Get the args from the call + call_args = mock_router_class.call_args + assert call_args.kwargs["migrate_dir"] == source_migrations_dir \ No newline at end of file diff --git a/tests/ra_aid/tools/test_memory.py b/tests/ra_aid/tools/test_memory.py index f48ad6e..0885b28 100644 --- a/tests/ra_aid/tools/test_memory.py +++ b/tests/ra_aid/tools/test_memory.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import patch, MagicMock from ra_aid.tools.memory import ( _global_memory, @@ -13,10 +14,13 @@ from ra_aid.tools.memory import ( get_memory_value, get_related_files, get_work_log, + key_fact_repository, log_work_event, reset_work_log, swap_task_order, ) +from ra_aid.database.connection import DatabaseManager +from ra_aid.database.models import KeyFact @pytest.fixture @@ -48,49 +52,111 @@ def reset_memory(): _global_memory["work_log"] = [] -def test_emit_key_facts_single_fact(reset_memory): +@pytest.fixture +def in_memory_db(): + """Set up an in-memory database for testing.""" + with DatabaseManager(in_memory=True) as db: + db.create_tables([KeyFact]) + yield db + # Clean up database tables after test + KeyFact.delete().execute() + + +@pytest.fixture(autouse=True) +def mock_repository(): + """Mock the KeyFactRepository to avoid database operations during tests""" + with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo: + # Setup the mock repository to behave like the original, but using memory + facts = {} # Local in-memory storage + fact_id_counter = 0 + + # Mock KeyFact objects + class MockKeyFact: + def __init__(self, id, content): + self.id = id + self.content = content + + # Mock create method + def mock_create(content): + nonlocal fact_id_counter + fact = MockKeyFact(fact_id_counter, content) + facts[fact_id_counter] = fact + fact_id_counter += 1 + return fact + mock_repo.create.side_effect = mock_create + + # Mock get method + def mock_get(fact_id): + return facts.get(fact_id) + mock_repo.get.side_effect = mock_get + + # Mock delete method + def mock_delete(fact_id): + if fact_id in facts: + del facts[fact_id] + return True + return False + mock_repo.delete.side_effect = mock_delete + + # Mock get_facts_dict method + def mock_get_facts_dict(): + return {fact_id: fact.content for fact_id, fact in facts.items()} + mock_repo.get_facts_dict.side_effect = mock_get_facts_dict + + yield mock_repo + + +def test_emit_key_facts_single_fact(reset_memory, mock_repository): """Test emitting a single key fact using emit_key_facts""" # Test with single fact result = emit_key_facts.invoke({"facts": ["First fact"]}) assert result == "Facts stored." - assert _global_memory["key_facts"][0] == "First fact" - assert _global_memory["key_fact_id_counter"] == 1 + + # Verify the repository's create method was called + mock_repository.create.assert_called_once_with("First fact") -def test_delete_key_facts_single_fact(reset_memory): +def test_delete_key_facts_single_fact(reset_memory, mock_repository): """Test deleting a single key fact using delete_key_facts""" # Add a fact - emit_key_facts.invoke({"facts": ["Test fact"]}) - + fact = mock_repository.create("Test fact") + fact_id = fact.id + # Delete the fact - result = delete_key_facts.invoke({"fact_ids": [0]}) + result = delete_key_facts.invoke({"fact_ids": [fact_id]}) assert result == "Facts deleted." - assert 0 not in _global_memory["key_facts"] + + # Verify the repository's delete method was called + mock_repository.delete.assert_called_once_with(fact_id) -def test_delete_key_facts_invalid(reset_memory): +def test_delete_key_facts_invalid(reset_memory, mock_repository): """Test deleting non-existent facts returns empty list""" # Try to delete non-existent fact result = delete_key_facts.invoke({"fact_ids": [999]}) assert result == "Facts deleted." - # Add and delete a fact, then try to delete it again - emit_key_facts.invoke({"facts": ["Test fact"]}) - delete_key_facts.invoke({"fact_ids": [0]}) - result = delete_key_facts.invoke({"fact_ids": [0]}) - assert result == "Facts deleted." + # Verify the repository's get method was called + mock_repository.get.assert_called_once_with(999) -def test_get_memory_value_key_facts(reset_memory): +def test_get_memory_value_key_facts(reset_memory, mock_repository): """Test get_memory_value with key facts dictionary""" # Empty key facts should return empty string assert get_memory_value("key_facts") == "" - # Add some facts - emit_key_facts.invoke({"facts": ["First fact", "Second fact"]}) + # Add some facts through the mocked repository + fact1 = mock_repository.create("First fact") + fact2 = mock_repository.create("Second fact") + + # Mock get_facts_dict to return our test data + mock_repository.get_facts_dict.return_value = { + fact1.id: "First fact", + fact2.id: "Second fact" + } # Should return markdown formatted list - expected = "## 🔑 Key Fact #0\n\nFirst fact\n\n## 🔑 Key Fact #1\n\nSecond fact" + expected = f"## 🔑 Key Fact #{fact1.id}\n\nFirst fact\n\n## 🔑 Key Fact #{fact2.id}\n\nSecond fact" assert get_memory_value("key_facts") == expected @@ -165,7 +231,7 @@ def test_empty_work_log(reset_memory): assert get_memory_value("work_log") == "" -def test_emit_key_facts(reset_memory): +def test_emit_key_facts(reset_memory, mock_repository): """Test emitting multiple key facts at once""" # Test emitting multiple facts facts = ["First fact", "Second fact", "Third fact"] @@ -174,31 +240,30 @@ def test_emit_key_facts(reset_memory): # Verify return message assert result == "Facts stored." - # Verify facts stored in memory with correct IDs - assert _global_memory["key_facts"][0] == "First fact" - assert _global_memory["key_facts"][1] == "Second fact" - assert _global_memory["key_facts"][2] == "Third fact" - - # Verify counter incremented correctly - assert _global_memory["key_fact_id_counter"] == 3 + # Verify create was called for each fact + assert mock_repository.create.call_count == 3 + mock_repository.create.assert_any_call("First fact") + mock_repository.create.assert_any_call("Second fact") + mock_repository.create.assert_any_call("Third fact") -def test_delete_key_facts(reset_memory): +def test_delete_key_facts(reset_memory, mock_repository): """Test deleting multiple key facts""" # Add some test facts - emit_key_facts.invoke({"facts": ["First fact", "Second fact", "Third fact"]}) + fact0 = mock_repository.create("First fact") + fact1 = mock_repository.create("Second fact") + fact2 = mock_repository.create("Third fact") # Test deleting mix of existing and non-existing IDs - result = delete_key_facts.invoke({"fact_ids": [0, 1, 999]}) + result = delete_key_facts.invoke({"fact_ids": [fact0.id, fact1.id, 999]}) # Verify success message assert result == "Facts deleted." - # Verify correct facts removed from memory - assert 0 not in _global_memory["key_facts"] - assert 1 not in _global_memory["key_facts"] - assert 2 in _global_memory["key_facts"] # ID 2 should remain - assert _global_memory["key_facts"][2] == "Third fact" + # Verify delete was called for each valid fact ID + assert mock_repository.delete.call_count == 2 + mock_repository.delete.assert_any_call(fact0.id) + mock_repository.delete.assert_any_call(fact1.id) def test_emit_key_snippet(reset_memory): @@ -820,4 +885,4 @@ def test_is_binary_file_with_null_bytes(reset_memory, tmp_path, monkeypatch): is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file)) assert ( is_binary - ), "File with null bytes should be identified as binary with fallback method" + ), "File with null bytes should be identified as binary with fallback method" \ No newline at end of file