trajectory for stage transitions

This commit is contained in:
AI Christianson 2025-03-11 08:34:08 -04:00
parent b4b0fdd686
commit 96093e8dfc
4 changed files with 38 additions and 17 deletions

View File

@ -1,6 +1,10 @@
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from typing import Optional
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
console = Console()
@ -34,6 +38,20 @@ def print_stage_header(stage: str) -> None:
panel_content = f" {icon} {stage_title}"
console.print(Panel(panel_content, style="green bold", padding=0))
# Record trajectory event - focus on semantic meaning
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"stage": stage_key,
"display_icon": icon,
"display_title": stage_title,
},
record_type="stage_transition",
human_input_id=human_input_id
)
def print_task_header(task: str) -> None:
"""Print a task header with yellow styling and wrench emoji. Content is rendered as Markdown.

View File

@ -182,11 +182,11 @@ class Trajectory(BaseModel):
- Error information (when a tool execution fails)
"""
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
tool_name = peewee.TextField()
tool_parameters = peewee.TextField() # JSON-encoded parameters
tool_result = peewee.TextField() # JSON-encoded result
step_data = peewee.TextField() # JSON-encoded UI rendering data
record_type = peewee.TextField() # Type of trajectory record
tool_name = peewee.TextField(null=True)
tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters
tool_result = peewee.TextField(null=True) # JSON-encoded result
step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data
record_type = peewee.TextField(null=True) # Type of trajectory record
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error

View File

@ -132,8 +132,8 @@ class TrajectoryRepository:
def create(
self,
tool_name: str,
tool_parameters: Dict[str, Any],
tool_name: Optional[str] = None,
tool_parameters: Optional[Dict[str, Any]] = None,
tool_result: Optional[Dict[str, Any]] = None,
step_data: Optional[Dict[str, Any]] = None,
record_type: str = "tool_execution",
@ -149,8 +149,8 @@ class TrajectoryRepository:
Create a new trajectory record in the database.
Args:
tool_name: Name of the tool that was executed
tool_parameters: Parameters passed to the tool (will be JSON encoded)
tool_name: Optional name of the tool that was executed
tool_parameters: Optional parameters passed to the tool (will be JSON encoded)
tool_result: Result returned by the tool (will be JSON encoded)
step_data: UI rendering data (will be JSON encoded)
record_type: Type of trajectory record
@ -170,7 +170,7 @@ class TrajectoryRepository:
"""
try:
# Serialize JSON fields
tool_parameters_json = json.dumps(tool_parameters)
tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
step_data_json = json.dumps(step_data) if step_data is not None else None
@ -185,7 +185,7 @@ class TrajectoryRepository:
# Create the trajectory record
trajectory = Trajectory.create(
human_input=human_input,
tool_name=tool_name,
tool_name=tool_name or "", # Use empty string if tool_name is None
tool_parameters=tool_parameters_json,
tool_result=tool_result_json,
step_data=step_data_json,
@ -197,7 +197,10 @@ class TrajectoryRepository:
error_type=error_type,
error_details=error_details
)
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
if tool_name:
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
else:
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
return trajectory
except peewee.DatabaseError as e:
logger.error(f"Failed to create trajectory record: {str(e)}")

View File

@ -51,11 +51,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
id = pw.AutoField()
created_at = pw.DateTimeField()
updated_at = pw.DateTimeField()
tool_name = pw.TextField()
tool_parameters = pw.TextField() # JSON-encoded parameters
tool_result = pw.TextField() # JSON-encoded result
step_data = pw.TextField() # JSON-encoded UI rendering data
record_type = pw.TextField() # Type of trajectory record
tool_name = pw.TextField(null=True) # JSON-encoded parameters
tool_parameters = pw.TextField(null=True) # JSON-encoded parameters
tool_result = pw.TextField(null=True) # JSON-encoded result
step_data = pw.TextField(null=True) # JSON-encoded UI rendering data
record_type = pw.TextField(null=True) # Type of trajectory record
cost = pw.FloatField(null=True) # Placeholder for cost tracking
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error