From b3010bb649806e7d009bef2cae2cab01df724eec Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Mon, 10 Mar 2025 19:12:14 -0400 Subject: [PATCH] add error info to trajectory records --- ra_aid/database/models.py | 5 +++ .../repositories/trajectory_repository.py | 42 +++++++++++++++++-- ...07_20250310_184046_add_trajectory_model.py | 4 ++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py index 9cfe0a5..f83acc5 100644 --- a/ra_aid/database/models.py +++ b/ra_aid/database/models.py @@ -179,6 +179,7 @@ class Trajectory(BaseModel): - What result was returned by the tool - UI rendering data for displaying the tool execution - Cost and token usage metrics (placeholders for future implementation) + - Error information (when a tool execution fails) """ human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True) tool_name = peewee.TextField() @@ -188,6 +189,10 @@ class Trajectory(BaseModel): record_type = peewee.TextField() # Type of trajectory record cost = peewee.FloatField(null=True) # Placeholder for cost tracking tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking + is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error + error_message = peewee.TextField(null=True) # The error message + error_type = peewee.TextField(null=True) # The type/class of the error + error_details = peewee.TextField(null=True) # Additional error details like stack traces or context # created_at and updated_at are inherited from BaseModel class Meta: diff --git a/ra_aid/database/repositories/trajectory_repository.py b/ra_aid/database/repositories/trajectory_repository.py index 14498c0..4b0f8ec 100644 --- a/ra_aid/database/repositories/trajectory_repository.py +++ b/ra_aid/database/repositories/trajectory_repository.py @@ -139,7 +139,11 @@ class TrajectoryRepository: record_type: str = "tool_execution", human_input_id: Optional[int] = None, cost: Optional[float] = None, - tokens: Optional[int] = None + tokens: Optional[int] = None, + is_error: bool = False, + error_message: Optional[str] = None, + error_type: Optional[str] = None, + error_details: Optional[str] = None ) -> Trajectory: """ Create a new trajectory record in the database. @@ -153,6 +157,10 @@ class TrajectoryRepository: human_input_id: Optional ID of the associated human input cost: Optional cost of the operation (placeholder) tokens: Optional token usage (placeholder) + is_error: Flag indicating if this record represents an error (default: False) + error_message: The error message (if is_error is True) + error_type: The type/class of the error (if is_error is True) + error_details: Additional error details like stack traces (if is_error is True) Returns: Trajectory: The newly created trajectory instance @@ -183,7 +191,11 @@ class TrajectoryRepository: step_data=step_data_json, record_type=record_type, cost=cost, - tokens=tokens + tokens=tokens, + is_error=is_error, + error_message=error_message, + error_type=error_type, + error_details=error_details ) logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}") return trajectory @@ -216,7 +228,11 @@ class TrajectoryRepository: tool_result: Optional[Dict[str, Any]] = None, step_data: Optional[Dict[str, Any]] = None, cost: Optional[float] = None, - tokens: Optional[int] = None + tokens: Optional[int] = None, + is_error: Optional[bool] = None, + error_message: Optional[str] = None, + error_type: Optional[str] = None, + error_details: Optional[str] = None ) -> Optional[Trajectory]: """ Update an existing trajectory record. @@ -229,6 +245,10 @@ class TrajectoryRepository: step_data: Updated UI rendering data (will be JSON encoded) cost: Updated cost information tokens: Updated token usage information + is_error: Flag indicating if this record represents an error + error_message: The error message + error_type: The type/class of the error + error_details: Additional error details like stack traces Returns: Optional[Trajectory]: The updated trajectory if found, None otherwise @@ -257,6 +277,18 @@ class TrajectoryRepository: if tokens is not None: update_data["tokens"] = tokens + + if is_error is not None: + update_data["is_error"] = is_error + + if error_message is not None: + update_data["error_message"] = error_message + + if error_type is not None: + update_data["error_type"] = error_type + + if error_details is not None: + update_data["error_details"] = error_details if update_data: query = Trajectory.update(**update_data).where(Trajectory.id == trajectory_id) @@ -378,4 +410,8 @@ class TrajectoryRepository: "cost": trajectory.cost, "tokens": trajectory.tokens, "human_input_id": trajectory.human_input.id if trajectory.human_input else None, + "is_error": trajectory.is_error, + "error_message": trajectory.error_message, + "error_type": trajectory.error_type, + "error_details": trajectory.error_details, } \ No newline at end of file diff --git a/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py b/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py index b62ecc9..3eb20cd 100644 --- a/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py +++ b/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py @@ -58,6 +58,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): record_type = pw.TextField() # Type of trajectory record cost = pw.FloatField(null=True) # Placeholder for cost tracking tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking + is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error + error_message = pw.TextField(null=True) # The error message + error_type = pw.TextField(null=True) # The type/class of the error + error_details = pw.TextField(null=True) # Additional error details like stack traces or context # We'll add the human_input foreign key in a separate step for safety class Meta: