420 lines
15 KiB
Python
420 lines
15 KiB
Python
"""
|
|
Trajectory repository implementation for database access.
|
|
|
|
This module provides a repository implementation for the Trajectory model,
|
|
following the repository pattern for data access abstraction. It handles
|
|
operations for storing and retrieving agent action trajectories.
|
|
"""
|
|
|
|
from typing import Dict, List, Optional, Any, Union
|
|
import contextvars
|
|
import json
|
|
import logging
|
|
|
|
import peewee
|
|
|
|
from ra_aid.database.models import Trajectory, HumanInput
|
|
from ra_aid.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Create contextvar to hold the TrajectoryRepository instance
|
|
trajectory_repo_var = contextvars.ContextVar("trajectory_repo", default=None)
|
|
|
|
|
|
class TrajectoryRepositoryManager:
|
|
"""
|
|
Context manager for TrajectoryRepository.
|
|
|
|
This class provides a context manager interface for TrajectoryRepository,
|
|
using the contextvars approach for thread safety.
|
|
|
|
Example:
|
|
with DatabaseManager() as db:
|
|
with TrajectoryRepositoryManager(db) as repo:
|
|
# Use the repository
|
|
trajectory = repo.create(
|
|
tool_name="ripgrep_search",
|
|
tool_parameters={"pattern": "example"}
|
|
)
|
|
all_trajectories = repo.get_all()
|
|
"""
|
|
|
|
def __init__(self, db):
|
|
"""
|
|
Initialize the TrajectoryRepositoryManager.
|
|
|
|
Args:
|
|
db: Database connection to use (required)
|
|
"""
|
|
self.db = db
|
|
|
|
def __enter__(self) -> 'TrajectoryRepository':
|
|
"""
|
|
Initialize the TrajectoryRepository and return it.
|
|
|
|
Returns:
|
|
TrajectoryRepository: The initialized repository
|
|
"""
|
|
repo = TrajectoryRepository(self.db)
|
|
trajectory_repo_var.set(repo)
|
|
return repo
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: Optional[type],
|
|
exc_val: Optional[Exception],
|
|
exc_tb: Optional[object],
|
|
) -> None:
|
|
"""
|
|
Reset the repository when exiting the context.
|
|
|
|
Args:
|
|
exc_type: The exception type if an exception was raised
|
|
exc_val: The exception value if an exception was raised
|
|
exc_tb: The traceback if an exception was raised
|
|
"""
|
|
# Reset the contextvar to None
|
|
trajectory_repo_var.set(None)
|
|
|
|
# Don't suppress exceptions
|
|
return False
|
|
|
|
|
|
def get_trajectory_repository() -> 'TrajectoryRepository':
|
|
"""
|
|
Get the current TrajectoryRepository instance.
|
|
|
|
Returns:
|
|
TrajectoryRepository: The current repository instance
|
|
|
|
Raises:
|
|
RuntimeError: If no repository has been initialized with TrajectoryRepositoryManager
|
|
"""
|
|
repo = trajectory_repo_var.get()
|
|
if repo is None:
|
|
raise RuntimeError(
|
|
"No TrajectoryRepository available. "
|
|
"Make sure to initialize one with TrajectoryRepositoryManager first."
|
|
)
|
|
return repo
|
|
|
|
|
|
class TrajectoryRepository:
|
|
"""
|
|
Repository for managing Trajectory database operations.
|
|
|
|
This class provides methods for performing CRUD operations on the Trajectory model,
|
|
abstracting the database access details from the business logic. It handles
|
|
serialization and deserialization of JSON fields for tool parameters, results,
|
|
and UI rendering data.
|
|
|
|
Example:
|
|
with DatabaseManager() as db:
|
|
with TrajectoryRepositoryManager(db) as repo:
|
|
trajectory = repo.create(
|
|
tool_name="ripgrep_search",
|
|
tool_parameters={"pattern": "example"}
|
|
)
|
|
all_trajectories = repo.get_all()
|
|
"""
|
|
|
|
def __init__(self, db):
|
|
"""
|
|
Initialize the repository with a database connection.
|
|
|
|
Args:
|
|
db: Database connection to use (required)
|
|
"""
|
|
if db is None:
|
|
raise ValueError("Database connection is required for TrajectoryRepository")
|
|
self.db = db
|
|
|
|
def create(
|
|
self,
|
|
tool_name: Optional[str] = None,
|
|
tool_parameters: Optional[Dict[str, Any]] = None,
|
|
tool_result: Optional[Dict[str, Any]] = None,
|
|
step_data: Optional[Dict[str, Any]] = None,
|
|
record_type: str = "tool_execution",
|
|
human_input_id: Optional[int] = None,
|
|
cost: Optional[float] = None,
|
|
tokens: Optional[int] = None,
|
|
is_error: bool = False,
|
|
error_message: Optional[str] = None,
|
|
error_type: Optional[str] = None,
|
|
error_details: Optional[str] = None
|
|
) -> Trajectory:
|
|
"""
|
|
Create a new trajectory record in the database.
|
|
|
|
Args:
|
|
tool_name: Optional name of the tool that was executed
|
|
tool_parameters: Optional parameters passed to the tool (will be JSON encoded)
|
|
tool_result: Result returned by the tool (will be JSON encoded)
|
|
step_data: UI rendering data (will be JSON encoded)
|
|
record_type: Type of trajectory record
|
|
human_input_id: Optional ID of the associated human input
|
|
cost: Optional cost of the operation (placeholder)
|
|
tokens: Optional token usage (placeholder)
|
|
is_error: Flag indicating if this record represents an error (default: False)
|
|
error_message: The error message (if is_error is True)
|
|
error_type: The type/class of the error (if is_error is True)
|
|
error_details: Additional error details like stack traces (if is_error is True)
|
|
|
|
Returns:
|
|
Trajectory: The newly created trajectory instance
|
|
|
|
Raises:
|
|
peewee.DatabaseError: If there's an error creating the record
|
|
"""
|
|
try:
|
|
# Serialize JSON fields
|
|
tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None
|
|
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
|
step_data_json = json.dumps(step_data) if step_data is not None else None
|
|
|
|
# Create human input reference if provided
|
|
human_input = None
|
|
if human_input_id is not None:
|
|
try:
|
|
human_input = HumanInput.get_by_id(human_input_id)
|
|
except peewee.DoesNotExist:
|
|
logger.warning(f"Human input with ID {human_input_id} not found")
|
|
|
|
# Create the trajectory record
|
|
trajectory = Trajectory.create(
|
|
human_input=human_input,
|
|
tool_name=tool_name or "", # Use empty string if tool_name is None
|
|
tool_parameters=tool_parameters_json,
|
|
tool_result=tool_result_json,
|
|
step_data=step_data_json,
|
|
record_type=record_type,
|
|
cost=cost,
|
|
tokens=tokens,
|
|
is_error=is_error,
|
|
error_message=error_message,
|
|
error_type=error_type,
|
|
error_details=error_details
|
|
)
|
|
if tool_name:
|
|
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
|
else:
|
|
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
|
|
return trajectory
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Failed to create trajectory record: {str(e)}")
|
|
raise
|
|
|
|
def get(self, trajectory_id: int) -> Optional[Trajectory]:
|
|
"""
|
|
Retrieve a trajectory record by its ID.
|
|
|
|
Args:
|
|
trajectory_id: The ID of the trajectory record to retrieve
|
|
|
|
Returns:
|
|
Optional[Trajectory]: The trajectory instance if found, None otherwise
|
|
|
|
Raises:
|
|
peewee.DatabaseError: If there's an error accessing the database
|
|
"""
|
|
try:
|
|
return Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}")
|
|
raise
|
|
|
|
def update(
|
|
self,
|
|
trajectory_id: int,
|
|
tool_result: Optional[Dict[str, Any]] = None,
|
|
step_data: Optional[Dict[str, Any]] = None,
|
|
cost: Optional[float] = None,
|
|
tokens: Optional[int] = None,
|
|
is_error: Optional[bool] = None,
|
|
error_message: Optional[str] = None,
|
|
error_type: Optional[str] = None,
|
|
error_details: Optional[str] = None
|
|
) -> Optional[Trajectory]:
|
|
"""
|
|
Update an existing trajectory record.
|
|
|
|
This is typically used to update the result or metrics after tool execution completes.
|
|
|
|
Args:
|
|
trajectory_id: The ID of the trajectory record to update
|
|
tool_result: Updated tool result (will be JSON encoded)
|
|
step_data: Updated UI rendering data (will be JSON encoded)
|
|
cost: Updated cost information
|
|
tokens: Updated token usage information
|
|
is_error: Flag indicating if this record represents an error
|
|
error_message: The error message
|
|
error_type: The type/class of the error
|
|
error_details: Additional error details like stack traces
|
|
|
|
Returns:
|
|
Optional[Trajectory]: The updated trajectory if found, None otherwise
|
|
|
|
Raises:
|
|
peewee.DatabaseError: If there's an error updating the record
|
|
"""
|
|
try:
|
|
# First check if the trajectory exists
|
|
trajectory = self.get(trajectory_id)
|
|
if not trajectory:
|
|
logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}")
|
|
return None
|
|
|
|
# Update the fields if provided
|
|
update_data = {}
|
|
|
|
if tool_result is not None:
|
|
update_data["tool_result"] = json.dumps(tool_result)
|
|
|
|
if step_data is not None:
|
|
update_data["step_data"] = json.dumps(step_data)
|
|
|
|
if cost is not None:
|
|
update_data["cost"] = cost
|
|
|
|
if tokens is not None:
|
|
update_data["tokens"] = tokens
|
|
|
|
if is_error is not None:
|
|
update_data["is_error"] = is_error
|
|
|
|
if error_message is not None:
|
|
update_data["error_message"] = error_message
|
|
|
|
if error_type is not None:
|
|
update_data["error_type"] = error_type
|
|
|
|
if error_details is not None:
|
|
update_data["error_details"] = error_details
|
|
|
|
if update_data:
|
|
query = Trajectory.update(**update_data).where(Trajectory.id == trajectory_id)
|
|
query.execute()
|
|
logger.debug(f"Updated trajectory record ID {trajectory_id}")
|
|
return self.get(trajectory_id)
|
|
|
|
return trajectory
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}")
|
|
raise
|
|
|
|
def delete(self, trajectory_id: int) -> bool:
|
|
"""
|
|
Delete a trajectory record by its ID.
|
|
|
|
Args:
|
|
trajectory_id: The ID of the trajectory record to delete
|
|
|
|
Returns:
|
|
bool: True if the record was deleted, False if it wasn't found
|
|
|
|
Raises:
|
|
peewee.DatabaseError: If there's an error deleting the record
|
|
"""
|
|
try:
|
|
# First check if the trajectory exists
|
|
trajectory = self.get(trajectory_id)
|
|
if not trajectory:
|
|
logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}")
|
|
return False
|
|
|
|
# Delete the trajectory
|
|
trajectory.delete_instance()
|
|
logger.debug(f"Deleted trajectory record ID {trajectory_id}")
|
|
return True
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}")
|
|
raise
|
|
|
|
def get_all(self) -> Dict[int, Trajectory]:
|
|
"""
|
|
Retrieve all trajectory records from the database.
|
|
|
|
Returns:
|
|
Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances
|
|
|
|
Raises:
|
|
peewee.DatabaseError: If there's an error accessing the database
|
|
"""
|
|
try:
|
|
return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)}
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Failed to fetch all trajectories: {str(e)}")
|
|
raise
|
|
|
|
def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]:
|
|
"""
|
|
Retrieve all trajectory records associated with a specific human input.
|
|
|
|
Args:
|
|
human_input_id: The ID of the human input to get trajectories for
|
|
|
|
Returns:
|
|
List[Trajectory]: List of trajectory instances associated with the human input
|
|
|
|
Raises:
|
|
peewee.DatabaseError: If there's an error accessing the database
|
|
"""
|
|
try:
|
|
return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}")
|
|
raise
|
|
|
|
def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Parse a JSON string into a Python dictionary.
|
|
|
|
Args:
|
|
json_str: JSON string to parse
|
|
|
|
Returns:
|
|
Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid
|
|
"""
|
|
if not json_str:
|
|
return None
|
|
|
|
try:
|
|
return json.loads(json_str)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Error parsing JSON field: {str(e)}")
|
|
return None
|
|
|
|
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get a trajectory record with JSON fields parsed into dictionaries.
|
|
|
|
Args:
|
|
trajectory_id: ID of the trajectory to retrieve
|
|
|
|
Returns:
|
|
Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields,
|
|
or None if not found
|
|
"""
|
|
trajectory = self.get(trajectory_id)
|
|
if trajectory is None:
|
|
return None
|
|
|
|
return {
|
|
"id": trajectory.id,
|
|
"created_at": trajectory.created_at,
|
|
"updated_at": trajectory.updated_at,
|
|
"tool_name": trajectory.tool_name,
|
|
"tool_parameters": self.parse_json_field(trajectory.tool_parameters),
|
|
"tool_result": self.parse_json_field(trajectory.tool_result),
|
|
"step_data": self.parse_json_field(trajectory.step_data),
|
|
"record_type": trajectory.record_type,
|
|
"cost": trajectory.cost,
|
|
"tokens": trajectory.tokens,
|
|
"human_input_id": trajectory.human_input.id if trajectory.human_input else None,
|
|
"is_error": trajectory.is_error,
|
|
"error_message": trajectory.error_message,
|
|
"error_type": trajectory.error_type,
|
|
"error_details": trajectory.error_details,
|
|
} |