add trajectory table

This commit is contained in:
AI Christianson 2025-03-10 18:59:42 -04:00
parent a18998be0d
commit 78983ec20b
4 changed files with 518 additions and 3 deletions

View File

@ -53,6 +53,9 @@ from ra_aid.database.repositories.human_input_repository import (
from ra_aid.database.repositories.research_note_repository import (
ResearchNoteRepositoryManager, get_research_note_repository
)
from ra_aid.database.repositories.trajectory_repository import (
TrajectoryRepositoryManager, get_trajectory_repository
)
from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepositoryManager
)
@ -528,6 +531,7 @@ def main():
HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo, \
RelatedFilesRepositoryManager() as related_files_repo, \
TrajectoryRepositoryManager(db) as trajectory_repo, \
WorkLogRepositoryManager() as work_log_repo, \
ConfigRepositoryManager(config) as config_repo, \
EnvInvManager(env_data) as env_inv:
@ -537,6 +541,7 @@ def main():
logger.debug("Initialized HumanInputRepository")
logger.debug("Initialized ResearchNoteRepository")
logger.debug("Initialized RelatedFilesRepository")
logger.debug("Initialized TrajectoryRepository")
logger.debug("Initialized WorkLogRepository")
logger.debug("Initialized ConfigRepository")
logger.debug("Initialized Environment Inventory")

View File

@ -42,8 +42,8 @@ def initialize_database():
# to avoid circular imports
# Note: This import needs to be here, not at the top level
try:
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote], safe=True)
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True)
logger.debug("Ensured database tables exist")
except Exception as e:
logger.error(f"Error creating tables: {str(e)}")
@ -163,3 +163,32 @@ class ResearchNote(BaseModel):
class Meta:
table_name = "research_note"
class Trajectory(BaseModel):
"""
Model representing an agent trajectory stored in the database.
Trajectories track the sequence of actions taken by agents, including
tool executions and their results. This enables analysis of agent behavior,
debugging of issues, and reconstruction of the decision-making process.
Each trajectory record captures details about a single tool execution:
- Which tool was used
- What parameters were passed to the tool
- What result was returned by the tool
- UI rendering data for displaying the tool execution
- Cost and token usage metrics (placeholders for future implementation)
"""
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
tool_name = peewee.TextField()
tool_parameters = peewee.TextField() # JSON-encoded parameters
tool_result = peewee.TextField() # JSON-encoded result
step_data = peewee.TextField() # JSON-encoded UI rendering data
record_type = peewee.TextField() # Type of trajectory record
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "trajectory"

View File

