From fc58aa0b7783cfd3589519a27c47573c82b515ee Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Mon, 3 Mar 2025 17:27:41 -0500 Subject: [PATCH] key snippets context var --- ra_aid/__main__.py | 11 +- .../repositories/key_snippet_repository.py | 136 +++++++++++++----- 2 files changed, 107 insertions(+), 40 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index fd1deb2..fe4c048 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -43,7 +43,9 @@ from ra_aid.config import ( VALID_PROVIDERS, ) from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository -from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository +from ra_aid.database.repositories.key_snippet_repository import ( + KeySnippetRepositoryManager, get_key_snippet_repository +) from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.console.output import cpm @@ -394,9 +396,10 @@ def main(): logger.error(f"Database migration error: {str(e)}") # Initialize repositories with database connection - with KeyFactRepositoryManager(db) as key_fact_repo: - # This initializes the repository and makes it available via get_key_fact_repository() + with KeyFactRepositoryManager(db) as key_fact_repo, KeySnippetRepositoryManager(db) as key_snippet_repo: + # This initializes both repositories and makes them available via their respective get methods logger.debug("Initialized KeyFactRepository") + logger.debug("Initialized KeySnippetRepository") # Check dependencies before proceeding check_dependencies() @@ -533,7 +536,7 @@ def main(): working_directory=working_directory, current_date=current_date, key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()), - key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()), + key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), project_info=formatted_project_info, ), config, diff --git a/ra_aid/database/repositories/key_snippet_repository.py b/ra_aid/database/repositories/key_snippet_repository.py index 7991838..e63826f 100644 --- a/ra_aid/database/repositories/key_snippet_repository.py +++ b/ra_aid/database/repositories/key_snippet_repository.py @@ -6,15 +6,98 @@ following the repository pattern for data access abstraction. """ from typing import Dict, List, Optional, Any +import contextvars import peewee -from ra_aid.database.connection import get_db -from ra_aid.database.models import KeySnippet, initialize_database +from ra_aid.database.models import KeySnippet from ra_aid.logging_config import get_logger logger = get_logger(__name__) +# Create contextvar to hold the KeySnippetRepository instance +key_snippet_repo_var = contextvars.ContextVar("key_snippet_repo", default=None) + + +class KeySnippetRepositoryManager: + """ + Context manager for KeySnippetRepository. + + This class provides a context manager interface for KeySnippetRepository, + using the contextvars approach for thread safety. + + Example: + with DatabaseManager() as db: + with KeySnippetRepositoryManager(db) as repo: + # Use the repository + snippet = repo.create( + filepath="main.py", + line_number=42, + snippet="def hello_world():", + description="Main function definition" + ) + all_snippets = repo.get_all() + """ + + def __init__(self, db): + """ + Initialize the KeySnippetRepositoryManager. + + Args: + db: Database connection to use (required) + """ + self.db = db + + def __enter__(self) -> 'KeySnippetRepository': + """ + Initialize the KeySnippetRepository and return it. + + Returns: + KeySnippetRepository: The initialized repository + """ + repo = KeySnippetRepository(self.db) + key_snippet_repo_var.set(repo) + return repo + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Optional[object], + ) -> None: + """ + Reset the repository when exiting the context. + + Args: + exc_type: The exception type if an exception was raised + exc_val: The exception value if an exception was raised + exc_tb: The traceback if an exception was raised + """ + # Reset the contextvar to None + key_snippet_repo_var.set(None) + + # Don't suppress exceptions + return False + + +def get_key_snippet_repository() -> 'KeySnippetRepository': + """ + Get the current KeySnippetRepository instance. + + Returns: + KeySnippetRepository: The current repository instance + + Raises: + RuntimeError: If no repository has been initialized with KeySnippetRepositoryManager + """ + repo = key_snippet_repo_var.get() + if repo is None: + raise RuntimeError( + "No KeySnippetRepository available. " + "Make sure to initialize one with KeySnippetRepositoryManager first." + ) + return repo + class KeySnippetRepository: """ @@ -24,23 +107,26 @@ class KeySnippetRepository: abstracting the database access details from the business logic. Example: - repo = KeySnippetRepository() - snippet = repo.create( - filepath="main.py", - line_number=42, - snippet="def hello_world():", - description="Main function definition" - ) - all_snippets = repo.get_all() + with DatabaseManager() as db: + with KeySnippetRepositoryManager(db) as repo: + snippet = repo.create( + filepath="main.py", + line_number=42, + snippet="def hello_world():", + description="Main function definition" + ) + all_snippets = repo.get_all() """ - def __init__(self, db=None): + def __init__(self, db): """ - Initialize the repository with an optional database connection. + Initialize the repository with a database connection. Args: - db: Optional database connection to use. If None, will use initialize_database() + db: Database connection to use (required) """ + if db is None: + raise ValueError("Database connection is required for KeySnippetRepository") self.db = db def create( @@ -64,7 +150,6 @@ class KeySnippetRepository: peewee.DatabaseError: If there's an error creating the snippet """ try: - db = self.db if self.db is not None else initialize_database() key_snippet = KeySnippet.create( filepath=filepath, line_number=line_number, @@ -92,7 +177,6 @@ class KeySnippetRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - db = self.db if self.db is not None else initialize_database() return KeySnippet.get_or_none(KeySnippet.id == snippet_id) except peewee.DatabaseError as e: logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}") @@ -123,7 +207,6 @@ class KeySnippetRepository: peewee.DatabaseError: If there's an error updating the snippet """ try: - db = self.db if self.db is not None else initialize_database() # First check if the snippet exists key_snippet = self.get(snippet_id) if not key_snippet: @@ -156,7 +239,6 @@ class KeySnippetRepository: peewee.DatabaseError: If there's an error deleting the snippet """ try: - db = self.db if self.db is not None else initialize_database() # First check if the snippet exists key_snippet = self.get(snippet_id) if not key_snippet: @@ -182,7 +264,6 @@ class KeySnippetRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - db = self.db if self.db is not None else initialize_database() return list(KeySnippet.select().order_by(KeySnippet.id)) except peewee.DatabaseError as e: logger.error(f"Failed to fetch all key snippets: {str(e)}") @@ -214,21 +295,4 @@ class KeySnippetRepository: } except peewee.DatabaseError as e: logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}") - raise - - -# Global singleton instance -_key_snippet_repository = None - - -def get_key_snippet_repository() -> KeySnippetRepository: - """ - Get or create a singleton instance of KeySnippetRepository. - - Returns: - KeySnippetRepository: Singleton instance of the repository - """ - global _key_snippet_repository - if _key_snippet_repository is None: - _key_snippet_repository = KeySnippetRepository() - return _key_snippet_repository \ No newline at end of file + raise \ No newline at end of file