key snippets context var
This commit is contained in:
parent
ffd1ef15d4
commit
fc58aa0b77
|
|
@ -43,7 +43,9 @@ from ra_aid.config import (
|
||||||
VALID_PROVIDERS,
|
VALID_PROVIDERS,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
from ra_aid.database.repositories.key_snippet_repository import (
|
||||||
|
KeySnippetRepositoryManager, get_key_snippet_repository
|
||||||
|
)
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
|
|
@ -394,9 +396,10 @@ def main():
|
||||||
logger.error(f"Database migration error: {str(e)}")
|
logger.error(f"Database migration error: {str(e)}")
|
||||||
|
|
||||||
# Initialize repositories with database connection
|
# Initialize repositories with database connection
|
||||||
with KeyFactRepositoryManager(db) as key_fact_repo:
|
with KeyFactRepositoryManager(db) as key_fact_repo, KeySnippetRepositoryManager(db) as key_snippet_repo:
|
||||||
# This initializes the repository and makes it available via get_key_fact_repository()
|
# This initializes both repositories and makes them available via their respective get methods
|
||||||
logger.debug("Initialized KeyFactRepository")
|
logger.debug("Initialized KeyFactRepository")
|
||||||
|
logger.debug("Initialized KeySnippetRepository")
|
||||||
|
|
||||||
# Check dependencies before proceeding
|
# Check dependencies before proceeding
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
@ -533,7 +536,7 @@ def main():
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
||||||
key_snippets=format_key_snippets_dict(KeySnippetRepository(db).get_snippets_dict()),
|
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||||
project_info=formatted_project_info,
|
project_info=formatted_project_info,
|
||||||
),
|
),
|
||||||
config,
|
config,
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,98 @@ following the repository pattern for data access abstraction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any
|
||||||
|
import contextvars
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
from ra_aid.database.connection import get_db
|
from ra_aid.database.models import KeySnippet
|
||||||
from ra_aid.database.models import KeySnippet, initialize_database
|
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Create contextvar to hold the KeySnippetRepository instance
|
||||||
|
key_snippet_repo_var = contextvars.ContextVar("key_snippet_repo", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class KeySnippetRepositoryManager:
|
||||||
|
"""
|
||||||
|
Context manager for KeySnippetRepository.
|
||||||
|
|
||||||
|
This class provides a context manager interface for KeySnippetRepository,
|
||||||
|
using the contextvars approach for thread safety.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with DatabaseManager() as db:
|
||||||
|
with KeySnippetRepositoryManager(db) as repo:
|
||||||
|
# Use the repository
|
||||||
|
snippet = repo.create(
|
||||||
|
filepath="main.py",
|
||||||
|
line_number=42,
|
||||||
|
snippet="def hello_world():",
|
||||||
|
description="Main function definition"
|
||||||
|
)
|
||||||
|
all_snippets = repo.get_all()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db):
|
||||||
|
"""
|
||||||
|
Initialize the KeySnippetRepositoryManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database connection to use (required)
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def __enter__(self) -> 'KeySnippetRepository':
|
||||||
|
"""
|
||||||
|
Initialize the KeySnippetRepository and return it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeySnippetRepository: The initialized repository
|
||||||
|
"""
|
||||||
|
repo = KeySnippetRepository(self.db)
|
||||||
|
key_snippet_repo_var.set(repo)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type],
|
||||||
|
exc_val: Optional[Exception],
|
||||||
|
exc_tb: Optional[object],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reset the repository when exiting the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: The exception type if an exception was raised
|
||||||
|
exc_val: The exception value if an exception was raised
|
||||||
|
exc_tb: The traceback if an exception was raised
|
||||||
|
"""
|
||||||
|
# Reset the contextvar to None
|
||||||
|
key_snippet_repo_var.set(None)
|
||||||
|
|
||||||
|
# Don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_snippet_repository() -> 'KeySnippetRepository':
|
||||||
|
"""
|
||||||
|
Get the current KeySnippetRepository instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeySnippetRepository: The current repository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no repository has been initialized with KeySnippetRepositoryManager
|
||||||
|
"""
|
||||||
|
repo = key_snippet_repo_var.get()
|
||||||
|
if repo is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No KeySnippetRepository available. "
|
||||||
|
"Make sure to initialize one with KeySnippetRepositoryManager first."
|
||||||
|
)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
class KeySnippetRepository:
|
class KeySnippetRepository:
|
||||||
"""
|
"""
|
||||||
|
|
@ -24,23 +107,26 @@ class KeySnippetRepository:
|
||||||
abstracting the database access details from the business logic.
|
abstracting the database access details from the business logic.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
repo = KeySnippetRepository()
|
with DatabaseManager() as db:
|
||||||
snippet = repo.create(
|
with KeySnippetRepositoryManager(db) as repo:
|
||||||
filepath="main.py",
|
snippet = repo.create(
|
||||||
line_number=42,
|
filepath="main.py",
|
||||||
snippet="def hello_world():",
|
line_number=42,
|
||||||
description="Main function definition"
|
snippet="def hello_world():",
|
||||||
)
|
description="Main function definition"
|
||||||
all_snippets = repo.get_all()
|
)
|
||||||
|
all_snippets = repo.get_all()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db=None):
|
def __init__(self, db):
|
||||||
"""
|
"""
|
||||||
Initialize the repository with an optional database connection.
|
Initialize the repository with a database connection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Optional database connection to use. If None, will use initialize_database()
|
db: Database connection to use (required)
|
||||||
"""
|
"""
|
||||||
|
if db is None:
|
||||||
|
raise ValueError("Database connection is required for KeySnippetRepository")
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
|
|
@ -64,7 +150,6 @@ class KeySnippetRepository:
|
||||||
peewee.DatabaseError: If there's an error creating the snippet
|
peewee.DatabaseError: If there's an error creating the snippet
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
key_snippet = KeySnippet.create(
|
key_snippet = KeySnippet.create(
|
||||||
filepath=filepath,
|
filepath=filepath,
|
||||||
line_number=line_number,
|
line_number=line_number,
|
||||||
|
|
@ -92,7 +177,6 @@ class KeySnippetRepository:
|
||||||
peewee.DatabaseError: If there's an error accessing the database
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
|
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
|
||||||
|
|
@ -123,7 +207,6 @@ class KeySnippetRepository:
|
||||||
peewee.DatabaseError: If there's an error updating the snippet
|
peewee.DatabaseError: If there's an error updating the snippet
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
# First check if the snippet exists
|
# First check if the snippet exists
|
||||||
key_snippet = self.get(snippet_id)
|
key_snippet = self.get(snippet_id)
|
||||||
if not key_snippet:
|
if not key_snippet:
|
||||||
|
|
@ -156,7 +239,6 @@ class KeySnippetRepository:
|
||||||
peewee.DatabaseError: If there's an error deleting the snippet
|
peewee.DatabaseError: If there's an error deleting the snippet
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
# First check if the snippet exists
|
# First check if the snippet exists
|
||||||
key_snippet = self.get(snippet_id)
|
key_snippet = self.get(snippet_id)
|
||||||
if not key_snippet:
|
if not key_snippet:
|
||||||
|
|
@ -182,7 +264,6 @@ class KeySnippetRepository:
|
||||||
peewee.DatabaseError: If there's an error accessing the database
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = self.db if self.db is not None else initialize_database()
|
|
||||||
return list(KeySnippet.select().order_by(KeySnippet.id))
|
return list(KeySnippet.select().order_by(KeySnippet.id))
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch all key snippets: {str(e)}")
|
logger.error(f"Failed to fetch all key snippets: {str(e)}")
|
||||||
|
|
@ -214,21 +295,4 @@ class KeySnippetRepository:
|
||||||
}
|
}
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
|
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# Global singleton instance
|
|
||||||
_key_snippet_repository = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_key_snippet_repository() -> KeySnippetRepository:
|
|
||||||
"""
|
|
||||||
Get or create a singleton instance of KeySnippetRepository.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
KeySnippetRepository: Singleton instance of the repository
|
|
||||||
"""
|
|
||||||
global _key_snippet_repository
|
|
||||||
if _key_snippet_repository is None:
|
|
||||||
_key_snippet_repository = KeySnippetRepository()
|
|
||||||
return _key_snippet_repository
|
|
||||||
Loading…
Reference in New Issue