@ -0,0 +1,381 @@
"""
Trajectory repository implementation for database access.
This module provides a repository implementation for the Trajectory model,
following the repository pattern for data access abstraction. It handles
operations for storing and retrieving agent action trajectories.
"""
from typing import Dict, List, Optional, Any, Union
import contextvars
import json
import logging
import peewee
from ra_aid.database.models import Trajectory, HumanInput
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
# Create contextvar to hold the TrajectoryRepository instance
trajectory_repo_var = contextvars.ContextVar("trajectory_repo", default=None)
class TrajectoryRepositoryManager:
"""
Context manager for TrajectoryRepository.
This class provides a context manager interface for TrajectoryRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with TrajectoryRepositoryManager(db) as repo:
# Use the repository
trajectory = repo.create(
tool_name="ripgrep_search",
tool_parameters={"pattern": "example"}
)
all_trajectories = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the TrajectoryRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'TrajectoryRepository':
"""
Initialize the TrajectoryRepository and return it.
Returns:
TrajectoryRepository: The initialized repository
"""
repo = TrajectoryRepository(self.db)
trajectory_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
trajectory_repo_var.set(None)
# Don't suppress exceptions
return False
def get_trajectory_repository() -> 'TrajectoryRepository':
"""
Get the current TrajectoryRepository instance.
Returns:
TrajectoryRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with TrajectoryRepositoryManager
"""
repo = trajectory_repo_var.get()
if repo is None:
raise RuntimeError(
"No TrajectoryRepository available. "
"Make sure to initialize one with TrajectoryRepositoryManager first."
)
return repo
class TrajectoryRepository:
"""
Repository for managing Trajectory database operations.
This class provides methods for performing CRUD operations on the Trajectory model,
abstracting the database access details from the business logic. It handles
serialization and deserialization of JSON fields for tool parameters, results,
and UI rendering data.
Example:
with DatabaseManager() as db:
with TrajectoryRepositoryManager(db) as repo:
trajectory = repo.create(
tool_name="ripgrep_search",
tool_parameters={"pattern": "example"}
)
all_trajectories = repo.get_all()
"""
def __init__(self, db):
"""
Initialize the repository with a database connection.
Args:
db: Database connection to use (required)
"""
if db is None:
raise ValueError("Database connection is required for TrajectoryRepository")
self.db = db
def create(
self,
tool_name: str,
tool_parameters: Dict[str, Any],
tool_result: Optional[Dict[str, Any]] = None,
step_data: Optional[Dict[str, Any]] = None,
record_type: str = "tool_execution",
human_input_id: Optional[int] = None,
cost: Optional[float] = None,
tokens: Optional[int] = None
) -> Trajectory:
"""
Create a new trajectory record in the database.
Args:
tool_name: Name of the tool that was executed
tool_parameters: Parameters passed to the tool (will be JSON encoded)
tool_result: Result returned by the tool (will be JSON encoded)
step_data: UI rendering data (will be JSON encoded)
record_type: Type of trajectory record
human_input_id: Optional ID of the associated human input
cost: Optional cost of the operation (placeholder)
tokens: Optional token usage (placeholder)
Returns:
Trajectory: The newly created trajectory instance
Raises:
peewee.DatabaseError: If there's an error creating the record
"""
try:
# Serialize JSON fields
tool_parameters_json = json.dumps(tool_parameters)
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
step_data_json = json.dumps(step_data) if step_data is not None else None
# Create human input reference if provided
human_input = None
if human_input_id is not None:
try:
human_input = HumanInput.get_by_id(human_input_id)
except peewee.DoesNotExist:
logger.warning(f"Human input with ID {human_input_id} not found")
# Create the trajectory record
trajectory = Trajectory.create(
human_input=human_input,
tool_name=tool_name,
tool_parameters=tool_parameters_json,
tool_result=tool_result_json,
step_data=step_data_json,
record_type=record_type,
cost=cost,
tokens=tokens
)
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
return trajectory
except peewee.DatabaseError as e:
logger.error(f"Failed to create trajectory record: {str(e)}")
raise
def get(self, trajectory_id: int) -> Optional[Trajectory]:
"""
Retrieve a trajectory record by its ID.
Args:
trajectory_id: The ID of the trajectory record to retrieve
Returns:
Optional[Trajectory]: The trajectory instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return Trajectory.get_or_none(Trajectory.id == trajectory_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}")
raise
def update(
self,
trajectory_id: int,
tool_result: Optional[Dict[str, Any]] = None,
step_data: Optional[Dict[str, Any]] = None,
cost: Optional[float] = None,
tokens: Optional[int] = None
) -> Optional[Trajectory]:
"""
Update an existing trajectory record.
This is typically used to update the result or metrics after tool execution completes.
Args:
trajectory_id: The ID of the trajectory record to update
tool_result: Updated tool result (will be JSON encoded)
step_data: Updated UI rendering data (will be JSON encoded)
cost: Updated cost information
tokens: Updated token usage information
Returns:
Optional[Trajectory]: The updated trajectory if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the record
"""
try:
# First check if the trajectory exists
trajectory = self.get(trajectory_id)
if not trajectory:
logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}")
return None
# Update the fields if provided
update_data = {}
if tool_result is not None:
update_data["tool_result"] = json.dumps(tool_result)
if step_data is not None:
update_data["step_data"] = json.dumps(step_data)
if cost is not None:
update_data["cost"] = cost
if tokens is not None:
update_data["tokens"] = tokens
if update_data:
query = Trajectory.update(**update_data).where(Trajectory.id == trajectory_id)
query.execute()
logger.debug(f"Updated trajectory record ID {trajectory_id}")
return self.get(trajectory_id)
return trajectory
except peewee.DatabaseError as e:
logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}")
raise
def delete(self, trajectory_id: int) -> bool:
"""
Delete a trajectory record by its ID.
Args:
trajectory_id: The ID of the trajectory record to delete
Returns:
bool: True if the record was deleted, False if it wasn't found
Raises:
peewee.DatabaseError: If there's an error deleting the record
"""
try:
# First check if the trajectory exists
trajectory = self.get(trajectory_id)
if not trajectory:
logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}")
return False
# Delete the trajectory
trajectory.delete_instance()
logger.debug(f"Deleted trajectory record ID {trajectory_id}")
return True
except peewee.DatabaseError as e:
logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}")
raise
def get_all(self) -> Dict[int, Trajectory]:
"""
Retrieve all trajectory records from the database.
Returns:
Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)}
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all trajectories: {str(e)}")
raise
def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]:
"""
Retrieve all trajectory records associated with a specific human input.
Args:
human_input_id: The ID of the human input to get trajectories for
Returns:
List[Trajectory]: List of trajectory instances associated with the human input
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}")
raise
def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]:
"""
Parse a JSON string into a Python dictionary.
Args:
json_str: JSON string to parse
Returns:
Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid
"""
if not json_str:
return None
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON field: {str(e)}")
return None
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]:
"""
Get a trajectory record with JSON fields parsed into dictionaries.
Args:
trajectory_id: ID of the trajectory to retrieve
Returns:
Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields,
or None if not found
"""
trajectory = self.get(trajectory_id)
if trajectory is None:
return None
return {
"id": trajectory.id,
"created_at": trajectory.created_at,
"updated_at": trajectory.updated_at,
"tool_name": trajectory.tool_name,
"tool_parameters": self.parse_json_field(trajectory.tool_parameters),
"tool_result": self.parse_json_field(trajectory.tool_result),
"step_data": self.parse_json_field(trajectory.step_data),
"record_type": trajectory.record_type,
"cost": trajectory.cost,
"tokens": trajectory.tokens,
"human_input_id": trajectory.human_input.id if trajectory.human_input else None,
}

