key snippets context var

This commit is contained in:
AI Christianson 2025-03-03 17:27:41 -05:00
parent ffd1ef15d4
commit fc58aa0b77
2 changed files with 107 additions and 40 deletions

View File

@ -43,7 +43,9 @@ from ra_aid.config import (
VALID_PROVIDERS,
)
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.database.repositories.key_snippet_repository import (
KeySnippetRepositoryManager, get_key_snippet_repository
)
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.console.output import cpm
@ -394,9 +396,10 @@ def main():
logger.error(f"Database migration error: {str(e)}")
# Initialize repositories with database connection
with KeyFactRepositoryManager(db) as key_fact_repo:
# This initializes the repository and makes it available via get_key_fact_repository()
with KeyFactRepositoryManager(db) as key_fact_repo, KeySnippetRepositoryManager(db) as key_snippet_repo:
# This initializes both repositories and makes them available via their respective get methods
logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository")
# Check dependencies before proceeding
check_dependencies()
@ -533,7 +536,7 @@ def main():
working_directory=working_directory,
current_date=current_date,
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()),
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
project_info=formatted_project_info,
),
config,

View File

@ -6,15 +6,98 @@ following the repository pattern for data access abstraction.
"""
from typing import Dict, List, Optional, Any
import contextvars
import peewee
from ra_aid.database.connection import get_db
from ra_aid.database.models import KeySnippet, initialize_database
from ra_aid.database.models import KeySnippet
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
# Create contextvar to hold the KeySnippetRepository instance
key_snippet_repo_var = contextvars.ContextVar("key_snippet_repo", default=None)
class KeySnippetRepositoryManager:
"""
Context manager for KeySnippetRepository.
This class provides a context manager interface for KeySnippetRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with KeySnippetRepositoryManager(db) as repo:
# Use the repository
snippet = repo.create(
filepath="main.py",
line_number=42,
snippet="def hello_world():",
description="Main function definition"
)
all_snippets = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the KeySnippetRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'KeySnippetRepository':
"""
Initialize the KeySnippetRepository and return it.
Returns:
KeySnippetRepository: The initialized repository
"""
repo = KeySnippetRepository(self.db)
key_snippet_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
key_snippet_repo_var.set(None)
# Don't suppress exceptions
return False
def get_key_snippet_repository() -> 'KeySnippetRepository':
"""
Get the current KeySnippetRepository instance.
Returns:
KeySnippetRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with KeySnippetRepositoryManager
"""
repo = key_snippet_repo_var.get()
if repo is None:
raise RuntimeError(
"No KeySnippetRepository available. "
"Make sure to initialize one with KeySnippetRepositoryManager first."
)
return repo
class KeySnippetRepository:
"""
@ -24,23 +107,26 @@ class KeySnippetRepository:
abstracting the database access details from the business logic.
Example:
repo = KeySnippetRepository()
snippet = repo.create(
filepath="main.py",
line_number=42,
snippet="def hello_world():",
description="Main function definition"
)
all_snippets = repo.get_all()
with DatabaseManager() as db:
with KeySnippetRepositoryManager(db) as repo:
snippet = repo.create(
filepath="main.py",
line_number=42,
snippet="def hello_world():",
description="Main function definition"
)
all_snippets = repo.get_all()
"""
def __init__(self, db=None):
def __init__(self, db):
"""
Initialize the repository with an optional database connection.
Initialize the repository with a database connection.
Args:
db: Optional database connection to use. If None, will use initialize_database()
db: Database connection to use (required)
"""
if db is None:
raise ValueError("Database connection is required for KeySnippetRepository")
self.db = db
def create(
@ -64,7 +150,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error creating the snippet
"""
try:
db = self.db if self.db is not None else initialize_database()
key_snippet = KeySnippet.create(
filepath=filepath,
line_number=line_number,
@ -92,7 +177,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = self.db if self.db is not None else initialize_database()
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
@ -123,7 +207,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error updating the snippet
"""
try:
db = self.db if self.db is not None else initialize_database()
# First check if the snippet exists
key_snippet = self.get(snippet_id)
if not key_snippet:
@ -156,7 +239,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error deleting the snippet
"""
try:
db = self.db if self.db is not None else initialize_database()
# First check if the snippet exists
key_snippet = self.get(snippet_id)
if not key_snippet:
@ -182,7 +264,6 @@ class KeySnippetRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = self.db if self.db is not None else initialize_database()
return list(KeySnippet.select().order_by(KeySnippet.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key snippets: {str(e)}")
@ -214,21 +295,4 @@ class KeySnippetRepository:
}
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
raise
# Global singleton instance
_key_snippet_repository = None
def get_key_snippet_repository() -> KeySnippetRepository:
"""
Get or create a singleton instance of KeySnippetRepository.
Returns:
KeySnippetRepository: Singleton instance of the repository
"""
global _key_snippet_repository
if _key_snippet_repository is None:
_key_snippet_repository = KeySnippetRepository()
return _key_snippet_repository
raise