diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index dfd7bfa..bf42ec8 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -53,6 +53,9 @@ from ra_aid.database.repositories.human_input_repository import ( from ra_aid.database.repositories.research_note_repository import ( ResearchNoteRepositoryManager, get_research_note_repository ) +from ra_aid.database.repositories.trajectory_repository import ( + TrajectoryRepositoryManager, get_trajectory_repository +) from ra_aid.database.repositories.related_files_repository import ( RelatedFilesRepositoryManager ) @@ -528,6 +531,7 @@ def main(): HumanInputRepositoryManager(db) as human_input_repo, \ ResearchNoteRepositoryManager(db) as research_note_repo, \ RelatedFilesRepositoryManager() as related_files_repo, \ + TrajectoryRepositoryManager(db) as trajectory_repo, \ WorkLogRepositoryManager() as work_log_repo, \ ConfigRepositoryManager(config) as config_repo, \ EnvInvManager(env_data) as env_inv: @@ -537,6 +541,7 @@ def main(): logger.debug("Initialized HumanInputRepository") logger.debug("Initialized ResearchNoteRepository") logger.debug("Initialized RelatedFilesRepository") + logger.debug("Initialized TrajectoryRepository") logger.debug("Initialized WorkLogRepository") logger.debug("Initialized ConfigRepository") logger.debug("Initialized Environment Inventory") diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py index f40d0e1..9cfe0a5 100644 --- a/ra_aid/database/models.py +++ b/ra_aid/database/models.py @@ -42,8 +42,8 @@ def initialize_database(): # to avoid circular imports # Note: This import needs to be here, not at the top level try: - from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote - db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote], safe=True) + from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory + db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True) logger.debug("Ensured database tables exist") except Exception as e: logger.error(f"Error creating tables: {str(e)}") @@ -162,4 +162,33 @@ class ResearchNote(BaseModel): # created_at and updated_at are inherited from BaseModel class Meta: - table_name = "research_note" \ No newline at end of file + table_name = "research_note" + + +class Trajectory(BaseModel): + """ + Model representing an agent trajectory stored in the database. + + Trajectories track the sequence of actions taken by agents, including + tool executions and their results. This enables analysis of agent behavior, + debugging of issues, and reconstruction of the decision-making process. + + Each trajectory record captures details about a single tool execution: + - Which tool was used + - What parameters were passed to the tool + - What result was returned by the tool + - UI rendering data for displaying the tool execution + - Cost and token usage metrics (placeholders for future implementation) + """ + human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True) + tool_name = peewee.TextField() + tool_parameters = peewee.TextField() # JSON-encoded parameters + tool_result = peewee.TextField() # JSON-encoded result + step_data = peewee.TextField() # JSON-encoded UI rendering data + record_type = peewee.TextField() # Type of trajectory record + cost = peewee.FloatField(null=True) # Placeholder for cost tracking + tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking + # created_at and updated_at are inherited from BaseModel + + class Meta: + table_name = "trajectory" \ No newline at end of file diff --git a/ra_aid/database/repositories/trajectory_repository.py b/ra_aid/database/repositories/trajectory_repository.py new file mode 100644 index 0000000..14498c0 --- /dev/null +++ b/ra_aid/database/repositories/trajectory_repository.py @@ -0,0 +1,381 @@ +""" +Trajectory repository implementation for database access. + +This module provides a repository implementation for the Trajectory model, +following the repository pattern for data access abstraction. It handles +operations for storing and retrieving agent action trajectories. +""" + +from typing import Dict, List, Optional, Any, Union +import contextvars +import json +import logging + +import peewee + +from ra_aid.database.models import Trajectory, HumanInput +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) + +# Create contextvar to hold the TrajectoryRepository instance +trajectory_repo_var = contextvars.ContextVar("trajectory_repo", default=None) + + +class TrajectoryRepositoryManager: + """ + Context manager for TrajectoryRepository. + + This class provides a context manager interface for TrajectoryRepository, + using the contextvars approach for thread safety. + + Example: + with DatabaseManager() as db: + with TrajectoryRepositoryManager(db) as repo: + # Use the repository + trajectory = repo.create( + tool_name="ripgrep_search", + tool_parameters={"pattern": "example"} + ) + all_trajectories = repo.get_all() + """ + + def __init__(self, db): + """ + Initialize the TrajectoryRepositoryManager. + + Args: + db: Database connection to use (required) + """ + self.db = db + + def __enter__(self) -> 'TrajectoryRepository': + """ + Initialize the TrajectoryRepository and return it. + + Returns: + TrajectoryRepository: The initialized repository + """ + repo = TrajectoryRepository(self.db) + trajectory_repo_var.set(repo) + return repo + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Optional[object], + ) -> None: + """ + Reset the repository when exiting the context. + + Args: + exc_type: The exception type if an exception was raised + exc_val: The exception value if an exception was raised + exc_tb: The traceback if an exception was raised + """ + # Reset the contextvar to None + trajectory_repo_var.set(None) + + # Don't suppress exceptions + return False + + +def get_trajectory_repository() -> 'TrajectoryRepository': + """ + Get the current TrajectoryRepository instance. + + Returns: + TrajectoryRepository: The current repository instance + + Raises: + RuntimeError: If no repository has been initialized with TrajectoryRepositoryManager + """ + repo = trajectory_repo_var.get() + if repo is None: + raise RuntimeError( + "No TrajectoryRepository available. " + "Make sure to initialize one with TrajectoryRepositoryManager first." + ) + return repo + + +class TrajectoryRepository: + """ + Repository for managing Trajectory database operations. + + This class provides methods for performing CRUD operations on the Trajectory model, + abstracting the database access details from the business logic. It handles + serialization and deserialization of JSON fields for tool parameters, results, + and UI rendering data. + + Example: + with DatabaseManager() as db: + with TrajectoryRepositoryManager(db) as repo: + trajectory = repo.create( + tool_name="ripgrep_search", + tool_parameters={"pattern": "example"} + ) + all_trajectories = repo.get_all() + """ + + def __init__(self, db): + """ + Initialize the repository with a database connection. + + Args: + db: Database connection to use (required) + """ + if db is None: + raise ValueError("Database connection is required for TrajectoryRepository") + self.db = db + + def create( + self, + tool_name: str, + tool_parameters: Dict[str, Any], + tool_result: Optional[Dict[str, Any]] = None, + step_data: Optional[Dict[str, Any]] = None, + record_type: str = "tool_execution", + human_input_id: Optional[int] = None, + cost: Optional[float] = None, + tokens: Optional[int] = None + ) -> Trajectory: + """ + Create a new trajectory record in the database. + + Args: + tool_name: Name of the tool that was executed + tool_parameters: Parameters passed to the tool (will be JSON encoded) + tool_result: Result returned by the tool (will be JSON encoded) + step_data: UI rendering data (will be JSON encoded) + record_type: Type of trajectory record + human_input_id: Optional ID of the associated human input + cost: Optional cost of the operation (placeholder) + tokens: Optional token usage (placeholder) + + Returns: + Trajectory: The newly created trajectory instance + + Raises: + peewee.DatabaseError: If there's an error creating the record + """ + try: + # Serialize JSON fields + tool_parameters_json = json.dumps(tool_parameters) + tool_result_json = json.dumps(tool_result) if tool_result is not None else None + step_data_json = json.dumps(step_data) if step_data is not None else None + + # Create human input reference if provided + human_input = None + if human_input_id is not None: + try: + human_input = HumanInput.get_by_id(human_input_id) + except peewee.DoesNotExist: + logger.warning(f"Human input with ID {human_input_id} not found") + + # Create the trajectory record + trajectory = Trajectory.create( + human_input=human_input, + tool_name=tool_name, + tool_parameters=tool_parameters_json, + tool_result=tool_result_json, + step_data=step_data_json, + record_type=record_type, + cost=cost, + tokens=tokens + ) + logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}") + return trajectory + except peewee.DatabaseError as e: + logger.error(f"Failed to create trajectory record: {str(e)}") + raise + + def get(self, trajectory_id: int) -> Optional[Trajectory]: + """ + Retrieve a trajectory record by its ID. + + Args: + trajectory_id: The ID of the trajectory record to retrieve + + Returns: + Optional[Trajectory]: The trajectory instance if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + return Trajectory.get_or_none(Trajectory.id == trajectory_id) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}") + raise + + def update( + self, + trajectory_id: int, + tool_result: Optional[Dict[str, Any]] = None, + step_data: Optional[Dict[str, Any]] = None, + cost: Optional[float] = None, + tokens: Optional[int] = None + ) -> Optional[Trajectory]: + """ + Update an existing trajectory record. + + This is typically used to update the result or metrics after tool execution completes. + + Args: + trajectory_id: The ID of the trajectory record to update + tool_result: Updated tool result (will be JSON encoded) + step_data: Updated UI rendering data (will be JSON encoded) + cost: Updated cost information + tokens: Updated token usage information + + Returns: + Optional[Trajectory]: The updated trajectory if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error updating the record + """ + try: + # First check if the trajectory exists + trajectory = self.get(trajectory_id) + if not trajectory: + logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}") + return None + + # Update the fields if provided + update_data = {} + + if tool_result is not None: + update_data["tool_result"] = json.dumps(tool_result) + + if step_data is not None: + update_data["step_data"] = json.dumps(step_data) + + if cost is not None: + update_data["cost"] = cost + + if tokens is not None: + update_data["tokens"] = tokens + + if update_data: + query = Trajectory.update(**update_data).where(Trajectory.id == trajectory_id) + query.execute() + logger.debug(f"Updated trajectory record ID {trajectory_id}") + return self.get(trajectory_id) + + return trajectory + except peewee.DatabaseError as e: + logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}") + raise + + def delete(self, trajectory_id: int) -> bool: + """ + Delete a trajectory record by its ID. + + Args: + trajectory_id: The ID of the trajectory record to delete + + Returns: + bool: True if the record was deleted, False if it wasn't found + + Raises: + peewee.DatabaseError: If there's an error deleting the record + """ + try: + # First check if the trajectory exists + trajectory = self.get(trajectory_id) + if not trajectory: + logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}") + return False + + # Delete the trajectory + trajectory.delete_instance() + logger.debug(f"Deleted trajectory record ID {trajectory_id}") + return True + except peewee.DatabaseError as e: + logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}") + raise + + def get_all(self) -> Dict[int, Trajectory]: + """ + Retrieve all trajectory records from the database. + + Returns: + Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)} + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch all trajectories: {str(e)}") + raise + + def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]: + """ + Retrieve all trajectory records associated with a specific human input. + + Args: + human_input_id: The ID of the human input to get trajectories for + + Returns: + List[Trajectory]: List of trajectory instances associated with the human input + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id)) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}") + raise + + def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]: + """ + Parse a JSON string into a Python dictionary. + + Args: + json_str: JSON string to parse + + Returns: + Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid + """ + if not json_str: + return None + + try: + return json.loads(json_str) + except json.JSONDecodeError as e: + logger.error(f"Error parsing JSON field: {str(e)}") + return None + + def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]: + """ + Get a trajectory record with JSON fields parsed into dictionaries. + + Args: + trajectory_id: ID of the trajectory to retrieve + + Returns: + Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields, + or None if not found + """ + trajectory = self.get(trajectory_id) + if trajectory is None: + return None + + return { + "id": trajectory.id, + "created_at": trajectory.created_at, + "updated_at": trajectory.updated_at, + "tool_name": trajectory.tool_name, + "tool_parameters": self.parse_json_field(trajectory.tool_parameters), + "tool_result": self.parse_json_field(trajectory.tool_result), + "step_data": self.parse_json_field(trajectory.step_data), + "record_type": trajectory.record_type, + "cost": trajectory.cost, + "tokens": trajectory.tokens, + "human_input_id": trajectory.human_input.id if trajectory.human_input else None, + } \ No newline at end of file diff --git a/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py b/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py new file mode 100644 index 0000000..b62ecc9 --- /dev/null +++ b/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py @@ -0,0 +1,100 @@ +"""Peewee migrations -- 007_20250310_184046_add_trajectory_model.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Create the trajectory table for storing agent action trajectories.""" + + # Check if the table already exists + try: + database.execute_sql("SELECT id FROM trajectory LIMIT 1") + # If we reach here, the table exists + return + except pw.OperationalError: + # Table doesn't exist, safe to create + pass + + @migrator.create_model + class Trajectory(pw.Model): + id = pw.AutoField() + created_at = pw.DateTimeField() + updated_at = pw.DateTimeField() + tool_name = pw.TextField() + tool_parameters = pw.TextField() # JSON-encoded parameters + tool_result = pw.TextField() # JSON-encoded result + step_data = pw.TextField() # JSON-encoded UI rendering data + record_type = pw.TextField() # Type of trajectory record + cost = pw.FloatField(null=True) # Placeholder for cost tracking + tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking + # We'll add the human_input foreign key in a separate step for safety + + class Meta: + table_name = "trajectory" + + # Check if HumanInput model exists before adding the foreign key + try: + HumanInput = migrator.orm['human_input'] + + # Only add the foreign key if the human_input_id column doesn't already exist + try: + database.execute_sql("SELECT human_input_id FROM trajectory LIMIT 1") + except pw.OperationalError: + # Column doesn't exist, safe to add + migrator.add_fields( + 'trajectory', + human_input=pw.ForeignKeyField( + HumanInput, + null=True, + field='id', + on_delete='SET NULL' + ) + ) + except KeyError: + # HumanInput doesn't exist, we'll skip adding the foreign key + pass + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove the trajectory table.""" + + # First remove any foreign key fields + try: + migrator.remove_fields('trajectory', 'human_input') + except pw.OperationalError: + # Field might not exist, that's fine + pass + + # Then remove the model + migrator.remove_model('trajectory') \ No newline at end of file