View File

@ -0,0 +1,100 @@
"""Peewee migrations -- 007_20250310_184046_add_trajectory_model.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Create the trajectory table for storing agent action trajectories."""
# Check if the table already exists
try:
database.execute_sql("SELECT id FROM trajectory LIMIT 1")
# If we reach here, the table exists
return
except pw.OperationalError:
# Table doesn't exist, safe to create
pass
@migrator.create_model
class Trajectory(pw.Model):
id = pw.AutoField()
created_at = pw.DateTimeField()
updated_at = pw.DateTimeField()
tool_name = pw.TextField()
tool_parameters = pw.TextField() # JSON-encoded parameters
tool_result = pw.TextField() # JSON-encoded result
step_data = pw.TextField() # JSON-encoded UI rendering data
record_type = pw.TextField() # Type of trajectory record
cost = pw.FloatField(null=True) # Placeholder for cost tracking
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
# We'll add the human_input foreign key in a separate step for safety
class Meta:
table_name = "trajectory"
# Check if HumanInput model exists before adding the foreign key
try:
HumanInput = migrator.orm['human_input']
# Only add the foreign key if the human_input_id column doesn't already exist
try:
database.execute_sql("SELECT human_input_id FROM trajectory LIMIT 1")
except pw.OperationalError:
# Column doesn't exist, safe to add
migrator.add_fields(
'trajectory',
human_input=pw.ForeignKeyField(
HumanInput,
null=True,
field='id',
on_delete='SET NULL'
)
)
except KeyError:
# HumanInput doesn't exist, we'll skip adding the foreign key
pass
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove the trajectory table."""
# First remove any foreign key fields
try:
migrator.remove_fields('trajectory', 'human_input')
except pw.OperationalError:
# Field might not exist, that's fine
pass
# Then remove the model
migrator.remove_model('trajectory')