From 96093e8dfcaa1c37c2563bbe33a6e78d854ce26a Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 11 Mar 2025 08:34:08 -0400 Subject: [PATCH] trajectory for stage transitions --- ra_aid/console/formatting.py | 18 ++++++++++++++++++ ra_aid/database/models.py | 10 +++++----- .../repositories/trajectory_repository.py | 17 ++++++++++------- ...007_20250310_184046_add_trajectory_model.py | 10 +++++----- 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/ra_aid/console/formatting.py b/ra_aid/console/formatting.py index d9a0657..229204e 100644 --- a/ra_aid/console/formatting.py +++ b/ra_aid/console/formatting.py @@ -1,6 +1,10 @@ from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel +from typing import Optional + +from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository +from ra_aid.database.repositories.human_input_repository import get_human_input_repository console = Console() @@ -33,6 +37,20 @@ def print_stage_header(stage: str) -> None: # Create styled panel with icon panel_content = f" {icon} {stage_title}" console.print(Panel(panel_content, style="green bold", padding=0)) + + # Record trajectory event - focus on semantic meaning + trajectory_repo = get_trajectory_repository() + human_input_id = get_human_input_repository().get_most_recent_id() + + trajectory_repo.create( + step_data={ + "stage": stage_key, + "display_icon": icon, + "display_title": stage_title, + }, + record_type="stage_transition", + human_input_id=human_input_id + ) def print_task_header(task: str) -> None: diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py index f83acc5..3fc7033 100644 --- a/ra_aid/database/models.py +++ b/ra_aid/database/models.py @@ -182,11 +182,11 @@ class Trajectory(BaseModel): - Error information (when a tool execution fails) """ human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True) - tool_name = peewee.TextField() - tool_parameters = peewee.TextField() # JSON-encoded parameters - tool_result = peewee.TextField() # JSON-encoded result - step_data = peewee.TextField() # JSON-encoded UI rendering data - record_type = peewee.TextField() # Type of trajectory record + tool_name = peewee.TextField(null=True) + tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters + tool_result = peewee.TextField(null=True) # JSON-encoded result + step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data + record_type = peewee.TextField(null=True) # Type of trajectory record cost = peewee.FloatField(null=True) # Placeholder for cost tracking tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error diff --git a/ra_aid/database/repositories/trajectory_repository.py b/ra_aid/database/repositories/trajectory_repository.py index 4b0f8ec..792ff79 100644 --- a/ra_aid/database/repositories/trajectory_repository.py +++ b/ra_aid/database/repositories/trajectory_repository.py @@ -132,8 +132,8 @@ class TrajectoryRepository: def create( self, - tool_name: str, - tool_parameters: Dict[str, Any], + tool_name: Optional[str] = None, + tool_parameters: Optional[Dict[str, Any]] = None, tool_result: Optional[Dict[str, Any]] = None, step_data: Optional[Dict[str, Any]] = None, record_type: str = "tool_execution", @@ -149,8 +149,8 @@ class TrajectoryRepository: Create a new trajectory record in the database. Args: - tool_name: Name of the tool that was executed - tool_parameters: Parameters passed to the tool (will be JSON encoded) + tool_name: Optional name of the tool that was executed + tool_parameters: Optional parameters passed to the tool (will be JSON encoded) tool_result: Result returned by the tool (will be JSON encoded) step_data: UI rendering data (will be JSON encoded) record_type: Type of trajectory record @@ -170,7 +170,7 @@ class TrajectoryRepository: """ try: # Serialize JSON fields - tool_parameters_json = json.dumps(tool_parameters) + tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None tool_result_json = json.dumps(tool_result) if tool_result is not None else None step_data_json = json.dumps(step_data) if step_data is not None else None @@ -185,7 +185,7 @@ class TrajectoryRepository: # Create the trajectory record trajectory = Trajectory.create( human_input=human_input, - tool_name=tool_name, + tool_name=tool_name or "", # Use empty string if tool_name is None tool_parameters=tool_parameters_json, tool_result=tool_result_json, step_data=step_data_json, @@ -197,7 +197,10 @@ class TrajectoryRepository: error_type=error_type, error_details=error_details ) - logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}") + if tool_name: + logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}") + else: + logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}") return trajectory except peewee.DatabaseError as e: logger.error(f"Failed to create trajectory record: {str(e)}") diff --git a/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py b/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py index 3eb20cd..e969d58 100644 --- a/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py +++ b/ra_aid/migrations/007_20250310_184046_add_trajectory_model.py @@ -51,11 +51,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): id = pw.AutoField() created_at = pw.DateTimeField() updated_at = pw.DateTimeField() - tool_name = pw.TextField() - tool_parameters = pw.TextField() # JSON-encoded parameters - tool_result = pw.TextField() # JSON-encoded result - step_data = pw.TextField() # JSON-encoded UI rendering data - record_type = pw.TextField() # Type of trajectory record + tool_name = pw.TextField(null=True) # JSON-encoded parameters + tool_parameters = pw.TextField(null=True) # JSON-encoded parameters + tool_result = pw.TextField(null=True) # JSON-encoded result + step_data = pw.TextField(null=True) # JSON-encoded UI rendering data + record_type = pw.TextField(null=True) # Type of trajectory record cost = pw.FloatField(null=True) # Placeholder for cost tracking tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error