""" Tests for the KeyFactRepository class. """ import pytest from unittest.mock import patch import peewee from ra_aid.database.connection import DatabaseManager, db_var from ra_aid.database.models import KeyFact, BaseModel from ra_aid.database.repositories.key_fact_repository import ( KeyFactRepository, KeyFactRepositoryManager, get_key_fact_repository, key_fact_repo_var ) from ra_aid.database.pydantic_models import KeyFactModel @pytest.fixture def cleanup_db(): """Reset the database contextvar and connection state after each test.""" # Reset before the test db = db_var.get() if db is not None: try: if not db.is_closed(): db.close() except Exception: # Ignore errors when closing the database pass db_var.set(None) # Run the test yield # Reset after the test db = db_var.get() if db is not None: try: if not db.is_closed(): db.close() except Exception: # Ignore errors when closing the database pass db_var.set(None) @pytest.fixture def cleanup_repo(): """Reset the repository contextvar after each test.""" # Reset before the test key_fact_repo_var.set(None) # Run the test yield # Reset after the test key_fact_repo_var.set(None) @pytest.fixture def setup_db(cleanup_db): """Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database.""" # Initialize an in-memory database connection with DatabaseManager(in_memory=True) as db: # Patch the BaseModel.Meta.database to use our in-memory database # This ensures that model operations like KeyFact.create() use our test database with patch.object(BaseModel._meta, 'database', db): # Create the KeyFact table with db.atomic(): db.create_tables([KeyFact], safe=True) yield db # Clean up with db.atomic(): KeyFact.drop_table(safe=True) def test_create_key_fact(setup_db): """Test creating a key fact.""" # Set up repository repo = KeyFactRepository(db=setup_db) # Create a key fact content = "Test key fact" fact = repo.create(content) # Verify the fact was created correctly and is a KeyFactModel assert isinstance(fact, KeyFactModel) assert fact.id is not None assert fact.content == content # Verify we can retrieve it from the database using the repository fact_from_db = repo.get(fact.id) assert fact_from_db.content == content def test_get_key_fact(setup_db): """Test retrieving a key fact by ID.""" # Set up repository repo = KeyFactRepository(db=setup_db) # Create a key fact content = "Test key fact" fact = repo.create(content) # Retrieve the fact by ID retrieved_fact = repo.get(fact.id) # Verify the retrieved fact matches the original and is a KeyFactModel assert isinstance(retrieved_fact, KeyFactModel) assert retrieved_fact is not None assert retrieved_fact.id == fact.id assert retrieved_fact.content == content # Try to retrieve a non-existent fact non_existent_fact = repo.get(999) assert non_existent_fact is None def test_update_key_fact(setup_db): """Test updating a key fact.""" # Set up repository repo = KeyFactRepository(db=setup_db) # Create a key fact original_content = "Original content" fact = repo.create(original_content) # Update the fact new_content = "Updated content" updated_fact = repo.update(fact.id, new_content) # Verify the fact was updated correctly and is a KeyFactModel assert isinstance(updated_fact, KeyFactModel) assert updated_fact is not None assert updated_fact.id == fact.id assert updated_fact.content == new_content # Verify we can retrieve the updated content from the database using the repository fact_from_db = repo.get(fact.id) assert fact_from_db.content == new_content # Try to update a non-existent fact non_existent_update = repo.update(999, "This shouldn't work") assert non_existent_update is None def test_delete_key_fact(setup_db): """Test deleting a key fact.""" # Set up repository repo = KeyFactRepository(db=setup_db) # Create a key fact content = "Test key fact to delete" fact = repo.create(content) # Verify the fact exists using the repository assert repo.get(fact.id) is not None # Delete the fact delete_result = repo.delete(fact.id) # Verify the delete operation was successful assert delete_result is True # Verify the fact no longer exists in the database using the repository assert repo.get(fact.id) is None # Try to delete a non-existent fact non_existent_delete = repo.delete(999) assert non_existent_delete is False def test_get_all_key_facts(setup_db): """Test retrieving all key facts.""" # Set up repository repo = KeyFactRepository(db=setup_db) # Create some key facts contents = ["Fact 1", "Fact 2", "Fact 3"] for content in contents: repo.create(content) # Retrieve all facts all_facts = repo.get_all() # Verify we got the correct number of facts and they are KeyFactModel instances assert len(all_facts) == len(contents) for fact in all_facts: assert isinstance(fact, KeyFactModel) # Verify the content of each fact fact_contents = [fact.content for fact in all_facts] for content in contents: assert content in fact_contents def test_get_facts_dict(setup_db): """Test retrieving key facts as a dictionary.""" # Set up repository repo = KeyFactRepository(db=setup_db) # Create some key facts facts = [] contents = ["Fact 1", "Fact 2", "Fact 3"] for content in contents: facts.append(repo.create(content)) # Retrieve facts as dictionary facts_dict = repo.get_facts_dict() # Verify we got the correct number of facts assert len(facts_dict) == len(contents) # Verify each fact is in the dictionary with the correct content for fact in facts: assert fact.id in facts_dict assert facts_dict[fact.id] == fact.content def test_repository_init_without_db(): """Test that KeyFactRepository raises an error when initialized without a db parameter.""" # Attempt to create a repository without a database connection with pytest.raises(ValueError) as excinfo: KeyFactRepository(db=None) # Verify the correct error message assert "Database connection is required" in str(excinfo.value) def test_key_fact_repository_manager(setup_db, cleanup_repo): """Test the KeyFactRepositoryManager context manager.""" # Use the context manager to create a repository with KeyFactRepositoryManager(setup_db) as repo: # Verify the repository was created correctly assert isinstance(repo, KeyFactRepository) assert repo.db is setup_db # Verify we can use the repository content = "Test fact via context manager" fact = repo.create(content) assert isinstance(fact, KeyFactModel) assert fact.id is not None assert fact.content == content # Verify we can get the repository using get_key_fact_repository repo_from_var = get_key_fact_repository() assert repo_from_var is repo # Verify the repository was removed from the context var with pytest.raises(RuntimeError) as excinfo: get_key_fact_repository() assert "No KeyFactRepository available" in str(excinfo.value) def test_get_key_fact_repository_when_not_set(cleanup_repo): """Test that get_key_fact_repository raises an error when no repository is in context.""" # Attempt to get the repository when none exists with pytest.raises(RuntimeError) as excinfo: get_key_fact_repository() # Verify the correct error message assert "No KeyFactRepository available" in str(excinfo.value) def test_to_model_method(setup_db): """Test the _to_model method converts KeyFact to KeyFactModel correctly.""" # Set up repository repo = KeyFactRepository(db=setup_db) # Create a Peewee KeyFact directly peewee_fact = KeyFact.create(content="Test fact for conversion") # Convert to Pydantic model pydantic_fact = repo._to_model(peewee_fact) # Verify conversion was correct assert isinstance(pydantic_fact, KeyFactModel) assert pydantic_fact.id == peewee_fact.id assert pydantic_fact.content == peewee_fact.content assert pydantic_fact.created_at == peewee_fact.created_at assert pydantic_fact.updated_at == peewee_fact.updated_at # Test with None input assert repo._to_model(None) is None