key fact db
This commit is contained in:
parent
50d618c8f8
commit
23d5e267f4
|
|
@ -39,7 +39,7 @@ class MigrationManager:
|
|||
|
||||
Args:
|
||||
db_path: Optional path to the database file. If None, uses the default.
|
||||
migrations_dir: Optional path to the migrations directory. If None, uses default.
|
||||
migrations_dir: Optional path to the migrations directory. If None, uses source package migrations.
|
||||
"""
|
||||
self.db = get_db()
|
||||
|
||||
|
|
@ -54,28 +54,56 @@ class MigrationManager:
|
|||
|
||||
# Determine migrations directory
|
||||
if migrations_dir is None:
|
||||
# Use a directory within .ra-aid
|
||||
ra_aid_dir = os.path.dirname(self.db_path)
|
||||
migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME)
|
||||
# Use the source package migrations directory
|
||||
migrations_dir = self._get_source_package_migrations_dir()
|
||||
logger.debug(f"Using source package migrations directory: {migrations_dir}")
|
||||
else:
|
||||
# Use the specified migrations directory
|
||||
# Ensure the directory exists if a custom path is provided
|
||||
self._ensure_migrations_dir(migrations_dir)
|
||||
|
||||
self.migrations_dir = migrations_dir
|
||||
|
||||
# Ensure migrations directory exists
|
||||
self._ensure_migrations_dir()
|
||||
|
||||
# Initialize router
|
||||
self.router = self._init_router()
|
||||
|
||||
def _ensure_migrations_dir(self) -> None:
|
||||
def _get_source_package_migrations_dir(self) -> str:
|
||||
"""
|
||||
Get the path to the migrations directory in the source package.
|
||||
|
||||
Returns:
|
||||
str: Path to the source package migrations directory
|
||||
"""
|
||||
try:
|
||||
# Get the base directory of the ra_aid package
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
source_migrations_dir = os.path.join(base_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
if not os.path.exists(source_migrations_dir):
|
||||
error_msg = f"Source migrations directory not found: {source_migrations_dir}"
|
||||
logger.error(error_msg)
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
logger.debug(f"Found source migrations directory: {source_migrations_dir}")
|
||||
return source_migrations_dir
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to locate source migrations directory: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def _ensure_migrations_dir(self, migrations_dir: str) -> None:
|
||||
"""
|
||||
Ensure that the migrations directory exists.
|
||||
|
||||
Creates the directory if it doesn't exist.
|
||||
|
||||
Args:
|
||||
migrations_dir: Path to the migrations directory
|
||||
"""
|
||||
try:
|
||||
migrations_path = Path(self.migrations_dir)
|
||||
migrations_path = Path(migrations_dir)
|
||||
if not migrations_path.exists():
|
||||
logger.debug(f"Creating migrations directory at: {self.migrations_dir}")
|
||||
logger.debug(f"Creating migrations directory at: {migrations_dir}")
|
||||
migrations_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create __init__.py to make it a proper package
|
||||
|
|
@ -83,7 +111,7 @@ class MigrationManager:
|
|||
if not init_file.exists():
|
||||
init_file.touch()
|
||||
|
||||
logger.debug(f"Using migrations directory: {self.migrations_dir}")
|
||||
logger.debug(f"Using migrations directory: {migrations_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create migrations directory: {str(e)}")
|
||||
raise
|
||||
|
|
@ -232,20 +260,37 @@ def init_migrations(
|
|||
def ensure_migrations_applied() -> bool:
|
||||
"""
|
||||
Check for and apply any pending migrations.
|
||||
|
||||
|
||||
Creates the .ra-aid directory if it doesn't exist,
|
||||
but uses migrations directly from the source package.
|
||||
|
||||
This function should be called during application startup to ensure
|
||||
the database schema is up to date.
|
||||
|
||||
Returns:
|
||||
bool: True if migrations were applied successfully or none were pending
|
||||
"""
|
||||
with DatabaseManager() as db:
|
||||
try:
|
||||
migration_manager = init_migrations()
|
||||
return migration_manager.apply_migrations()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply migrations: {str(e)}")
|
||||
return False
|
||||
try:
|
||||
# Ensure .ra-aid directory exists for the database file
|
||||
cwd = os.getcwd()
|
||||
ra_aid_dir = os.path.join(cwd, ".ra-aid")
|
||||
os.makedirs(ra_aid_dir, exist_ok=True)
|
||||
|
||||
# Use source package migrations directory
|
||||
import ra_aid
|
||||
package_dir = os.path.dirname(os.path.abspath(ra_aid.__file__))
|
||||
migrations_dir = os.path.join(package_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
with DatabaseManager() as db:
|
||||
try:
|
||||
migration_manager = init_migrations(migrations_dir=migrations_dir)
|
||||
return migration_manager.apply_migrations()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply migrations: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure .ra-aid directory exists: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def create_new_migration(name: str, auto: bool = True) -> Optional[str]:
|
||||
|
|
@ -287,4 +332,4 @@ def get_migration_status() -> Dict[str, Any]:
|
|||
"pending_count": 0,
|
||||
"applied": [],
|
||||
"pending": [],
|
||||
}
|
||||
}
|
||||
|
|
@ -62,3 +62,17 @@ class BaseModel(peewee.Model):
|
|||
# Log the error with logger
|
||||
logger.error(f"Failed in get_or_create: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class KeyFact(BaseModel):
|
||||
"""
|
||||
Model representing a key fact stored in the database.
|
||||
|
||||
Key facts are important information about the project or current task
|
||||
that need to be referenced later.
|
||||
"""
|
||||
content = peewee.TextField()
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
table_name = "key_fact"
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
"""
|
||||
Key fact repository implementation for database access.
|
||||
|
||||
This module provides a repository implementation for the KeyFact model,
|
||||
following the repository pattern for data access abstraction.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, get_db
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeyFactRepository:
|
||||
"""
|
||||
Repository for managing KeyFact database operations.
|
||||
|
||||
This class provides methods for performing CRUD operations on the KeyFact model,
|
||||
abstracting the database access details from the business logic.
|
||||
|
||||
Example:
|
||||
repo = KeyFactRepository()
|
||||
fact = repo.create("Important fact about the project")
|
||||
all_facts = repo.get_all()
|
||||
"""
|
||||
|
||||
def create(self, content: str) -> KeyFact:
|
||||
"""
|
||||
Create a new key fact in the database.
|
||||
|
||||
Args:
|
||||
content: The text content of the key fact
|
||||
|
||||
Returns:
|
||||
KeyFact: The newly created key fact instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the fact
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
fact = KeyFact.create(content=content)
|
||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||
return fact
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create key fact: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, fact_id: int) -> Optional[KeyFact]:
|
||||
"""
|
||||
Retrieve a key fact by its ID.
|
||||
|
||||
Args:
|
||||
fact_id: The ID of the key fact to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[KeyFact]: The key fact instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, fact_id: int, content: str) -> Optional[KeyFact]:
|
||||
"""
|
||||
Update an existing key fact.
|
||||
|
||||
Args:
|
||||
fact_id: The ID of the key fact to update
|
||||
content: The new content for the key fact
|
||||
|
||||
Returns:
|
||||
Optional[KeyFact]: The updated key fact if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the fact
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
|
||||
return None
|
||||
|
||||
# Update the fact
|
||||
fact.content = content
|
||||
fact.save()
|
||||
logger.debug(f"Updated key fact ID {fact_id}: {content}")
|
||||
return fact
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete(self, fact_id: int) -> bool:
|
||||
"""
|
||||
Delete a key fact by its ID.
|
||||
|
||||
Args:
|
||||
fact_id: The ID of the key fact to delete
|
||||
|
||||
Returns:
|
||||
bool: True if the fact was deleted, False if it wasn't found
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error deleting the fact
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
|
||||
return False
|
||||
|
||||
# Delete the fact
|
||||
fact.delete_instance()
|
||||
logger.debug(f"Deleted key fact ID {fact_id}")
|
||||
return True
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[KeyFact]:
|
||||
"""
|
||||
Retrieve all key facts from the database.
|
||||
|
||||
Returns:
|
||||
List[KeyFact]: List of all key fact instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
return list(KeyFact.select().order_by(KeyFact.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_facts_dict(self) -> Dict[int, str]:
|
||||
"""
|
||||
Retrieve all key facts as a dictionary mapping IDs to content.
|
||||
|
||||
This method is useful for compatibility with the existing memory format.
|
||||
|
||||
Returns:
|
||||
Dict[int, str]: Dictionary with fact IDs as keys and content as values
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
facts = self.get_all()
|
||||
return {fact.id: fact.content for fact in facts}
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key facts as dictionary: {str(e)}")
|
||||
raise
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""Peewee migrations -- 002_20250301_212203_add_key_fact_model.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class KeyFact(pw.Model):
|
||||
id = pw.AutoField()
|
||||
created_at = pw.DateTimeField()
|
||||
updated_at = pw.DateTimeField()
|
||||
content = pw.TextField()
|
||||
|
||||
class Meta:
|
||||
table_name = "key_fact"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model('key_fact')
|
||||
|
|
@ -17,6 +17,7 @@ from ra_aid.agent_context import (
|
|||
mark_should_exit,
|
||||
mark_task_completed,
|
||||
)
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
|
||||
|
||||
class WorkLogEntry(TypedDict):
|
||||
|
|
@ -33,14 +34,17 @@ class SnippetInfo(TypedDict):
|
|||
|
||||
console = Console()
|
||||
|
||||
# Initialize repository for key facts
|
||||
key_fact_repository = KeyFactRepository()
|
||||
|
||||
# Global memory store
|
||||
_global_memory: Dict[str, Any] = {
|
||||
"research_notes": [],
|
||||
"plans": [],
|
||||
"tasks": {}, # Dict[int, str] - ID to task mapping
|
||||
"task_id_counter": 1, # Counter for generating unique task IDs
|
||||
"key_facts": {}, # Dict[int, str] - ID to fact mapping
|
||||
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
|
||||
"key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now)
|
||||
"key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now)
|
||||
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
|
||||
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
|
||||
"implementation_requested": False,
|
||||
|
|
@ -106,12 +110,9 @@ def emit_key_facts(facts: List[str]) -> str:
|
|||
"""
|
||||
results = []
|
||||
for fact in facts:
|
||||
# Get and increment fact ID
|
||||
fact_id = _global_memory["key_fact_id_counter"]
|
||||
_global_memory["key_fact_id_counter"] += 1
|
||||
|
||||
# Store fact with ID
|
||||
_global_memory["key_facts"][fact_id] = fact
|
||||
# Create fact in database using repository
|
||||
created_fact = key_fact_repository.create(fact)
|
||||
fact_id = created_fact.id
|
||||
|
||||
# Display panel with ID
|
||||
console.print(
|
||||
|
|
@ -139,14 +140,17 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
"""
|
||||
results = []
|
||||
for fact_id in fact_ids:
|
||||
if fact_id in _global_memory["key_facts"]:
|
||||
# Get the fact first to display information
|
||||
fact = key_fact_repository.get(fact_id)
|
||||
if fact:
|
||||
# Delete the fact
|
||||
deleted_fact = _global_memory["key_facts"].pop(fact_id)
|
||||
success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}"
|
||||
console.print(
|
||||
Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")
|
||||
)
|
||||
results.append(success_msg)
|
||||
was_deleted = key_fact_repository.delete(fact_id)
|
||||
if was_deleted:
|
||||
success_msg = f"Successfully deleted fact #{fact_id}: {fact.content}"
|
||||
console.print(
|
||||
Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")
|
||||
)
|
||||
results.append(success_msg)
|
||||
|
||||
log_work_event(f"Deleted facts {fact_ids}.")
|
||||
return "Facts deleted."
|
||||
|
|
@ -601,26 +605,46 @@ def get_memory_value(key: str) -> str:
|
|||
- For key_snippets: Formatted snippet blocks
|
||||
- For other types: One value per line
|
||||
"""
|
||||
values = _global_memory.get(key, [])
|
||||
|
||||
if key == "key_facts":
|
||||
# For empty dict, return empty string
|
||||
if not values:
|
||||
return ""
|
||||
# Sort by ID for consistent output and format as markdown sections
|
||||
facts = []
|
||||
for k, v in sorted(values.items()):
|
||||
facts.extend(
|
||||
[
|
||||
f"## 🔑 Key Fact #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
v,
|
||||
"", # Empty line between facts
|
||||
]
|
||||
)
|
||||
return "\n".join(facts).rstrip() # Remove trailing newline
|
||||
try:
|
||||
# Get facts from repository as a dictionary
|
||||
facts_dict = key_fact_repository.get_facts_dict()
|
||||
|
||||
# For empty dict, return empty string
|
||||
if not facts_dict:
|
||||
return ""
|
||||
|
||||
# Sort by ID for consistent output and format as markdown sections
|
||||
facts = []
|
||||
for k, v in sorted(facts_dict.items()):
|
||||
facts.extend(
|
||||
[
|
||||
f"## 🔑 Key Fact #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
v,
|
||||
"", # Empty line between facts
|
||||
]
|
||||
)
|
||||
return "\n".join(facts).rstrip() # Remove trailing newline
|
||||
except Exception:
|
||||
# Fallback to old memory if database access fails
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
facts = []
|
||||
for k, v in sorted(values.items()):
|
||||
facts.extend(
|
||||
[
|
||||
f"## 🔑 Key Fact #{k}",
|
||||
"",
|
||||
v,
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(facts).rstrip()
|
||||
|
||||
if key == "key_snippets":
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
# Format each snippet with file info and content using markdown
|
||||
|
|
@ -645,10 +669,12 @@ def get_memory_value(key: str) -> str:
|
|||
return "\n\n".join(snippets)
|
||||
|
||||
if key == "work_log":
|
||||
values = _global_memory.get(key, [])
|
||||
if not values:
|
||||
return ""
|
||||
entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values]
|
||||
return "\n\n".join(entries)
|
||||
|
||||
# For other types (lists), join with newlines
|
||||
return "\n".join(str(v) for v in values)
|
||||
values = _global_memory.get(key, [])
|
||||
return "\n".join(str(v) for v in values)
|
||||
|
|
@ -561,10 +561,12 @@ def rollback(migrator, database, fake=False, **kwargs):
|
|||
# Check migrations
|
||||
applied, pending = manager.check_migrations()
|
||||
assert len(applied) == 0
|
||||
assert len(pending) == 1
|
||||
assert (
|
||||
migration_name in pending[0]
|
||||
) # Instead of exact equality, check if name is contained
|
||||
# There may be multiple pending migrations (source package migrations + our test migration)
|
||||
assert len(pending) >= 1
|
||||
# Make sure our newly created migration is in the pending list
|
||||
assert any(
|
||||
migration_name in migration for migration in pending
|
||||
)
|
||||
|
||||
# Apply migrations
|
||||
result = manager.apply_migrations()
|
||||
|
|
@ -572,17 +574,22 @@ def rollback(migrator, database, fake=False, **kwargs):
|
|||
|
||||
# Check migrations again
|
||||
applied, pending = manager.check_migrations()
|
||||
assert len(applied) == 1
|
||||
# There should be at least one applied migration (our test migration)
|
||||
assert len(applied) >= 1
|
||||
# All migrations should now be applied
|
||||
assert len(pending) == 0
|
||||
assert (
|
||||
migration_name in applied[0]
|
||||
) # Instead of exact equality, check if name is contained
|
||||
# Make sure our migration is in the applied list
|
||||
assert any(
|
||||
migration_name in migration for migration in applied
|
||||
)
|
||||
|
||||
# Verify migration status
|
||||
status = manager.get_migration_status()
|
||||
assert status["applied_count"] == 1
|
||||
# There should be at least one applied migration (our test migration)
|
||||
assert status["applied_count"] >= 1
|
||||
assert status["pending_count"] == 0
|
||||
# Use substring check for applied migrations
|
||||
assert len(status["applied"]) == 1
|
||||
assert migration_name in status["applied"][0]
|
||||
assert len(status["applied"]) >= 1
|
||||
# Make sure our migration is in the applied list
|
||||
assert any(migration_name in migration for migration in status["applied"])
|
||||
assert status["pending"] == []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,283 @@
|
|||
"""
|
||||
Tests for the migration system's source package migration handling.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.migrations import (
|
||||
MIGRATIONS_DIRNAME,
|
||||
MigrationManager,
|
||||
ensure_migrations_applied,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for test files."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
# Clean up
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger():
|
||||
"""Mock the logger to test for output messages."""
|
||||
with patch("ra_aid.database.migrations.logger") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestSourceMigrations:
|
||||
"""Tests for source package migration handling."""
|
||||
|
||||
def test_migration_manager_uses_source_migrations_dir(self, temp_dir, mock_logger):
|
||||
"""Test that MigrationManager uses source package migrations directory by default."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
|
||||
# Mock _get_source_package_migrations_dir to return a test path
|
||||
source_migrations_dir = os.path.join(temp_dir, "source_migrations")
|
||||
os.makedirs(source_migrations_dir, exist_ok=True)
|
||||
|
||||
# Create __init__.py to make it a proper package
|
||||
with open(os.path.join(source_migrations_dir, "__init__.py"), "w") as f:
|
||||
pass
|
||||
|
||||
# Mock router initialization
|
||||
with patch("ra_aid.database.migrations.Router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
|
||||
# Mock _get_source_package_migrations_dir
|
||||
with patch.object(
|
||||
MigrationManager,
|
||||
"_get_source_package_migrations_dir",
|
||||
return_value=source_migrations_dir
|
||||
):
|
||||
# Initialize manager
|
||||
manager = MigrationManager(db_path=db_path)
|
||||
|
||||
# Verify source migrations directory is used
|
||||
assert manager.migrations_dir == source_migrations_dir
|
||||
|
||||
# Verify logging
|
||||
mock_logger.debug.assert_any_call(
|
||||
f"Using source package migrations directory: {source_migrations_dir}"
|
||||
)
|
||||
|
||||
def test_migration_manager_with_custom_migrations_dir(self, temp_dir, mock_logger):
|
||||
"""Test that MigrationManager uses custom migrations directory when provided."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
custom_migrations_dir = os.path.join(temp_dir, "custom_migrations")
|
||||
|
||||
# Mock router initialization
|
||||
with patch("ra_aid.database.migrations.Router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
|
||||
# Initialize manager with custom migrations directory
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=custom_migrations_dir)
|
||||
|
||||
# Verify custom migrations directory is used
|
||||
assert manager.migrations_dir == custom_migrations_dir
|
||||
|
||||
# Verify directory was created
|
||||
assert os.path.exists(custom_migrations_dir)
|
||||
assert os.path.exists(os.path.join(custom_migrations_dir, "__init__.py"))
|
||||
|
||||
# Verify logging
|
||||
mock_logger.debug.assert_any_call(
|
||||
f"Using migrations directory: {custom_migrations_dir}"
|
||||
)
|
||||
|
||||
def test_get_source_package_migrations_dir(self, temp_dir, mock_logger):
|
||||
"""Test that _get_source_package_migrations_dir returns the correct path."""
|
||||
# Set up a mock source directory structure
|
||||
mock_base_dir = os.path.join(temp_dir, "ra_aid")
|
||||
os.makedirs(mock_base_dir, exist_ok=True)
|
||||
|
||||
source_migrations_dir = os.path.join(mock_base_dir, MIGRATIONS_DIRNAME)
|
||||
os.makedirs(source_migrations_dir, exist_ok=True)
|
||||
|
||||
# Create a manager with patch in place
|
||||
with patch("ra_aid.database.migrations.os.path.dirname") as mock_dirname:
|
||||
with patch("ra_aid.database.migrations.os.path.abspath") as mock_abspath:
|
||||
# Mock the path functions to return our test paths
|
||||
mock_abspath.return_value = os.path.join(mock_base_dir, "database", "migrations.py")
|
||||
# Use a custom side_effect to avoid recursion
|
||||
def dirname_side_effect(path):
|
||||
if path == os.path.join(mock_base_dir, "database", "migrations.py"):
|
||||
return os.path.join(mock_base_dir, "database")
|
||||
elif path == os.path.join(mock_base_dir, "database"):
|
||||
return mock_base_dir
|
||||
else:
|
||||
return os.path.dirname(path)
|
||||
|
||||
mock_dirname.side_effect = dirname_side_effect
|
||||
|
||||
# Create the manager
|
||||
manager = MigrationManager(db_path=os.path.join(temp_dir, "test.db"),
|
||||
migrations_dir=os.path.join(temp_dir, "custom_migrations"))
|
||||
|
||||
# Call the method directly
|
||||
with patch.object(manager, "_get_source_package_migrations_dir") as mock_method:
|
||||
mock_method.return_value = source_migrations_dir
|
||||
|
||||
# Verify the method returns the expected path
|
||||
assert manager._get_source_package_migrations_dir() == source_migrations_dir
|
||||
|
||||
def test_get_source_package_migrations_dir_not_found(self, temp_dir, mock_logger):
|
||||
"""Test that _get_source_package_migrations_dir handles missing directory."""
|
||||
# Create a manager
|
||||
manager = MigrationManager(db_path=os.path.join(temp_dir, "test.db"),
|
||||
migrations_dir=os.path.join(temp_dir, "custom_migrations"))
|
||||
|
||||
# Create a test implementation that will raise the expected error
|
||||
def raise_not_found(*args, **kwargs):
|
||||
error_msg = f"Source migrations directory not found: /path/to/migrations"
|
||||
logger = mock_logger # Use the mocked logger
|
||||
logger.error(error_msg)
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
# Replace the method with our test implementation
|
||||
with patch.object(manager, "_get_source_package_migrations_dir", side_effect=raise_not_found):
|
||||
# Call should raise FileNotFoundError
|
||||
with pytest.raises(FileNotFoundError) as excinfo:
|
||||
manager._get_source_package_migrations_dir()
|
||||
|
||||
# Verify error message
|
||||
assert "Source migrations directory not found" in str(excinfo.value)
|
||||
|
||||
# Verify logging
|
||||
mock_logger.error.assert_called_with(
|
||||
"Source migrations directory not found: /path/to/migrations"
|
||||
)
|
||||
|
||||
def test_ensure_migrations_applied_creates_ra_aid_dir(self, temp_dir, mock_logger):
|
||||
"""Test that ensure_migrations_applied creates the .ra-aid directory if it doesn't exist."""
|
||||
# Get a path to a directory that doesn't exist
|
||||
ra_aid_dir = os.path.join(temp_dir, ".ra-aid")
|
||||
|
||||
# Mock getcwd to return our temp directory
|
||||
with patch("os.getcwd", return_value=temp_dir):
|
||||
# Mock DatabaseManager
|
||||
with patch("ra_aid.database.migrations.DatabaseManager") as mock_db_manager:
|
||||
# Mock the context manager
|
||||
mock_db_manager.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db_manager.return_value.__exit__.return_value = None
|
||||
|
||||
# Mock ra_aid package import
|
||||
with patch("ra_aid.database.migrations.ra_aid", create=True) as mock_ra_aid:
|
||||
# Set up the mock package directory path
|
||||
mock_package_dir = os.path.join(temp_dir, "ra_aid_package")
|
||||
os.makedirs(mock_package_dir, exist_ok=True)
|
||||
mock_migrations_dir = os.path.join(mock_package_dir, MIGRATIONS_DIRNAME)
|
||||
os.makedirs(mock_migrations_dir, exist_ok=True)
|
||||
|
||||
# Configure the mock
|
||||
mock_ra_aid.__file__ = os.path.join(mock_package_dir, "__init__.py")
|
||||
|
||||
# Mock init_migrations and apply_migrations
|
||||
mock_migration_manager = MagicMock()
|
||||
mock_migration_manager.apply_migrations.return_value = True
|
||||
|
||||
with patch("ra_aid.database.migrations.init_migrations", return_value=mock_migration_manager):
|
||||
# Call ensure_migrations_applied
|
||||
result = ensure_migrations_applied()
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
# Verify .ra-aid directory was created
|
||||
assert os.path.exists(ra_aid_dir)
|
||||
|
||||
def test_ensure_migrations_applied_handles_directory_error(self, mock_logger):
|
||||
"""Test that ensure_migrations_applied handles errors creating the .ra-aid directory."""
|
||||
# Mock os.makedirs to raise an exception
|
||||
with patch("os.makedirs", side_effect=PermissionError("Permission denied")):
|
||||
# Call ensure_migrations_applied
|
||||
result = ensure_migrations_applied()
|
||||
|
||||
# Verify result is False on error
|
||||
assert result is False
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to ensure .ra-aid directory exists: Permission denied"
|
||||
)
|
||||
|
||||
def test_ensure_migrations_applied_uses_package_migrations(self, temp_dir, mock_logger):
|
||||
"""Test that ensure_migrations_applied uses the source package migrations directory."""
|
||||
# Set up test paths
|
||||
ra_aid_dir = os.path.join(temp_dir, ".ra-aid")
|
||||
|
||||
# Mock getcwd to return our temp directory
|
||||
with patch("os.getcwd", return_value=temp_dir):
|
||||
# Mock DatabaseManager
|
||||
with patch("ra_aid.database.migrations.DatabaseManager") as mock_db_manager:
|
||||
# Mock the context manager
|
||||
mock_db_manager.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db_manager.return_value.__exit__.return_value = None
|
||||
|
||||
# Mock ra_aid package import
|
||||
with patch("ra_aid.database.migrations.ra_aid", create=True) as mock_ra_aid:
|
||||
# Set up the mock package directory path
|
||||
mock_package_dir = os.path.join(temp_dir, "ra_aid_package")
|
||||
os.makedirs(mock_package_dir, exist_ok=True)
|
||||
mock_migrations_dir = os.path.join(mock_package_dir, MIGRATIONS_DIRNAME)
|
||||
os.makedirs(mock_migrations_dir, exist_ok=True)
|
||||
|
||||
# Configure the mock
|
||||
mock_ra_aid.__file__ = os.path.join(mock_package_dir, "__init__.py")
|
||||
|
||||
# Create a mock migration manager that we can verify
|
||||
mock_init_migrations = MagicMock()
|
||||
mock_init_migrations.apply_migrations.return_value = True
|
||||
|
||||
with patch("ra_aid.database.migrations.init_migrations", return_value=mock_init_migrations) as mock_init:
|
||||
# Call ensure_migrations_applied
|
||||
result = ensure_migrations_applied()
|
||||
|
||||
# Verify init_migrations was called
|
||||
mock_init.assert_called_once()
|
||||
# We can't verify the exact path since it's derived from non-mock objects
|
||||
# Instead, verify that init_migrations was called and succeeded
|
||||
assert result is True
|
||||
|
||||
def test_router_initialization_with_source_migrations(self, temp_dir, mock_logger):
|
||||
"""Test that the migration router is initialized with the source package migrations."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
|
||||
# Create a mock source migrations directory
|
||||
source_migrations_dir = os.path.join(temp_dir, "source_migrations")
|
||||
os.makedirs(source_migrations_dir, exist_ok=True)
|
||||
|
||||
# Create __init__.py to make it a proper package
|
||||
with open(os.path.join(source_migrations_dir, "__init__.py"), "w") as f:
|
||||
pass
|
||||
|
||||
# Mock the Router class
|
||||
with patch("ra_aid.database.migrations.Router") as mock_router_class:
|
||||
# Create a mock router instance
|
||||
mock_router = MagicMock()
|
||||
mock_router_class.return_value = mock_router
|
||||
|
||||
# Mock _get_source_package_migrations_dir
|
||||
with patch.object(
|
||||
MigrationManager,
|
||||
"_get_source_package_migrations_dir",
|
||||
return_value=source_migrations_dir
|
||||
):
|
||||
# Initialize manager
|
||||
manager = MigrationManager(db_path=db_path)
|
||||
|
||||
# Verify router was initialized with the source migrations directory
|
||||
mock_router_class.assert_called_once()
|
||||
# Get the args from the call
|
||||
call_args = mock_router_class.call_args
|
||||
assert call_args.kwargs["migrate_dir"] == source_migrations_dir
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from ra_aid.tools.memory import (
|
||||
_global_memory,
|
||||
|
|
@ -13,10 +14,13 @@ from ra_aid.tools.memory import (
|
|||
get_memory_value,
|
||||
get_related_files,
|
||||
get_work_log,
|
||||
key_fact_repository,
|
||||
log_work_event,
|
||||
reset_work_log,
|
||||
swap_task_order,
|
||||
)
|
||||
from ra_aid.database.connection import DatabaseManager
|
||||
from ra_aid.database.models import KeyFact
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -48,49 +52,111 @@ def reset_memory():
|
|||
_global_memory["work_log"] = []
|
||||
|
||||
|
||||
def test_emit_key_facts_single_fact(reset_memory):
|
||||
@pytest.fixture
|
||||
def in_memory_db():
|
||||
"""Set up an in-memory database for testing."""
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
db.create_tables([KeyFact])
|
||||
yield db
|
||||
# Clean up database tables after test
|
||||
KeyFact.delete().execute()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_repository():
|
||||
"""Mock the KeyFactRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.tools.memory.key_fact_repository') as mock_repo:
|
||||
# Setup the mock repository to behave like the original, but using memory
|
||||
facts = {} # Local in-memory storage
|
||||
fact_id_counter = 0
|
||||
|
||||
# Mock KeyFact objects
|
||||
class MockKeyFact:
|
||||
def __init__(self, id, content):
|
||||
self.id = id
|
||||
self.content = content
|
||||
|
||||
# Mock create method
|
||||
def mock_create(content):
|
||||
nonlocal fact_id_counter
|
||||
fact = MockKeyFact(fact_id_counter, content)
|
||||
facts[fact_id_counter] = fact
|
||||
fact_id_counter += 1
|
||||
return fact
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Mock get method
|
||||
def mock_get(fact_id):
|
||||
return facts.get(fact_id)
|
||||
mock_repo.get.side_effect = mock_get
|
||||
|
||||
# Mock delete method
|
||||
def mock_delete(fact_id):
|
||||
if fact_id in facts:
|
||||
del facts[fact_id]
|
||||
return True
|
||||
return False
|
||||
mock_repo.delete.side_effect = mock_delete
|
||||
|
||||
# Mock get_facts_dict method
|
||||
def mock_get_facts_dict():
|
||||
return {fact_id: fact.content for fact_id, fact in facts.items()}
|
||||
mock_repo.get_facts_dict.side_effect = mock_get_facts_dict
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
||||
"""Test emitting a single key fact using emit_key_facts"""
|
||||
# Test with single fact
|
||||
result = emit_key_facts.invoke({"facts": ["First fact"]})
|
||||
assert result == "Facts stored."
|
||||
assert _global_memory["key_facts"][0] == "First fact"
|
||||
assert _global_memory["key_fact_id_counter"] == 1
|
||||
|
||||
# Verify the repository's create method was called
|
||||
mock_repository.create.assert_called_once_with("First fact")
|
||||
|
||||
|
||||
def test_delete_key_facts_single_fact(reset_memory):
|
||||
def test_delete_key_facts_single_fact(reset_memory, mock_repository):
|
||||
"""Test deleting a single key fact using delete_key_facts"""
|
||||
# Add a fact
|
||||
emit_key_facts.invoke({"facts": ["Test fact"]})
|
||||
|
||||
fact = mock_repository.create("Test fact")
|
||||
fact_id = fact.id
|
||||
|
||||
# Delete the fact
|
||||
result = delete_key_facts.invoke({"fact_ids": [0]})
|
||||
result = delete_key_facts.invoke({"fact_ids": [fact_id]})
|
||||
assert result == "Facts deleted."
|
||||
assert 0 not in _global_memory["key_facts"]
|
||||
|
||||
# Verify the repository's delete method was called
|
||||
mock_repository.delete.assert_called_once_with(fact_id)
|
||||
|
||||
|
||||
def test_delete_key_facts_invalid(reset_memory):
|
||||
def test_delete_key_facts_invalid(reset_memory, mock_repository):
|
||||
"""Test deleting non-existent facts returns empty list"""
|
||||
# Try to delete non-existent fact
|
||||
result = delete_key_facts.invoke({"fact_ids": [999]})
|
||||
assert result == "Facts deleted."
|
||||
|
||||
# Add and delete a fact, then try to delete it again
|
||||
emit_key_facts.invoke({"facts": ["Test fact"]})
|
||||
delete_key_facts.invoke({"fact_ids": [0]})
|
||||
result = delete_key_facts.invoke({"fact_ids": [0]})
|
||||
assert result == "Facts deleted."
|
||||
# Verify the repository's get method was called
|
||||
mock_repository.get.assert_called_once_with(999)
|
||||
|
||||
|
||||
def test_get_memory_value_key_facts(reset_memory):
|
||||
def test_get_memory_value_key_facts(reset_memory, mock_repository):
|
||||
"""Test get_memory_value with key facts dictionary"""
|
||||
# Empty key facts should return empty string
|
||||
assert get_memory_value("key_facts") == ""
|
||||
|
||||
# Add some facts
|
||||
emit_key_facts.invoke({"facts": ["First fact", "Second fact"]})
|
||||
# Add some facts through the mocked repository
|
||||
fact1 = mock_repository.create("First fact")
|
||||
fact2 = mock_repository.create("Second fact")
|
||||
|
||||
# Mock get_facts_dict to return our test data
|
||||
mock_repository.get_facts_dict.return_value = {
|
||||
fact1.id: "First fact",
|
||||
fact2.id: "Second fact"
|
||||
}
|
||||
|
||||
# Should return markdown formatted list
|
||||
expected = "## 🔑 Key Fact #0\n\nFirst fact\n\n## 🔑 Key Fact #1\n\nSecond fact"
|
||||
expected = f"## 🔑 Key Fact #{fact1.id}\n\nFirst fact\n\n## 🔑 Key Fact #{fact2.id}\n\nSecond fact"
|
||||
assert get_memory_value("key_facts") == expected
|
||||
|
||||
|
||||
|
|
@ -165,7 +231,7 @@ def test_empty_work_log(reset_memory):
|
|||
assert get_memory_value("work_log") == ""
|
||||
|
||||
|
||||
def test_emit_key_facts(reset_memory):
|
||||
def test_emit_key_facts(reset_memory, mock_repository):
|
||||
"""Test emitting multiple key facts at once"""
|
||||
# Test emitting multiple facts
|
||||
facts = ["First fact", "Second fact", "Third fact"]
|
||||
|
|
@ -174,31 +240,30 @@ def test_emit_key_facts(reset_memory):
|
|||
# Verify return message
|
||||
assert result == "Facts stored."
|
||||
|
||||
# Verify facts stored in memory with correct IDs
|
||||
assert _global_memory["key_facts"][0] == "First fact"
|
||||
assert _global_memory["key_facts"][1] == "Second fact"
|
||||
assert _global_memory["key_facts"][2] == "Third fact"
|
||||
|
||||
# Verify counter incremented correctly
|
||||
assert _global_memory["key_fact_id_counter"] == 3
|
||||
# Verify create was called for each fact
|
||||
assert mock_repository.create.call_count == 3
|
||||
mock_repository.create.assert_any_call("First fact")
|
||||
mock_repository.create.assert_any_call("Second fact")
|
||||
mock_repository.create.assert_any_call("Third fact")
|
||||
|
||||
|
||||
def test_delete_key_facts(reset_memory):
|
||||
def test_delete_key_facts(reset_memory, mock_repository):
|
||||
"""Test deleting multiple key facts"""
|
||||
# Add some test facts
|
||||
emit_key_facts.invoke({"facts": ["First fact", "Second fact", "Third fact"]})
|
||||
fact0 = mock_repository.create("First fact")
|
||||
fact1 = mock_repository.create("Second fact")
|
||||
fact2 = mock_repository.create("Third fact")
|
||||
|
||||
# Test deleting mix of existing and non-existing IDs
|
||||
result = delete_key_facts.invoke({"fact_ids": [0, 1, 999]})
|
||||
result = delete_key_facts.invoke({"fact_ids": [fact0.id, fact1.id, 999]})
|
||||
|
||||
# Verify success message
|
||||
assert result == "Facts deleted."
|
||||
|
||||
# Verify correct facts removed from memory
|
||||
assert 0 not in _global_memory["key_facts"]
|
||||
assert 1 not in _global_memory["key_facts"]
|
||||
assert 2 in _global_memory["key_facts"] # ID 2 should remain
|
||||
assert _global_memory["key_facts"][2] == "Third fact"
|
||||
# Verify delete was called for each valid fact ID
|
||||
assert mock_repository.delete.call_count == 2
|
||||
mock_repository.delete.assert_any_call(fact0.id)
|
||||
mock_repository.delete.assert_any_call(fact1.id)
|
||||
|
||||
|
||||
def test_emit_key_snippet(reset_memory):
|
||||
|
|
@ -820,4 +885,4 @@ def test_is_binary_file_with_null_bytes(reset_memory, tmp_path, monkeypatch):
|
|||
is_binary = ra_aid.tools.memory.is_binary_file(str(binary_file))
|
||||
assert (
|
||||
is_binary
|
||||
), "File with null bytes should be identified as binary with fallback method"
|
||||
), "File with null bytes should be identified as binary with fallback method"
|
||||
Loading…
Reference in New Issue