trajectory for stage transitions
This commit is contained in:
parent
b4b0fdd686
commit
96093e8dfc
|
|
@ -1,6 +1,10 @@
|
|||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from typing import Optional
|
||||
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -34,6 +38,20 @@ def print_stage_header(stage: str) -> None:
|
|||
panel_content = f" {icon} {stage_title}"
|
||||
console.print(Panel(panel_content, style="green bold", padding=0))
|
||||
|
||||
# Record trajectory event - focus on semantic meaning
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"stage": stage_key,
|
||||
"display_icon": icon,
|
||||
"display_title": stage_title,
|
||||
},
|
||||
record_type="stage_transition",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
|
||||
def print_task_header(task: str) -> None:
|
||||
"""Print a task header with yellow styling and wrench emoji. Content is rendered as Markdown.
|
||||
|
|
|
|||
|
|
@ -182,11 +182,11 @@ class Trajectory(BaseModel):
|
|||
- Error information (when a tool execution fails)
|
||||
"""
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
|
||||
tool_name = peewee.TextField()
|
||||
tool_parameters = peewee.TextField() # JSON-encoded parameters
|
||||
tool_result = peewee.TextField() # JSON-encoded result
|
||||
step_data = peewee.TextField() # JSON-encoded UI rendering data
|
||||
record_type = peewee.TextField() # Type of trajectory record
|
||||
tool_name = peewee.TextField(null=True)
|
||||
tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters
|
||||
tool_result = peewee.TextField(null=True) # JSON-encoded result
|
||||
step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data
|
||||
record_type = peewee.TextField(null=True) # Type of trajectory record
|
||||
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
|
||||
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
|
||||
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||
|
|
|
|||
|
|
@ -132,8 +132,8 @@ class TrajectoryRepository:
|
|||
|
||||
def create(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
tool_name: Optional[str] = None,
|
||||
tool_parameters: Optional[Dict[str, Any]] = None,
|
||||
tool_result: Optional[Dict[str, Any]] = None,
|
||||
step_data: Optional[Dict[str, Any]] = None,
|
||||
record_type: str = "tool_execution",
|
||||
|
|
@ -149,8 +149,8 @@ class TrajectoryRepository:
|
|||
Create a new trajectory record in the database.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was executed
|
||||
tool_parameters: Parameters passed to the tool (will be JSON encoded)
|
||||
tool_name: Optional name of the tool that was executed
|
||||
tool_parameters: Optional parameters passed to the tool (will be JSON encoded)
|
||||
tool_result: Result returned by the tool (will be JSON encoded)
|
||||
step_data: UI rendering data (will be JSON encoded)
|
||||
record_type: Type of trajectory record
|
||||
|
|
@ -170,7 +170,7 @@ class TrajectoryRepository:
|
|||
"""
|
||||
try:
|
||||
# Serialize JSON fields
|
||||
tool_parameters_json = json.dumps(tool_parameters)
|
||||
tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None
|
||||
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
||||
step_data_json = json.dumps(step_data) if step_data is not None else None
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class TrajectoryRepository:
|
|||
# Create the trajectory record
|
||||
trajectory = Trajectory.create(
|
||||
human_input=human_input,
|
||||
tool_name=tool_name,
|
||||
tool_name=tool_name or "", # Use empty string if tool_name is None
|
||||
tool_parameters=tool_parameters_json,
|
||||
tool_result=tool_result_json,
|
||||
step_data=step_data_json,
|
||||
|
|
@ -197,7 +197,10 @@ class TrajectoryRepository:
|
|||
error_type=error_type,
|
||||
error_details=error_details
|
||||
)
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||
if tool_name:
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||
else:
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
|
||||
return trajectory
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create trajectory record: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -51,11 +51,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
id = pw.AutoField()
|
||||
created_at = pw.DateTimeField()
|
||||
updated_at = pw.DateTimeField()
|
||||
tool_name = pw.TextField()
|
||||
tool_parameters = pw.TextField() # JSON-encoded parameters
|
||||
tool_result = pw.TextField() # JSON-encoded result
|
||||
step_data = pw.TextField() # JSON-encoded UI rendering data
|
||||
record_type = pw.TextField() # Type of trajectory record
|
||||
tool_name = pw.TextField(null=True) # JSON-encoded parameters
|
||||
tool_parameters = pw.TextField(null=True) # JSON-encoded parameters
|
||||
tool_result = pw.TextField(null=True) # JSON-encoded result
|
||||
step_data = pw.TextField(null=True) # JSON-encoded UI rendering data
|
||||
record_type = pw.TextField(null=True) # Type of trajectory record
|
||||
cost = pw.FloatField(null=True) # Placeholder for cost tracking
|
||||
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
|
||||
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||
|
|
|
|||
Loading…
Reference in New Issue