key snippets context var

This commit is contained in:
AI Christianson 2025-03-03 17:27:41 -05:00
parent ffd1ef15d4
commit fc58aa0b77
2 changed files with 107 additions and 40 deletions

View File

@ -43,7 +43,9 @@ from ra_aid.config import (
VALID_PROVIDERS, VALID_PROVIDERS,
) )
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository from ra_aid.database.repositories.key_snippet_repository import (
KeySnippetRepositoryManager, get_key_snippet_repository
)
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.console.output import cpm from ra_aid.console.output import cpm
@ -394,9 +396,10 @@ def main():
logger.error(f"Database migration error: {str(e)}") logger.error(f"Database migration error: {str(e)}")
# Initialize repositories with database connection # Initialize repositories with database connection
with KeyFactRepositoryManager(db) as key_fact_repo: with KeyFactRepositoryManager(db) as key_fact_repo, KeySnippetRepositoryManager(db) as key_snippet_repo:
# This initializes the repository and makes it available via get_key_fact_repository() # This initializes both repositories and makes them available via their respective get methods
logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository")
# Check dependencies before proceeding # Check dependencies before proceeding
check_dependencies() check_dependencies()
@ -533,7 +536,7 @@ def main():
working_directory=working_directory, working_directory=working_directory,
current_date=current_date, current_date=current_date,
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()), key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()), key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
project_info=formatted_project_info, project_info=formatted_project_info,
), ),
config, config,

View File

@ -6,15 +6,98 @@ following the repository pattern for data access abstraction.
""" """
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
import contextvars
import peewee import peewee
from ra_aid.database.connection import get_db from ra_aid.database.models import KeySnippet
from ra_aid.database.models import KeySnippet, initialize_database
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
# Create contextvar to hold the KeySnippetRepository instance
key_snippet_repo_var = contextvars.ContextVar("key_snippet_repo", default=None)
class KeySnippetRepositoryManager:
"""
Context manager for KeySnippetRepository.
This class provides a context manager interface for KeySnippetRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with KeySnippetRepositoryManager(db) as repo:
# Use the repository
snippet = repo.create(
filepath="main.py",
line_number=42,
snippet="def hello_world():",
description="Main function definition"
)
all_snippets = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the KeySnippetRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'KeySnippetRepository':
"""
Initialize the KeySnippetRepository and return it.
Returns:
KeySnippetRepository: The initialized repository
"""
repo = KeySnippetRepository(self.db)
key_snippet_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
key_snippet_repo_var.set(None)
# Don't suppress exceptions
return False
def get_key_snippet_repository() -> 'KeySnippetRepository':
"""
Get the current KeySnippetRepository instance.
Returns:
KeySnippetRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with KeySnippetRepositoryManager
"""
repo = key_snippet_repo_var.get()
if repo is None:
raise RuntimeError(
"No KeySnippetRepository available. "
"Make sure to initialize one with KeySnippetRepositoryManager first."
)
return repo
class KeySnippetRepository: class KeySnippetRepository:
""" """
@ -24,23 +107,26 @@ class KeySnippetRepository:
abstracting the database access details from the business logic. abstracting the database access details from the business logic.
Example: Example:
repo = KeySnippetRepository() with DatabaseManager() as db:
snippet = repo.create( with KeySnippetRepositoryManager(db) as repo:
filepath="main.py", snippet = repo.create(
line_number=42, filepath="main.py",
snippet="def hello_world():", line_number=42,
description="Main function definition" snippet="def hello_world():",
) description="Main function definition"
all_snippets = repo.get_all() )
all_snippets = repo.get_all()
""" """
def __init__(self, db=None): def __init__(self, db):
""" """
Initialize the repository with an optional database connection. Initialize the repository with a database connection.
Args: Args:
db: Optional database connection to use. If None, will use initialize_database() db: Database connection to use (required)
""" """
if db is None:
raise ValueError("Database connection is required for KeySnippetRepository")
self.db = db self.db = db
def create( def create(
@ -64,7 +150,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error creating the snippet peewee.DatabaseError: If there's an error creating the snippet
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
key_snippet = KeySnippet.create( key_snippet = KeySnippet.create(
filepath=filepath, filepath=filepath,
line_number=line_number, line_number=line_number,
@ -92,7 +177,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error accessing the database peewee.DatabaseError: If there's an error accessing the database
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
return KeySnippet.get_or_none(KeySnippet.id == snippet_id) return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
except peewee.DatabaseError as e: except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}") logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
@ -123,7 +207,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error updating the snippet peewee.DatabaseError: If there's an error updating the snippet
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
# First check if the snippet exists # First check if the snippet exists
key_snippet = self.get(snippet_id) key_snippet = self.get(snippet_id)
if not key_snippet: if not key_snippet:
@ -156,7 +239,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error deleting the snippet peewee.DatabaseError: If there's an error deleting the snippet
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
# First check if the snippet exists # First check if the snippet exists
key_snippet = self.get(snippet_id) key_snippet = self.get(snippet_id)
if not key_snippet: if not key_snippet:
@ -182,7 +264,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error accessing the database peewee.DatabaseError: If there's an error accessing the database
""" """
try: try:
db = self.db if self.db is not None else initialize_database()
return list(KeySnippet.select().order_by(KeySnippet.id)) return list(KeySnippet.select().order_by(KeySnippet.id))
except peewee.DatabaseError as e: except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key snippets: {str(e)}") logger.error(f"Failed to fetch all key snippets: {str(e)}")
@ -214,21 +295,4 @@ class KeySnippetRepository:
} }
except peewee.DatabaseError as e: except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}") logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
raise raise
# Global singleton instance
_key_snippet_repository = None
def get_key_snippet_repository() -> KeySnippetRepository:
"""
Get or create a singleton instance of KeySnippetRepository.
Returns:
KeySnippetRepository: Singleton instance of the repository
"""
global _key_snippet_repository
if _key_snippet_repository is None:
_key_snippet_repository = KeySnippetRepository()
return _key_snippet_repository