add trajectory table
This commit is contained in:
parent
a18998be0d
commit
78983ec20b
|
|
@ -53,6 +53,9 @@ from ra_aid.database.repositories.human_input_repository import (
|
||||||
from ra_aid.database.repositories.research_note_repository import (
|
from ra_aid.database.repositories.research_note_repository import (
|
||||||
ResearchNoteRepositoryManager, get_research_note_repository
|
ResearchNoteRepositoryManager, get_research_note_repository
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import (
|
||||||
|
TrajectoryRepositoryManager, get_trajectory_repository
|
||||||
|
)
|
||||||
from ra_aid.database.repositories.related_files_repository import (
|
from ra_aid.database.repositories.related_files_repository import (
|
||||||
RelatedFilesRepositoryManager
|
RelatedFilesRepositoryManager
|
||||||
)
|
)
|
||||||
|
|
@ -528,6 +531,7 @@ def main():
|
||||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||||
|
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||||
WorkLogRepositoryManager() as work_log_repo, \
|
WorkLogRepositoryManager() as work_log_repo, \
|
||||||
ConfigRepositoryManager(config) as config_repo, \
|
ConfigRepositoryManager(config) as config_repo, \
|
||||||
EnvInvManager(env_data) as env_inv:
|
EnvInvManager(env_data) as env_inv:
|
||||||
|
|
@ -537,6 +541,7 @@ def main():
|
||||||
logger.debug("Initialized HumanInputRepository")
|
logger.debug("Initialized HumanInputRepository")
|
||||||
logger.debug("Initialized ResearchNoteRepository")
|
logger.debug("Initialized ResearchNoteRepository")
|
||||||
logger.debug("Initialized RelatedFilesRepository")
|
logger.debug("Initialized RelatedFilesRepository")
|
||||||
|
logger.debug("Initialized TrajectoryRepository")
|
||||||
logger.debug("Initialized WorkLogRepository")
|
logger.debug("Initialized WorkLogRepository")
|
||||||
logger.debug("Initialized ConfigRepository")
|
logger.debug("Initialized ConfigRepository")
|
||||||
logger.debug("Initialized Environment Inventory")
|
logger.debug("Initialized Environment Inventory")
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,8 @@ def initialize_database():
|
||||||
# to avoid circular imports
|
# to avoid circular imports
|
||||||
# Note: This import needs to be here, not at the top level
|
# Note: This import needs to be here, not at the top level
|
||||||
try:
|
try:
|
||||||
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote
|
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory
|
||||||
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote], safe=True)
|
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True)
|
||||||
logger.debug("Ensured database tables exist")
|
logger.debug("Ensured database tables exist")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating tables: {str(e)}")
|
logger.error(f"Error creating tables: {str(e)}")
|
||||||
|
|
@ -163,3 +163,32 @@ class ResearchNote(BaseModel):
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "research_note"
|
table_name = "research_note"
|
||||||
|
|
||||||
|
|
||||||
|
class Trajectory(BaseModel):
|
||||||
|
"""
|
||||||
|
Model representing an agent trajectory stored in the database.
|
||||||
|
|
||||||
|
Trajectories track the sequence of actions taken by agents, including
|
||||||
|
tool executions and their results. This enables analysis of agent behavior,
|
||||||
|
debugging of issues, and reconstruction of the decision-making process.
|
||||||
|
|
||||||
|
Each trajectory record captures details about a single tool execution:
|
||||||
|
- Which tool was used
|
||||||
|
- What parameters were passed to the tool
|
||||||
|
- What result was returned by the tool
|
||||||
|
- UI rendering data for displaying the tool execution
|
||||||
|
- Cost and token usage metrics (placeholders for future implementation)
|
||||||
|
"""
|
||||||
|
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
|
||||||
|
tool_name = peewee.TextField()
|
||||||
|
tool_parameters = peewee.TextField() # JSON-encoded parameters
|
||||||
|
tool_result = peewee.TextField() # JSON-encoded result
|
||||||
|
step_data = peewee.TextField() # JSON-encoded UI rendering data
|
||||||
|
record_type = peewee.TextField() # Type of trajectory record
|
||||||
|
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
|
||||||
|
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
|
||||||
|
# created_at and updated_at are inherited from BaseModel
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "trajectory"
|
||||||
|
|
@ -0,0 +1,381 @@
|
||||||
|
"""
|
||||||
|
Trajectory repository implementation for database access.
|
||||||
|
|
||||||
|
This module provides a repository implementation for the Trajectory model,
|
||||||
|
following the repository pattern for data access abstraction. It handles
|
||||||
|
operations for storing and retrieving agent action trajectories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Any, Union
|
||||||
|
import contextvars
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import peewee
|
||||||
|
|
||||||
|
from ra_aid.database.models import Trajectory, HumanInput
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Create contextvar to hold the TrajectoryRepository instance
|
||||||
|
trajectory_repo_var = contextvars.ContextVar("trajectory_repo", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TrajectoryRepositoryManager:
|
||||||
|
"""
|
||||||
|
Context manager for TrajectoryRepository.
|
||||||
|
|
||||||
|
This class provides a context manager interface for TrajectoryRepository,
|
||||||
|
using the contextvars approach for thread safety.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with DatabaseManager() as db:
|
||||||
|
with TrajectoryRepositoryManager(db) as repo:
|
||||||
|
# Use the repository
|
||||||
|
trajectory = repo.create(
|
||||||
|
tool_name="ripgrep_search",
|
||||||
|
tool_parameters={"pattern": "example"}
|
||||||
|
)
|
||||||
|
all_trajectories = repo.get_all()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db):
|
||||||
|
"""
|
||||||
|
Initialize the TrajectoryRepositoryManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database connection to use (required)
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def __enter__(self) -> 'TrajectoryRepository':
|
||||||
|
"""
|
||||||
|
Initialize the TrajectoryRepository and return it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrajectoryRepository: The initialized repository
|
||||||
|
"""
|
||||||
|
repo = TrajectoryRepository(self.db)
|
||||||
|
trajectory_repo_var.set(repo)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type],
|
||||||
|
exc_val: Optional[Exception],
|
||||||
|
exc_tb: Optional[object],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reset the repository when exiting the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: The exception type if an exception was raised
|
||||||
|
exc_val: The exception value if an exception was raised
|
||||||
|
exc_tb: The traceback if an exception was raised
|
||||||
|
"""
|
||||||
|
# Reset the contextvar to None
|
||||||
|
trajectory_repo_var.set(None)
|
||||||
|
|
||||||
|
# Don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_trajectory_repository() -> 'TrajectoryRepository':
|
||||||
|
"""
|
||||||
|
Get the current TrajectoryRepository instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrajectoryRepository: The current repository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no repository has been initialized with TrajectoryRepositoryManager
|
||||||
|
"""
|
||||||
|
repo = trajectory_repo_var.get()
|
||||||
|
if repo is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No TrajectoryRepository available. "
|
||||||
|
"Make sure to initialize one with TrajectoryRepositoryManager first."
|
||||||
|
)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
class TrajectoryRepository:
|
||||||
|
"""
|
||||||
|
Repository for managing Trajectory database operations.
|
||||||
|
|
||||||
|
This class provides methods for performing CRUD operations on the Trajectory model,
|
||||||
|
abstracting the database access details from the business logic. It handles
|
||||||
|
serialization and deserialization of JSON fields for tool parameters, results,
|
||||||
|
and UI rendering data.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with DatabaseManager() as db:
|
||||||
|
with TrajectoryRepositoryManager(db) as repo:
|
||||||
|
trajectory = repo.create(
|
||||||
|
tool_name="ripgrep_search",
|
||||||
|
tool_parameters={"pattern": "example"}
|
||||||
|
)
|
||||||
|
all_trajectories = repo.get_all()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db):
|
||||||
|
"""
|
||||||
|
Initialize the repository with a database connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database connection to use (required)
|
||||||
|
"""
|
||||||
|
if db is None:
|
||||||
|
raise ValueError("Database connection is required for TrajectoryRepository")
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_parameters: Dict[str, Any],
|
||||||
|
tool_result: Optional[Dict[str, Any]] = None,
|
||||||
|
step_data: Optional[Dict[str, Any]] = None,
|
||||||
|
record_type: str = "tool_execution",
|
||||||
|
human_input_id: Optional[int] = None,
|
||||||
|
cost: Optional[float] = None,
|
||||||
|
tokens: Optional[int] = None
|
||||||
|
) -> Trajectory:
|
||||||
|
"""
|
||||||
|
Create a new trajectory record in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool that was executed
|
||||||
|
tool_parameters: Parameters passed to the tool (will be JSON encoded)
|
||||||
|
tool_result: Result returned by the tool (will be JSON encoded)
|
||||||
|
step_data: UI rendering data (will be JSON encoded)
|
||||||
|
record_type: Type of trajectory record
|
||||||
|
human_input_id: Optional ID of the associated human input
|
||||||
|
cost: Optional cost of the operation (placeholder)
|
||||||
|
tokens: Optional token usage (placeholder)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trajectory: The newly created trajectory instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error creating the record
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Serialize JSON fields
|
||||||
|
tool_parameters_json = json.dumps(tool_parameters)
|
||||||
|
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
||||||
|
step_data_json = json.dumps(step_data) if step_data is not None else None
|
||||||
|
|
||||||
|
# Create human input reference if provided
|
||||||
|
human_input = None
|
||||||
|
if human_input_id is not None:
|
||||||
|
try:
|
||||||
|
human_input = HumanInput.get_by_id(human_input_id)
|
||||||
|
except peewee.DoesNotExist:
|
||||||
|
logger.warning(f"Human input with ID {human_input_id} not found")
|
||||||
|
|
||||||
|
# Create the trajectory record
|
||||||
|
trajectory = Trajectory.create(
|
||||||
|
human_input=human_input,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_parameters=tool_parameters_json,
|
||||||
|
tool_result=tool_result_json,
|
||||||
|
step_data=step_data_json,
|
||||||
|
record_type=record_type,
|
||||||
|
cost=cost,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||||
|
return trajectory
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to create trajectory record: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get(self, trajectory_id: int) -> Optional[Trajectory]:
|
||||||
|
"""
|
||||||
|
Retrieve a trajectory record by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trajectory_id: The ID of the trajectory record to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Trajectory]: The trajectory instance if found, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
trajectory_id: int,
|
||||||
|
tool_result: Optional[Dict[str, Any]] = None,
|
||||||
|
step_data: Optional[Dict[str, Any]] = None,
|
||||||
|
cost: Optional[float] = None,
|
||||||
|
tokens: Optional[int] = None
|
||||||
|
) -> Optional[Trajectory]:
|
||||||
|
"""
|
||||||
|
Update an existing trajectory record.
|
||||||
|
|
||||||
|
This is typically used to update the result or metrics after tool execution completes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trajectory_id: The ID of the trajectory record to update
|
||||||
|
tool_result: Updated tool result (will be JSON encoded)
|
||||||
|
step_data: Updated UI rendering data (will be JSON encoded)
|
||||||
|
cost: Updated cost information
|
||||||
|
tokens: Updated token usage information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Trajectory]: The updated trajectory if found, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error updating the record
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# First check if the trajectory exists
|
||||||
|
trajectory = self.get(trajectory_id)
|
||||||
|
if not trajectory:
|
||||||
|
logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update the fields if provided
|
||||||
|
update_data = {}
|
||||||
|
|
||||||
|
if tool_result is not None:
|
||||||
|
update_data["tool_result"] = json.dumps(tool_result)
|
||||||
|
|
||||||
|
if step_data is not None:
|
||||||
|
update_data["step_data"] = json.dumps(step_data)
|
||||||
|
|
||||||
|
if cost is not None:
|
||||||
|
update_data["cost"] = cost
|
||||||
|
|
||||||
|
if tokens is not None:
|
||||||
|
update_data["tokens"] = tokens
|
||||||
|
|
||||||
|
if update_data:
|
||||||
|
query = Trajectory.update(**update_data).where(Trajectory.id == trajectory_id)
|
||||||
|
query.execute()
|
||||||
|
logger.debug(f"Updated trajectory record ID {trajectory_id}")
|
||||||
|
return self.get(trajectory_id)
|
||||||
|
|
||||||
|
return trajectory
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete(self, trajectory_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a trajectory record by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trajectory_id: The ID of the trajectory record to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the record was deleted, False if it wasn't found
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error deleting the record
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# First check if the trajectory exists
|
||||||
|
trajectory = self.get(trajectory_id)
|
||||||
|
if not trajectory:
|
||||||
|
logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Delete the trajectory
|
||||||
|
trajectory.delete_instance()
|
||||||
|
logger.debug(f"Deleted trajectory record ID {trajectory_id}")
|
||||||
|
return True
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_all(self) -> Dict[int, Trajectory]:
|
||||||
|
"""
|
||||||
|
Retrieve all trajectory records from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)}
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to fetch all trajectories: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]:
|
||||||
|
"""
|
||||||
|
Retrieve all trajectory records associated with a specific human input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
human_input_id: The ID of the human input to get trajectories for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Trajectory]: List of trajectory instances associated with the human input
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Parse a JSON string into a Python dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str: JSON string to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid
|
||||||
|
"""
|
||||||
|
if not json_str:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(json_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Error parsing JSON field: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get a trajectory record with JSON fields parsed into dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trajectory_id: ID of the trajectory to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields,
|
||||||
|
or None if not found
|
||||||
|
"""
|
||||||
|
trajectory = self.get(trajectory_id)
|
||||||
|
if trajectory is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": trajectory.id,
|
||||||
|
"created_at": trajectory.created_at,
|
||||||
|
"updated_at": trajectory.updated_at,
|
||||||
|
"tool_name": trajectory.tool_name,
|
||||||
|
"tool_parameters": self.parse_json_field(trajectory.tool_parameters),
|
||||||
|
"tool_result": self.parse_json_field(trajectory.tool_result),
|
||||||
|
"step_data": self.parse_json_field(trajectory.step_data),
|
||||||
|
"record_type": trajectory.record_type,
|
||||||
|
"cost": trajectory.cost,
|
||||||
|
"tokens": trajectory.tokens,
|
||||||
|
"human_input_id": trajectory.human_input.id if trajectory.human_input else None,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
"""Peewee migrations -- 007_20250310_184046_add_trajectory_model.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Create the trajectory table for storing agent action trajectories."""
|
||||||
|
|
||||||
|
# Check if the table already exists
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT id FROM trajectory LIMIT 1")
|
||||||
|
# If we reach here, the table exists
|
||||||
|
return
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Table doesn't exist, safe to create
|
||||||
|
pass
|
||||||
|
|
||||||
|
@migrator.create_model
|
||||||
|
class Trajectory(pw.Model):
|
||||||
|
id = pw.AutoField()
|
||||||
|
created_at = pw.DateTimeField()
|
||||||
|
updated_at = pw.DateTimeField()
|
||||||
|
tool_name = pw.TextField()
|
||||||
|
tool_parameters = pw.TextField() # JSON-encoded parameters
|
||||||
|
tool_result = pw.TextField() # JSON-encoded result
|
||||||
|
step_data = pw.TextField() # JSON-encoded UI rendering data
|
||||||
|
record_type = pw.TextField() # Type of trajectory record
|
||||||
|
cost = pw.FloatField(null=True) # Placeholder for cost tracking
|
||||||
|
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
|
||||||
|
# We'll add the human_input foreign key in a separate step for safety
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "trajectory"
|
||||||
|
|
||||||
|
# Check if HumanInput model exists before adding the foreign key
|
||||||
|
try:
|
||||||
|
HumanInput = migrator.orm['human_input']
|
||||||
|
|
||||||
|
# Only add the foreign key if the human_input_id column doesn't already exist
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT human_input_id FROM trajectory LIMIT 1")
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Column doesn't exist, safe to add
|
||||||
|
migrator.add_fields(
|
||||||
|
'trajectory',
|
||||||
|
human_input=pw.ForeignKeyField(
|
||||||
|
HumanInput,
|
||||||
|
null=True,
|
||||||
|
field='id',
|
||||||
|
on_delete='SET NULL'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
# HumanInput doesn't exist, we'll skip adding the foreign key
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Remove the trajectory table."""
|
||||||
|
|
||||||
|
# First remove any foreign key fields
|
||||||
|
try:
|
||||||
|
migrator.remove_fields('trajectory', 'human_input')
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Field might not exist, that's fine
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Then remove the model
|
||||||
|
migrator.remove_model('trajectory')
|
||||||
Loading…
Reference in New Issue