trajectory for stage transitions
This commit is contained in:
parent
b4b0fdd686
commit
96093e8dfc
|
|
@ -1,6 +1,10 @@
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -33,6 +37,20 @@ def print_stage_header(stage: str) -> None:
|
||||||
# Create styled panel with icon
|
# Create styled panel with icon
|
||||||
panel_content = f" {icon} {stage_title}"
|
panel_content = f" {icon} {stage_title}"
|
||||||
console.print(Panel(panel_content, style="green bold", padding=0))
|
console.print(Panel(panel_content, style="green bold", padding=0))
|
||||||
|
|
||||||
|
# Record trajectory event - focus on semantic meaning
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"stage": stage_key,
|
||||||
|
"display_icon": icon,
|
||||||
|
"display_title": stage_title,
|
||||||
|
},
|
||||||
|
record_type="stage_transition",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_task_header(task: str) -> None:
|
def print_task_header(task: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -182,11 +182,11 @@ class Trajectory(BaseModel):
|
||||||
- Error information (when a tool execution fails)
|
- Error information (when a tool execution fails)
|
||||||
"""
|
"""
|
||||||
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
|
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
|
||||||
tool_name = peewee.TextField()
|
tool_name = peewee.TextField(null=True)
|
||||||
tool_parameters = peewee.TextField() # JSON-encoded parameters
|
tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters
|
||||||
tool_result = peewee.TextField() # JSON-encoded result
|
tool_result = peewee.TextField(null=True) # JSON-encoded result
|
||||||
step_data = peewee.TextField() # JSON-encoded UI rendering data
|
step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data
|
||||||
record_type = peewee.TextField() # Type of trajectory record
|
record_type = peewee.TextField(null=True) # Type of trajectory record
|
||||||
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
|
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
|
||||||
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
|
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
|
||||||
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
|
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||||
|
|
|
||||||
|
|
@ -132,8 +132,8 @@ class TrajectoryRepository:
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: Optional[str] = None,
|
||||||
tool_parameters: Dict[str, Any],
|
tool_parameters: Optional[Dict[str, Any]] = None,
|
||||||
tool_result: Optional[Dict[str, Any]] = None,
|
tool_result: Optional[Dict[str, Any]] = None,
|
||||||
step_data: Optional[Dict[str, Any]] = None,
|
step_data: Optional[Dict[str, Any]] = None,
|
||||||
record_type: str = "tool_execution",
|
record_type: str = "tool_execution",
|
||||||
|
|
@ -149,8 +149,8 @@ class TrajectoryRepository:
|
||||||
Create a new trajectory record in the database.
|
Create a new trajectory record in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: Name of the tool that was executed
|
tool_name: Optional name of the tool that was executed
|
||||||
tool_parameters: Parameters passed to the tool (will be JSON encoded)
|
tool_parameters: Optional parameters passed to the tool (will be JSON encoded)
|
||||||
tool_result: Result returned by the tool (will be JSON encoded)
|
tool_result: Result returned by the tool (will be JSON encoded)
|
||||||
step_data: UI rendering data (will be JSON encoded)
|
step_data: UI rendering data (will be JSON encoded)
|
||||||
record_type: Type of trajectory record
|
record_type: Type of trajectory record
|
||||||
|
|
@ -170,7 +170,7 @@ class TrajectoryRepository:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Serialize JSON fields
|
# Serialize JSON fields
|
||||||
tool_parameters_json = json.dumps(tool_parameters)
|
tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None
|
||||||
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
||||||
step_data_json = json.dumps(step_data) if step_data is not None else None
|
step_data_json = json.dumps(step_data) if step_data is not None else None
|
||||||
|
|
||||||
|
|
@ -185,7 +185,7 @@ class TrajectoryRepository:
|
||||||
# Create the trajectory record
|
# Create the trajectory record
|
||||||
trajectory = Trajectory.create(
|
trajectory = Trajectory.create(
|
||||||
human_input=human_input,
|
human_input=human_input,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name or "", # Use empty string if tool_name is None
|
||||||
tool_parameters=tool_parameters_json,
|
tool_parameters=tool_parameters_json,
|
||||||
tool_result=tool_result_json,
|
tool_result=tool_result_json,
|
||||||
step_data=step_data_json,
|
step_data=step_data_json,
|
||||||
|
|
@ -197,7 +197,10 @@ class TrajectoryRepository:
|
||||||
error_type=error_type,
|
error_type=error_type,
|
||||||
error_details=error_details
|
error_details=error_details
|
||||||
)
|
)
|
||||||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
if tool_name:
|
||||||
|
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
|
||||||
return trajectory
|
return trajectory
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to create trajectory record: {str(e)}")
|
logger.error(f"Failed to create trajectory record: {str(e)}")
|
||||||
|
|
|
||||||
|
|
@ -51,11 +51,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
id = pw.AutoField()
|
id = pw.AutoField()
|
||||||
created_at = pw.DateTimeField()
|
created_at = pw.DateTimeField()
|
||||||
updated_at = pw.DateTimeField()
|
updated_at = pw.DateTimeField()
|
||||||
tool_name = pw.TextField()
|
tool_name = pw.TextField(null=True) # JSON-encoded parameters
|
||||||
tool_parameters = pw.TextField() # JSON-encoded parameters
|
tool_parameters = pw.TextField(null=True) # JSON-encoded parameters
|
||||||
tool_result = pw.TextField() # JSON-encoded result
|
tool_result = pw.TextField(null=True) # JSON-encoded result
|
||||||
step_data = pw.TextField() # JSON-encoded UI rendering data
|
step_data = pw.TextField(null=True) # JSON-encoded UI rendering data
|
||||||
record_type = pw.TextField() # Type of trajectory record
|
record_type = pw.TextField(null=True) # Type of trajectory record
|
||||||
cost = pw.FloatField(null=True) # Placeholder for cost tracking
|
cost = pw.FloatField(null=True) # Placeholder for cost tracking
|
||||||
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
|
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
|
||||||
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error
|
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue