RA.Aid/ra_aid/database/repositories/key_fact_repository.py

255 lines
7.9 KiB
Python

"""
Key fact repository implementation for database access.
This module provides a repository implementation for the KeyFact model,
following the repository pattern for data access abstraction.
"""
from typing import Dict, List, Optional
import contextvars
from contextlib import contextmanager
import peewee
from ra_aid.database.models import KeyFact
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
# Create contextvar to hold the KeyFactRepository instance
key_fact_repo_var = contextvars.ContextVar("key_fact_repo", default=None)
class KeyFactRepositoryManager:
"""
Context manager for KeyFactRepository.
This class provides a context manager interface for KeyFactRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with KeyFactRepositoryManager(db) as repo:
# Use the repository
fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the KeyFactRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'KeyFactRepository':
"""
Initialize the KeyFactRepository and return it.
Returns:
KeyFactRepository: The initialized repository
"""
repo = KeyFactRepository(self.db)
key_fact_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
key_fact_repo_var.set(None)
# Don't suppress exceptions
return False
def get_key_fact_repository() -> 'KeyFactRepository':
"""
Get the current KeyFactRepository instance.
Returns:
KeyFactRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with KeyFactRepositoryManager
"""
repo = key_fact_repo_var.get()
if repo is None:
raise RuntimeError(
"No KeyFactRepository available. "
"Make sure to initialize one with KeyFactRepositoryManager first."
)
return repo
class KeyFactRepository:
"""
Repository for managing KeyFact database operations.
This class provides methods for performing CRUD operations on the KeyFact model,
abstracting the database access details from the business logic.
Example:
with DatabaseManager() as db:
with KeyFactRepositoryManager(db) as repo:
fact = repo.create("Important fact about the project")
all_facts = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the repository with a database connection.
Args:
db: Database connection to use (required)
"""
if db is None:
raise ValueError("Database connection is required for KeyFactRepository")
self.db = db
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
"""
Create a new key fact in the database.
Args:
content: The text content of the key fact
human_input_id: Optional ID of the associated human input
Returns:
KeyFact: The newly created key fact instance
Raises:
peewee.DatabaseError: If there's an error creating the fact
"""
try:
fact = KeyFact.create(content=content, human_input_id=human_input_id)
logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact
except peewee.DatabaseError as e:
logger.error(f"Failed to create key fact: {str(e)}")
raise
def get(self, fact_id: int) -> Optional[KeyFact]:
"""
Retrieve a key fact by its ID.
Args:
fact_id: The ID of the key fact to retrieve
Returns:
Optional[KeyFact]: The key fact instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return KeyFact.get_or_none(KeyFact.id == fact_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
raise
def update(self, fact_id: int, content: str) -> Optional[KeyFact]:
"""
Update an existing key fact.
Args:
fact_id: The ID of the key fact to update
content: The new content for the key fact
Returns:
Optional[KeyFact]: The updated key fact if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the fact
"""
try:
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
return None
# Update the fact
fact.content = content
fact.save()
logger.debug(f"Updated key fact ID {fact_id}: {content}")
return fact
except peewee.DatabaseError as e:
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
raise
def delete(self, fact_id: int) -> bool:
"""
Delete a key fact by its ID.
Args:
fact_id: The ID of the key fact to delete
Returns:
bool: True if the fact was deleted, False if it wasn't found
Raises:
peewee.DatabaseError: If there's an error deleting the fact
"""
try:
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
return False
# Delete the fact
fact.delete_instance()
logger.debug(f"Deleted key fact ID {fact_id}")
return True
except peewee.DatabaseError as e:
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
raise
def get_all(self) -> List[KeyFact]:
"""
Retrieve all key facts from the database.
Returns:
List[KeyFact]: List of all key fact instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(KeyFact.select().order_by(KeyFact.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key facts: {str(e)}")
raise
def get_facts_dict(self) -> Dict[int, str]:
"""
Retrieve all key facts as a dictionary mapping IDs to content.
This method is useful for compatibility with the existing memory format.
Returns:
Dict[int, str]: Dictionary with fact IDs as keys and content as values
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
facts = self.get_all()
return {fact.id: fact.content for fact in facts}
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key facts as dictionary: {str(e)}")
raise