RA.Aid/ra_aid/database/models.py

223 lines
8.1 KiB
Python

"""
Database models for ra_aid.
This module defines the base model class that all models will inherit from.
"""
import datetime
from typing import Any, Type, TypeVar
import peewee
from ra_aid.database.connection import get_db
from ra_aid.logging_config import get_logger
T = TypeVar("T", bound="BaseModel")
logger = get_logger(__name__)
# Create a database proxy that will be initialized later
database_proxy = peewee.DatabaseProxy()
def initialize_database():
"""
Initialize the database proxy with a real database connection.
This function should be called before any database operations
to ensure the proxy points to a real database connection.
Returns:
peewee.SqliteDatabase: The initialized database connection
"""
db = get_db()
# Check if proxy is already initialized by checking the obj attribute directly
if getattr(database_proxy, 'obj', None) is None:
logger.debug("Initializing database proxy")
database_proxy.initialize(db)
else:
logger.debug("Database proxy already initialized")
# Create tables if they don't exist yet
# We need to import models here for table creation
# to avoid circular imports
# Note: This import needs to be here, not at the top level
try:
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session], safe=True)
logger.debug("Ensured database tables exist")
except Exception as e:
logger.error(f"Error creating tables: {str(e)}")
return db
class BaseModel(peewee.Model):
"""
Base model class for all ra_aid models.
All models should inherit from this class to ensure consistent
behavior and database connection.
"""
created_at = peewee.DateTimeField(default=datetime.datetime.now)
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
class Meta:
database = database_proxy
def save(self, *args: Any, **kwargs: Any) -> int:
"""
Override save to update the updated_at field.
Args:
*args: Arguments to pass to the parent save method
**kwargs: Keyword arguments to pass to the parent save method
Returns:
int: The primary key of the saved instance
"""
self.updated_at = datetime.datetime.now()
return super().save(*args, **kwargs)
@classmethod
def get_or_create(cls: Type[T], **kwargs: Any) -> tuple[T, bool]:
"""
Get an instance or create it if it doesn't exist.
Args:
**kwargs: Fields to use for lookup and creation
Returns:
tuple: (instance, created) where created is a boolean indicating
whether a new instance was created
"""
try:
return super().get_or_create(**kwargs)
except peewee.DatabaseError as e:
# Log the error with logger
logger.error(f"Failed in get_or_create: {str(e)}")
raise
class Session(BaseModel):
"""
Model representing a session stored in the database.
Sessions track information about each program run, providing a way to group
related records like human inputs, trajectories, and key facts.
Each session record captures details about when the program was started,
what command line arguments were used, and environment information.
"""
start_time = peewee.DateTimeField(default=datetime.datetime.now)
command_line = peewee.TextField(null=True)
program_version = peewee.TextField(null=True)
machine_info = peewee.TextField(null=True) # JSON-encoded machine information
class Meta:
table_name = "session"
class HumanInput(BaseModel):
"""
Model representing human input stored in the database.
Human inputs are text inputs provided by users through various interfaces
such as CLI, chat, or HIL (human-in-the-loop). This model tracks these inputs
along with their source for analysis and reference.
"""
content = peewee.TextField()
source = peewee.TextField() # 'cli', 'chat', or 'hil'
session = peewee.ForeignKeyField(Session, backref='human_inputs', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "human_input"
class KeyFact(BaseModel):
"""
Model representing a key fact stored in the database.
Key facts are important information about the project or current task
that need to be referenced later.
"""
content = peewee.TextField()
human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True)
session = peewee.ForeignKeyField(Session, backref='key_facts', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "key_fact"
class KeySnippet(BaseModel):
"""
Model representing a key code snippet stored in the database.
Key snippets are important code fragments from the project that need to be
referenced later. Each snippet includes its file location, line number,
the code content itself, and an optional description of its significance.
"""
filepath = peewee.TextField()
line_number = peewee.IntegerField()
snippet = peewee.TextField()
description = peewee.TextField(null=True)
human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True)
session = peewee.ForeignKeyField(Session, backref='key_snippets', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "key_snippet"
class ResearchNote(BaseModel):
"""
Model representing a research note stored in the database.
Research notes are detailed information compiled from research activities
that need to be preserved for future reference. These notes contain valuable
context and findings about topics relevant to the project.
"""
content = peewee.TextField()
human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True)
session = peewee.ForeignKeyField(Session, backref='research_notes', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "research_note"
class Trajectory(BaseModel):
"""
Model representing an agent trajectory stored in the database.
Trajectories track the sequence of actions taken by agents, including
tool executions and their results. This enables analysis of agent behavior,
debugging of issues, and reconstruction of the decision-making process.
Each trajectory record captures details about a single tool execution:
- Which tool was used
- What parameters were passed to the tool
- What result was returned by the tool
- UI rendering data for displaying the tool execution
- Cost and token usage metrics (placeholders for future implementation)
- Error information (when a tool execution fails)
"""
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
tool_name = peewee.TextField(null=True)
tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters
tool_result = peewee.TextField(null=True) # JSON-encoded result
step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data
record_type = peewee.TextField(null=True) # Type of trajectory record
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
error_message = peewee.TextField(null=True) # The error message
error_type = peewee.TextField(null=True) # The type/class of the error
error_details = peewee.TextField(null=True) # Additional error details like stack traces or context
session = peewee.ForeignKeyField(Session, backref='trajectories', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "trajectory"