use pydantic models
This commit is contained in:
parent
5d07a7f7b8
commit
e0aab1021b
|
|
@ -0,0 +1,376 @@
|
|||
"""
|
||||
Pydantic models for ra_aid database entities.
|
||||
|
||||
This module defines Pydantic models that correspond to Peewee ORM models,
|
||||
providing validation, serialization, and deserialization capabilities.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
class SessionModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a Session.
|
||||
|
||||
This model corresponds to the Session Peewee ORM model and provides
|
||||
validation and serialization capabilities. It handles the conversion
|
||||
between JSON-encoded strings and Python dictionaries for the machine_info field.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the session
|
||||
created_at: When the session record was created
|
||||
updated_at: When the session record was last updated
|
||||
start_time: When the program session started
|
||||
command_line: Command line arguments used to start the program
|
||||
program_version: Version of the program
|
||||
machine_info: Dictionary containing machine-specific metadata
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
start_time: datetime.datetime
|
||||
command_line: Optional[str] = None
|
||||
program_version: Optional[str] = None
|
||||
machine_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_validator("machine_info", mode="before")
|
||||
@classmethod
|
||||
def parse_machine_info(cls, value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse the machine_info field from a JSON string to a dictionary.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, or None
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The parsed dictionary or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in machine_info: {e}")
|
||||
|
||||
raise ValueError(f"Unexpected type for machine_info: {type(value)}")
|
||||
|
||||
@field_serializer("machine_info")
|
||||
def serialize_machine_info(self, machine_info: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the machine_info dictionary to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
machine_info: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if machine_info is None:
|
||||
return None
|
||||
|
||||
return json.dumps(machine_info)
|
||||
|
||||
|
||||
class HumanInputModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a HumanInput.
|
||||
|
||||
This model corresponds to the HumanInput Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the human input
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
content: The text content of the input
|
||||
source: The source of the input ('cli', 'chat', or 'hil')
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
content: str
|
||||
source: str
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class KeyFactModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a KeyFact.
|
||||
|
||||
This model corresponds to the KeyFact Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the key fact
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
content: The text content of the key fact
|
||||
human_input_id: Optional reference to the associated human input
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
content: str
|
||||
human_input_id: Optional[int] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class KeySnippetModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a KeySnippet.
|
||||
|
||||
This model corresponds to the KeySnippet Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the key snippet
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
filepath: Path to the source file
|
||||
line_number: Line number where the snippet starts
|
||||
snippet: The source code snippet text
|
||||
description: Optional description of the significance
|
||||
human_input_id: Optional reference to the associated human input
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
filepath: str
|
||||
line_number: int
|
||||
snippet: str
|
||||
description: Optional[str] = None
|
||||
human_input_id: Optional[int] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ResearchNoteModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a ResearchNote.
|
||||
|
||||
This model corresponds to the ResearchNote Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the research note
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
content: The text content of the research note
|
||||
human_input_id: Optional reference to the associated human input
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
content: str
|
||||
human_input_id: Optional[int] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TrajectoryModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a Trajectory.
|
||||
|
||||
This model corresponds to the Trajectory Peewee ORM model and provides
|
||||
validation and serialization capabilities. It handles the conversion
|
||||
between JSON-encoded strings and Python dictionaries for the tool_parameters,
|
||||
tool_result, and step_data fields.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the trajectory
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
human_input_id: Optional reference to the associated human input
|
||||
tool_name: Name of the tool that was executed
|
||||
tool_parameters: Dictionary containing the parameters passed to the tool
|
||||
tool_result: Dictionary containing the result returned by the tool
|
||||
step_data: Dictionary containing UI rendering data
|
||||
record_type: Type of trajectory record
|
||||
cost: Optional cost of the tool execution
|
||||
tokens: Optional token usage of the tool execution
|
||||
is_error: Flag indicating if this record represents an error
|
||||
error_message: The error message if is_error is True
|
||||
error_type: The type/class of the error if is_error is True
|
||||
error_details: Additional error details if is_error is True
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
human_input_id: Optional[int] = None
|
||||
tool_name: Optional[str] = None
|
||||
tool_parameters: Optional[Dict[str, Any]] = None
|
||||
tool_result: Optional[Any] = None
|
||||
step_data: Optional[Dict[str, Any]] = None
|
||||
record_type: Optional[str] = None
|
||||
cost: Optional[float] = None
|
||||
tokens: Optional[int] = None
|
||||
is_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_details: Optional[str] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def parse_tool_parameters(cls, value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse the tool_parameters field from a JSON string to a dictionary.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, or None
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The parsed dictionary or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in tool_parameters: {e}")
|
||||
|
||||
raise ValueError(f"Unexpected type for tool_parameters: {type(value)}")
|
||||
|
||||
@field_validator("tool_result", mode="before")
|
||||
@classmethod
|
||||
def parse_tool_result(cls, value: Any) -> Optional[Any]:
|
||||
"""
|
||||
Parse the tool_result field from a JSON string to a Python object.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, list, or None
|
||||
|
||||
Returns:
|
||||
Optional[Any]: The parsed object or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in tool_result: {e}")
|
||||
|
||||
@field_validator("step_data", mode="before")
|
||||
@classmethod
|
||||
def parse_step_data(cls, value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse the step_data field from a JSON string to a dictionary.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, or None
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The parsed dictionary or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in step_data: {e}")
|
||||
|
||||
raise ValueError(f"Unexpected type for step_data: {type(value)}")
|
||||
|
||||
@field_serializer("tool_parameters")
|
||||
def serialize_tool_parameters(self, tool_parameters: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the tool_parameters dictionary to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
tool_parameters: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if tool_parameters is None:
|
||||
return None
|
||||
|
||||
return json.dumps(tool_parameters)
|
||||
|
||||
@field_serializer("tool_result")
|
||||
def serialize_tool_result(self, tool_result: Optional[Any]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the tool_result object to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
tool_result: Object to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if tool_result is None:
|
||||
return None
|
||||
|
||||
return json.dumps(tool_result)
|
||||
|
||||
@field_serializer("step_data")
|
||||
def serialize_step_data(self, step_data: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the step_data dictionary to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
step_data: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if step_data is None:
|
||||
return None
|
||||
|
||||
return json.dumps(step_data)
|
||||
|
|
@ -11,6 +11,7 @@ import contextvars
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import HumanInput
|
||||
from ra_aid.database.pydantic_models import HumanInputModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -119,7 +120,22 @@ class HumanInputRepository:
|
|||
raise ValueError("Database connection is required for HumanInputRepository")
|
||||
self.db = db
|
||||
|
||||
def create(self, content: str, source: str) -> HumanInput:
|
||||
def _to_model(self, human_input: Optional[HumanInput]) -> Optional[HumanInputModel]:
|
||||
"""
|
||||
Convert a Peewee HumanInput object to a Pydantic HumanInputModel.
|
||||
|
||||
Args:
|
||||
human_input: Peewee HumanInput instance or None
|
||||
|
||||
Returns:
|
||||
Optional[HumanInputModel]: Pydantic model representation or None if human_input is None
|
||||
"""
|
||||
if human_input is None:
|
||||
return None
|
||||
|
||||
return HumanInputModel.model_validate(human_input, from_attributes=True)
|
||||
|
||||
def create(self, content: str, source: str) -> HumanInputModel:
|
||||
"""
|
||||
Create a new human input record in the database.
|
||||
|
||||
|
|
@ -128,7 +144,7 @@ class HumanInputRepository:
|
|||
source: The source of the input (e.g., "cli", "chat", "hil")
|
||||
|
||||
Returns:
|
||||
HumanInput: The newly created human input instance
|
||||
HumanInputModel: The newly created human input instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
|
|
@ -136,12 +152,12 @@ class HumanInputRepository:
|
|||
try:
|
||||
input_record = HumanInput.create(content=content, source=source)
|
||||
logger.debug(f"Created human input ID {input_record.id} from {source}")
|
||||
return input_record
|
||||
return self._to_model(input_record)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create human input record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, input_id: int) -> Optional[HumanInput]:
|
||||
def get(self, input_id: int) -> Optional[HumanInputModel]:
|
||||
"""
|
||||
Retrieve a human input record by its ID.
|
||||
|
||||
|
|
@ -149,18 +165,19 @@ class HumanInputRepository:
|
|||
input_id: The ID of the human input to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[HumanInput]: The human input instance if found, None otherwise
|
||||
Optional[HumanInputModel]: The human input instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return HumanInput.get_or_none(HumanInput.id == input_id)
|
||||
human_input = HumanInput.get_or_none(HumanInput.id == input_id)
|
||||
return self._to_model(human_input)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch human input {input_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInput]:
|
||||
def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInputModel]:
|
||||
"""
|
||||
Update an existing human input record.
|
||||
|
||||
|
|
@ -170,14 +187,14 @@ class HumanInputRepository:
|
|||
source: The new source for the human input
|
||||
|
||||
Returns:
|
||||
Optional[HumanInput]: The updated human input if found, None otherwise
|
||||
Optional[HumanInputModel]: The updated human input if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the record
|
||||
"""
|
||||
try:
|
||||
# First check if the record exists
|
||||
input_record = self.get(input_id)
|
||||
# We need to get the raw Peewee object for updating
|
||||
input_record = HumanInput.get_or_none(HumanInput.id == input_id)
|
||||
if not input_record:
|
||||
logger.warning(f"Attempted to update non-existent human input {input_id}")
|
||||
return None
|
||||
|
|
@ -190,7 +207,7 @@ class HumanInputRepository:
|
|||
|
||||
input_record.save()
|
||||
logger.debug(f"Updated human input ID {input_id}")
|
||||
return input_record
|
||||
return self._to_model(input_record)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update human input {input_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -223,23 +240,24 @@ class HumanInputRepository:
|
|||
logger.error(f"Failed to delete human input {input_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[HumanInput]:
|
||||
def get_all(self) -> List[HumanInputModel]:
|
||||
"""
|
||||
Retrieve all human input records from the database.
|
||||
|
||||
Returns:
|
||||
List[HumanInput]: List of all human input instances
|
||||
List[HumanInputModel]: List of all human input instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return list(HumanInput.select().order_by(HumanInput.created_at.desc()))
|
||||
human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc()))
|
||||
return [self._to_model(input) for input in human_inputs]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all human inputs: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_recent(self, limit: int = 10) -> List[HumanInput]:
|
||||
def get_recent(self, limit: int = 10) -> List[HumanInputModel]:
|
||||
"""
|
||||
Retrieve the most recent human input records.
|
||||
|
||||
|
|
@ -247,13 +265,14 @@ class HumanInputRepository:
|
|||
limit: Maximum number of records to retrieve (default: 10)
|
||||
|
||||
Returns:
|
||||
List[HumanInput]: List of the most recent human input records
|
||||
List[HumanInputModel]: List of the most recent human input records
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit))
|
||||
human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit))
|
||||
return [self._to_model(input) for input in human_inputs]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch recent human inputs: {str(e)}")
|
||||
raise
|
||||
|
|
@ -277,7 +296,7 @@ class HumanInputRepository:
|
|||
logger.error(f"Failed to fetch most recent human input ID: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_source(self, source: str) -> List[HumanInput]:
|
||||
def get_by_source(self, source: str) -> List[HumanInputModel]:
|
||||
"""
|
||||
Retrieve human input records by source.
|
||||
|
||||
|
|
@ -285,13 +304,14 @@ class HumanInputRepository:
|
|||
source: The source to filter by (e.g., "cli", "chat", "hil")
|
||||
|
||||
Returns:
|
||||
List[HumanInput]: List of human input records from the specified source
|
||||
List[HumanInputModel]: List of human input records from the specified source
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc()))
|
||||
human_inputs = list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc()))
|
||||
return [self._to_model(input) for input in human_inputs]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch human inputs by source {source}: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from contextlib import contextmanager
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.database.pydantic_models import KeyFactModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -120,7 +121,22 @@ class KeyFactRepository:
|
|||
raise ValueError("Database connection is required for KeyFactRepository")
|
||||
self.db = db
|
||||
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
|
||||
def _to_model(self, fact: Optional[KeyFact]) -> Optional[KeyFactModel]:
|
||||
"""
|
||||
Convert a Peewee KeyFact object to a Pydantic KeyFactModel.
|
||||
|
||||
Args:
|
||||
fact: Peewee KeyFact instance or None
|
||||
|
||||
Returns:
|
||||
Optional[KeyFactModel]: Pydantic model representation or None if fact is None
|
||||
"""
|
||||
if fact is None:
|
||||
return None
|
||||
|
||||
return KeyFactModel.model_validate(fact, from_attributes=True)
|
||||
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFactModel:
|
||||
"""
|
||||
Create a new key fact in the database.
|
||||
|
||||
|
|
@ -129,7 +145,7 @@ class KeyFactRepository:
|
|||
human_input_id: Optional ID of the associated human input
|
||||
|
||||
Returns:
|
||||
KeyFact: The newly created key fact instance
|
||||
KeyFactModel: The newly created key fact instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the fact
|
||||
|
|
@ -137,12 +153,12 @@ class KeyFactRepository:
|
|||
try:
|
||||
fact = KeyFact.create(content=content, human_input_id=human_input_id)
|
||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||
return fact
|
||||
return self._to_model(fact)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create key fact: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, fact_id: int) -> Optional[KeyFact]:
|
||||
def get(self, fact_id: int) -> Optional[KeyFactModel]:
|
||||
"""
|
||||
Retrieve a key fact by its ID.
|
||||
|
||||
|
|
@ -150,18 +166,19 @@ class KeyFactRepository:
|
|||
fact_id: The ID of the key fact to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[KeyFact]: The key fact instance if found, None otherwise
|
||||
Optional[KeyFactModel]: The key fact instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
return self._to_model(fact)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, fact_id: int, content: str) -> Optional[KeyFact]:
|
||||
def update(self, fact_id: int, content: str) -> Optional[KeyFactModel]:
|
||||
"""
|
||||
Update an existing key fact.
|
||||
|
||||
|
|
@ -170,14 +187,14 @@ class KeyFactRepository:
|
|||
content: The new content for the key fact
|
||||
|
||||
Returns:
|
||||
Optional[KeyFact]: The updated key fact if found, None otherwise
|
||||
Optional[KeyFactModel]: The updated key fact if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the fact
|
||||
"""
|
||||
try:
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
|
||||
return None
|
||||
|
|
@ -186,7 +203,7 @@ class KeyFactRepository:
|
|||
fact.content = content
|
||||
fact.save()
|
||||
logger.debug(f"Updated key fact ID {fact_id}: {content}")
|
||||
return fact
|
||||
return self._to_model(fact)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -206,7 +223,7 @@ class KeyFactRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
|
||||
return False
|
||||
|
|
@ -219,18 +236,19 @@ class KeyFactRepository:
|
|||
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[KeyFact]:
|
||||
def get_all(self) -> List[KeyFactModel]:
|
||||
"""
|
||||
Retrieve all key facts from the database.
|
||||
|
||||
Returns:
|
||||
List[KeyFact]: List of all key fact instances
|
||||
List[KeyFactModel]: List of all key fact instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return list(KeyFact.select().order_by(KeyFact.id))
|
||||
facts = list(KeyFact.select().order_by(KeyFact.id))
|
||||
return [self._to_model(fact) for fact in facts]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import contextvars
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import KeySnippet
|
||||
from ra_aid.database.pydantic_models import KeySnippetModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -129,10 +130,25 @@ class KeySnippetRepository:
|
|||
raise ValueError("Database connection is required for KeySnippetRepository")
|
||||
self.db = db
|
||||
|
||||
def _to_model(self, snippet: Optional[KeySnippet]) -> Optional[KeySnippetModel]:
|
||||
"""
|
||||
Convert a Peewee KeySnippet object to a Pydantic KeySnippetModel.
|
||||
|
||||
Args:
|
||||
snippet: Peewee KeySnippet instance or None
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippetModel]: Pydantic model representation or None if snippet is None
|
||||
"""
|
||||
if snippet is None:
|
||||
return None
|
||||
|
||||
return KeySnippetModel.model_validate(snippet, from_attributes=True)
|
||||
|
||||
def create(
|
||||
self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None,
|
||||
human_input_id: Optional[int] = None
|
||||
) -> KeySnippet:
|
||||
) -> KeySnippetModel:
|
||||
"""
|
||||
Create a new key snippet in the database.
|
||||
|
||||
|
|
@ -144,7 +160,7 @@ class KeySnippetRepository:
|
|||
human_input_id: Optional ID of the associated human input
|
||||
|
||||
Returns:
|
||||
KeySnippet: The newly created key snippet instance
|
||||
KeySnippetModel: The newly created key snippet instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the snippet
|
||||
|
|
@ -158,12 +174,12 @@ class KeySnippetRepository:
|
|||
human_input_id=human_input_id
|
||||
)
|
||||
logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}")
|
||||
return key_snippet
|
||||
return self._to_model(key_snippet)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create key snippet: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, snippet_id: int) -> Optional[KeySnippet]:
|
||||
def get(self, snippet_id: int) -> Optional[KeySnippetModel]:
|
||||
"""
|
||||
Retrieve a key snippet by its ID.
|
||||
|
||||
|
|
@ -171,13 +187,14 @@ class KeySnippetRepository:
|
|||
snippet_id: The ID of the key snippet to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippet]: The key snippet instance if found, None otherwise
|
||||
Optional[KeySnippetModel]: The key snippet instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
return self._to_model(snippet)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -189,7 +206,7 @@ class KeySnippetRepository:
|
|||
line_number: int,
|
||||
snippet: str,
|
||||
description: Optional[str] = None
|
||||
) -> Optional[KeySnippet]:
|
||||
) -> Optional[KeySnippetModel]:
|
||||
"""
|
||||
Update an existing key snippet.
|
||||
|
||||
|
|
@ -201,14 +218,14 @@ class KeySnippetRepository:
|
|||
description: Optional description of the significance
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippet]: The updated key snippet if found, None otherwise
|
||||
Optional[KeySnippetModel]: The updated key snippet if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the snippet
|
||||
"""
|
||||
try:
|
||||
# First check if the snippet exists
|
||||
key_snippet = self.get(snippet_id)
|
||||
key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
if not key_snippet:
|
||||
logger.warning(f"Attempted to update non-existent key snippet {snippet_id}")
|
||||
return None
|
||||
|
|
@ -220,7 +237,7 @@ class KeySnippetRepository:
|
|||
key_snippet.description = description
|
||||
key_snippet.save()
|
||||
logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}")
|
||||
return key_snippet
|
||||
return self._to_model(key_snippet)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -240,7 +257,7 @@ class KeySnippetRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the snippet exists
|
||||
key_snippet = self.get(snippet_id)
|
||||
key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
if not key_snippet:
|
||||
logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}")
|
||||
return False
|
||||
|
|
@ -253,18 +270,19 @@ class KeySnippetRepository:
|
|||
logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[KeySnippet]:
|
||||
def get_all(self) -> List[KeySnippetModel]:
|
||||
"""
|
||||
Retrieve all key snippets from the database.
|
||||
|
||||
Returns:
|
||||
List[KeySnippet]: List of all key snippet instances
|
||||
List[KeySnippetModel]: List of all key snippet instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return list(KeySnippet.select().order_by(KeySnippet.id))
|
||||
snippets = list(KeySnippet.select().order_by(KeySnippet.id))
|
||||
return [self._to_model(snippet) for snippet in snippets]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key snippets: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from contextlib import contextmanager
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import ResearchNote
|
||||
from ra_aid.database.pydantic_models import ResearchNoteModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -120,7 +121,22 @@ class ResearchNoteRepository:
|
|||
raise ValueError("Database connection is required for ResearchNoteRepository")
|
||||
self.db = db
|
||||
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNote:
|
||||
def _to_model(self, note: Optional[ResearchNote]) -> Optional[ResearchNoteModel]:
|
||||
"""
|
||||
Convert a Peewee ResearchNote object to a Pydantic ResearchNoteModel.
|
||||
|
||||
Args:
|
||||
note: Peewee ResearchNote instance or None
|
||||
|
||||
Returns:
|
||||
Optional[ResearchNoteModel]: Pydantic model representation or None if note is None
|
||||
"""
|
||||
if note is None:
|
||||
return None
|
||||
|
||||
return ResearchNoteModel.model_validate(note, from_attributes=True)
|
||||
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNoteModel:
|
||||
"""
|
||||
Create a new research note in the database.
|
||||
|
||||
|
|
@ -129,7 +145,7 @@ class ResearchNoteRepository:
|
|||
human_input_id: Optional ID of the associated human input
|
||||
|
||||
Returns:
|
||||
ResearchNote: The newly created research note instance
|
||||
ResearchNoteModel: The newly created research note instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the note
|
||||
|
|
@ -137,12 +153,12 @@ class ResearchNoteRepository:
|
|||
try:
|
||||
note = ResearchNote.create(content=content, human_input_id=human_input_id)
|
||||
logger.debug(f"Created research note ID {note.id}: {content[:50]}...")
|
||||
return note
|
||||
return self._to_model(note)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create research note: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, note_id: int) -> Optional[ResearchNote]:
|
||||
def get(self, note_id: int) -> Optional[ResearchNoteModel]:
|
||||
"""
|
||||
Retrieve a research note by its ID.
|
||||
|
||||
|
|
@ -150,18 +166,19 @@ class ResearchNoteRepository:
|
|||
note_id: The ID of the research note to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[ResearchNote]: The research note instance if found, None otherwise
|
||||
Optional[ResearchNoteModel]: The research note instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
return self._to_model(note)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch research note {note_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, note_id: int, content: str) -> Optional[ResearchNote]:
|
||||
def update(self, note_id: int, content: str) -> Optional[ResearchNoteModel]:
|
||||
"""
|
||||
Update an existing research note.
|
||||
|
||||
|
|
@ -170,14 +187,14 @@ class ResearchNoteRepository:
|
|||
content: The new content for the research note
|
||||
|
||||
Returns:
|
||||
Optional[ResearchNote]: The updated research note if found, None otherwise
|
||||
Optional[ResearchNoteModel]: The updated research note if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the note
|
||||
"""
|
||||
try:
|
||||
# First check if the note exists
|
||||
note = self.get(note_id)
|
||||
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
if not note:
|
||||
logger.warning(f"Attempted to update non-existent research note {note_id}")
|
||||
return None
|
||||
|
|
@ -186,7 +203,7 @@ class ResearchNoteRepository:
|
|||
note.content = content
|
||||
note.save()
|
||||
logger.debug(f"Updated research note ID {note_id}: {content[:50]}...")
|
||||
return note
|
||||
return self._to_model(note)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update research note {note_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -206,7 +223,7 @@ class ResearchNoteRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the note exists
|
||||
note = self.get(note_id)
|
||||
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
if not note:
|
||||
logger.warning(f"Attempted to delete non-existent research note {note_id}")
|
||||
return False
|
||||
|
|
@ -219,18 +236,19 @@ class ResearchNoteRepository:
|
|||
logger.error(f"Failed to delete research note {note_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[ResearchNote]:
|
||||
def get_all(self) -> List[ResearchNoteModel]:
|
||||
"""
|
||||
Retrieve all research notes from the database.
|
||||
|
||||
Returns:
|
||||
List[ResearchNote]: List of all research note instances
|
||||
List[ResearchNoteModel]: List of all research note instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return list(ResearchNote.select().order_by(ResearchNote.id))
|
||||
notes = list(ResearchNote.select().order_by(ResearchNote.id))
|
||||
return [self._to_model(note) for note in notes]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all research notes: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import sys
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import Session
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
from ra_aid.__version__ import __version__
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
|
|
@ -121,7 +122,22 @@ class SessionRepository:
|
|||
self.db = db
|
||||
self.current_session = None
|
||||
|
||||
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
|
||||
def _to_model(self, session: Optional[Session]) -> Optional[SessionModel]:
|
||||
"""
|
||||
Convert a Peewee Session object to a Pydantic SessionModel.
|
||||
|
||||
Args:
|
||||
session: Peewee Session instance or None
|
||||
|
||||
Returns:
|
||||
Optional[SessionModel]: Pydantic model representation or None if session is None
|
||||
"""
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
return SessionModel.model_validate(session, from_attributes=True)
|
||||
|
||||
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> SessionModel:
|
||||
"""
|
||||
Create a new session record in the database.
|
||||
|
||||
|
|
@ -129,7 +145,7 @@ class SessionRepository:
|
|||
metadata: Optional dictionary of additional metadata to store with the session
|
||||
|
||||
Returns:
|
||||
Session: The newly created session instance
|
||||
SessionModel: The newly created session instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
|
|
@ -155,12 +171,12 @@ class SessionRepository:
|
|||
self.current_session = session
|
||||
|
||||
logger.debug(f"Created new session with ID {session.id}")
|
||||
return session
|
||||
return self._to_model(session)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create session record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_session(self) -> Optional[Session]:
|
||||
def get_current_session(self) -> Optional[SessionModel]:
|
||||
"""
|
||||
Get the current active session.
|
||||
|
||||
|
|
@ -168,17 +184,17 @@ class SessionRepository:
|
|||
retrieves the most recent session from the database.
|
||||
|
||||
Returns:
|
||||
Optional[Session]: The current session or None if no sessions exist
|
||||
Optional[SessionModel]: The current session or None if no sessions exist
|
||||
"""
|
||||
if self.current_session is not None:
|
||||
return self.current_session
|
||||
return self._to_model(self.current_session)
|
||||
|
||||
try:
|
||||
# Find the most recent session
|
||||
session = Session.select().order_by(Session.created_at.desc()).first()
|
||||
if session:
|
||||
self.current_session = session
|
||||
return session
|
||||
return self._to_model(session)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get current session: {str(e)}")
|
||||
return None
|
||||
|
|
@ -193,7 +209,7 @@ class SessionRepository:
|
|||
session = self.get_current_session()
|
||||
return session.id if session else None
|
||||
|
||||
def get(self, session_id: int) -> Optional[Session]:
|
||||
def get(self, session_id: int) -> Optional[SessionModel]:
|
||||
"""
|
||||
Get a session by its ID.
|
||||
|
||||
|
|
@ -201,28 +217,30 @@ class SessionRepository:
|
|||
session_id: The ID of the session to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Session]: The session with the given ID or None if not found
|
||||
Optional[SessionModel]: The session with the given ID or None if not found
|
||||
"""
|
||||
try:
|
||||
return Session.get_or_none(Session.id == session_id)
|
||||
session = Session.get_or_none(Session.id == session_id)
|
||||
return self._to_model(session)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_all(self) -> List[Session]:
|
||||
def get_all(self) -> List[SessionModel]:
|
||||
"""
|
||||
Get all sessions from the database.
|
||||
|
||||
Returns:
|
||||
List[Session]: List of all sessions
|
||||
List[SessionModel]: List of all sessions
|
||||
"""
|
||||
try:
|
||||
return list(Session.select().order_by(Session.created_at.desc()))
|
||||
sessions = list(Session.select().order_by(Session.created_at.desc()))
|
||||
return [self._to_model(session) for session in sessions]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get all sessions: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_recent(self, limit: int = 10) -> List[Session]:
|
||||
def get_recent(self, limit: int = 10) -> List[SessionModel]:
|
||||
"""
|
||||
Get the most recent sessions from the database.
|
||||
|
||||
|
|
@ -230,14 +248,15 @@ class SessionRepository:
|
|||
limit: Maximum number of sessions to return (default: 10)
|
||||
|
||||
Returns:
|
||||
List[Session]: List of the most recent sessions
|
||||
List[SessionModel]: List of the most recent sessions
|
||||
"""
|
||||
try:
|
||||
return list(
|
||||
sessions = list(
|
||||
Session.select()
|
||||
.order_by(Session.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [self._to_model(session) for session in sessions]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get recent sessions: {str(e)}")
|
||||
return []
|
||||
|
|
@ -14,6 +14,7 @@ import logging
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import Trajectory, HumanInput
|
||||
from ra_aid.database.pydantic_models import TrajectoryModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -130,6 +131,21 @@ class TrajectoryRepository:
|
|||
raise ValueError("Database connection is required for TrajectoryRepository")
|
||||
self.db = db
|
||||
|
||||
def _to_model(self, trajectory: Optional[Trajectory]) -> Optional[TrajectoryModel]:
|
||||
"""
|
||||
Convert a Peewee Trajectory object to a Pydantic TrajectoryModel.
|
||||
|
||||
Args:
|
||||
trajectory: Peewee Trajectory instance or None
|
||||
|
||||
Returns:
|
||||
Optional[TrajectoryModel]: Pydantic model representation or None if trajectory is None
|
||||
"""
|
||||
if trajectory is None:
|
||||
return None
|
||||
|
||||
return TrajectoryModel.model_validate(trajectory, from_attributes=True)
|
||||
|
||||
def create(
|
||||
self,
|
||||
tool_name: Optional[str] = None,
|
||||
|
|
@ -144,7 +160,7 @@ class TrajectoryRepository:
|
|||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None,
|
||||
error_details: Optional[str] = None
|
||||
) -> Trajectory:
|
||||
) -> TrajectoryModel:
|
||||
"""
|
||||
Create a new trajectory record in the database.
|
||||
|
||||
|
|
@ -163,7 +179,7 @@ class TrajectoryRepository:
|
|||
error_details: Additional error details like stack traces (if is_error is True)
|
||||
|
||||
Returns:
|
||||
Trajectory: The newly created trajectory instance
|
||||
TrajectoryModel: The newly created trajectory instance as a Pydantic model
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
|
|
@ -201,12 +217,12 @@ class TrajectoryRepository:
|
|||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||
else:
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
|
||||
return trajectory
|
||||
return self._to_model(trajectory)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create trajectory record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, trajectory_id: int) -> Optional[Trajectory]:
|
||||
def get(self, trajectory_id: int) -> Optional[TrajectoryModel]:
|
||||
"""
|
||||
Retrieve a trajectory record by its ID.
|
||||
|
||||
|
|
@ -214,13 +230,14 @@ class TrajectoryRepository:
|
|||
trajectory_id: The ID of the trajectory record to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Trajectory]: The trajectory instance if found, None otherwise
|
||||
Optional[TrajectoryModel]: The trajectory instance as a Pydantic model if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
return self._to_model(trajectory)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -236,7 +253,7 @@ class TrajectoryRepository:
|
|||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None,
|
||||
error_details: Optional[str] = None
|
||||
) -> Optional[Trajectory]:
|
||||
) -> Optional[TrajectoryModel]:
|
||||
"""
|
||||
Update an existing trajectory record.
|
||||
|
||||
|
|
@ -254,15 +271,15 @@ class TrajectoryRepository:
|
|||
error_details: Additional error details like stack traces
|
||||
|
||||
Returns:
|
||||
Optional[Trajectory]: The updated trajectory if found, None otherwise
|
||||
Optional[TrajectoryModel]: The updated trajectory as a Pydantic model if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the record
|
||||
"""
|
||||
try:
|
||||
# First check if the trajectory exists
|
||||
trajectory = self.get(trajectory_id)
|
||||
if not trajectory:
|
||||
peewee_trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
if not peewee_trajectory:
|
||||
logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}")
|
||||
return None
|
||||
|
||||
|
|
@ -299,7 +316,7 @@ class TrajectoryRepository:
|
|||
logger.debug(f"Updated trajectory record ID {trajectory_id}")
|
||||
return self.get(trajectory_id)
|
||||
|
||||
return trajectory
|
||||
return self._to_model(peewee_trajectory)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -319,7 +336,7 @@ class TrajectoryRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the trajectory exists
|
||||
trajectory = self.get(trajectory_id)
|
||||
trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
if not trajectory:
|
||||
logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}")
|
||||
return False
|
||||
|
|
@ -332,23 +349,24 @@ class TrajectoryRepository:
|
|||
logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> Dict[int, Trajectory]:
|
||||
def get_all(self) -> Dict[int, TrajectoryModel]:
|
||||
"""
|
||||
Retrieve all trajectory records from the database.
|
||||
|
||||
Returns:
|
||||
Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances
|
||||
Dict[int, TrajectoryModel]: Dictionary mapping trajectory IDs to trajectory Pydantic models
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)}
|
||||
trajectories = Trajectory.select().order_by(Trajectory.id)
|
||||
return {trajectory.id: self._to_model(trajectory) for trajectory in trajectories}
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all trajectories: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]:
|
||||
def get_trajectories_by_human_input(self, human_input_id: int) -> List[TrajectoryModel]:
|
||||
"""
|
||||
Retrieve all trajectory records associated with a specific human input.
|
||||
|
||||
|
|
@ -356,37 +374,19 @@ class TrajectoryRepository:
|
|||
human_input_id: The ID of the human input to get trajectories for
|
||||
|
||||
Returns:
|
||||
List[Trajectory]: List of trajectory instances associated with the human input
|
||||
List[TrajectoryModel]: List of trajectory Pydantic models associated with the human input
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
|
||||
trajectories = list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
|
||||
return [self._to_model(trajectory) for trajectory in trajectories]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse a JSON string into a Python dictionary.
|
||||
|
||||
Args:
|
||||
json_str: JSON string to parse
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid
|
||||
"""
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error parsing JSON field: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]:
|
||||
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[TrajectoryModel]:
|
||||
"""
|
||||
Get a trajectory record with JSON fields parsed into dictionaries.
|
||||
|
||||
|
|
@ -394,27 +394,7 @@ class TrajectoryRepository:
|
|||
trajectory_id: ID of the trajectory to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields,
|
||||
Optional[TrajectoryModel]: The trajectory as a Pydantic model with parsed JSON fields,
|
||||
or None if not found
|
||||
"""
|
||||
trajectory = self.get(trajectory_id)
|
||||
if trajectory is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": trajectory.id,
|
||||
"created_at": trajectory.created_at,
|
||||
"updated_at": trajectory.updated_at,
|
||||
"tool_name": trajectory.tool_name,
|
||||
"tool_parameters": self.parse_json_field(trajectory.tool_parameters),
|
||||
"tool_result": self.parse_json_field(trajectory.tool_result),
|
||||
"step_data": self.parse_json_field(trajectory.step_data),
|
||||
"record_type": trajectory.record_type,
|
||||
"cost": trajectory.cost,
|
||||
"tokens": trajectory.tokens,
|
||||
"human_input_id": trajectory.human_input.id if trajectory.human_input else None,
|
||||
"is_error": trajectory.is_error,
|
||||
"error_message": trajectory.error_message,
|
||||
"error_type": trajectory.error_type,
|
||||
"error_details": trajectory.error_details,
|
||||
}
|
||||
return self.get(trajectory_id)
|
||||
|
|
@ -0,0 +1,198 @@
|
|||
"""
|
||||
Tests for the human input repository.
|
||||
|
||||
This module provides tests for the HumanInputRepository class,
|
||||
ensuring it correctly interfaces with the database and returns
|
||||
appropriate Pydantic models.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import pytest
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
from ra_aid.database.models import HumanInput, Session, database_proxy
|
||||
from ra_aid.database.pydantic_models import HumanInputModel, SessionModel
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.repositories.session_repository import SessionRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""Fixture for creating a test database."""
|
||||
# Create an in-memory SQLite database for testing
|
||||
test_db = SqliteDatabase(':memory:')
|
||||
|
||||
# Register the models with the test database
|
||||
with test_db.bind_ctx([HumanInput, Session]):
|
||||
# Create the tables
|
||||
test_db.create_tables([HumanInput, Session])
|
||||
|
||||
# Return the test database for use in the tests
|
||||
yield test_db
|
||||
|
||||
# Drop the tables after the tests
|
||||
test_db.drop_tables([HumanInput, Session])
|
||||
|
||||
|
||||
class TestHumanInputRepository(unittest.TestCase):
|
||||
"""Test case for the HumanInputRepository class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up the test case with a test database and repositories."""
|
||||
# Create an in-memory database for testing
|
||||
self.db = SqliteDatabase(':memory:')
|
||||
|
||||
# Register the models with the test database
|
||||
self.models = [HumanInput, Session]
|
||||
self.db.bind(self.models)
|
||||
|
||||
# Create the tables
|
||||
self.db.create_tables(self.models)
|
||||
|
||||
# Create repository instances for testing
|
||||
self.repository = HumanInputRepository(self.db)
|
||||
self.session_repository = SessionRepository(self.db)
|
||||
|
||||
# Bind the test database to the repository model
|
||||
database_proxy.initialize(self.db)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after the test case."""
|
||||
# Close the database connection
|
||||
self.db.close()
|
||||
|
||||
def test_create(self):
|
||||
"""Test creating a human input record."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create a human input
|
||||
content = "Test human input"
|
||||
source = "cli"
|
||||
human_input = self.repository.create(content=content, source=source)
|
||||
|
||||
# Verify the human input was created
|
||||
self.assertIsInstance(human_input, HumanInputModel)
|
||||
self.assertEqual(human_input.content, content)
|
||||
self.assertEqual(human_input.source, source)
|
||||
|
||||
def test_get(self):
|
||||
"""Test retrieving a human input record by ID."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create a human input
|
||||
content = "Test human input"
|
||||
source = "chat"
|
||||
created_input = self.repository.create(content=content, source=source)
|
||||
|
||||
# Get the human input by ID
|
||||
retrieved_input = self.repository.get(created_input.id)
|
||||
|
||||
# Verify the human input was retrieved correctly
|
||||
self.assertIsInstance(retrieved_input, HumanInputModel)
|
||||
self.assertEqual(retrieved_input.id, created_input.id)
|
||||
self.assertEqual(retrieved_input.content, content)
|
||||
self.assertEqual(retrieved_input.source, source)
|
||||
|
||||
def test_update(self):
|
||||
"""Test updating a human input record."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create a human input
|
||||
content = "Original content"
|
||||
source = "cli"
|
||||
created_input = self.repository.create(content=content, source=source)
|
||||
|
||||
# Update the human input
|
||||
new_content = "Updated content"
|
||||
updated_input = self.repository.update(created_input.id, content=new_content)
|
||||
|
||||
# Verify the human input was updated correctly
|
||||
self.assertIsInstance(updated_input, HumanInputModel)
|
||||
self.assertEqual(updated_input.id, created_input.id)
|
||||
self.assertEqual(updated_input.content, new_content)
|
||||
self.assertEqual(updated_input.source, source)
|
||||
|
||||
def test_get_all(self):
|
||||
"""Test retrieving all human input records."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create multiple human inputs
|
||||
self.repository.create(content="Input 1", source="cli")
|
||||
self.repository.create(content="Input 2", source="chat")
|
||||
self.repository.create(content="Input 3", source="hil")
|
||||
|
||||
# Get all human inputs
|
||||
all_inputs = self.repository.get_all()
|
||||
|
||||
# Verify all human inputs were retrieved
|
||||
self.assertEqual(len(all_inputs), 3)
|
||||
self.assertIsInstance(all_inputs[0], HumanInputModel)
|
||||
|
||||
# Verify the inputs are ordered by created_at in descending order
|
||||
self.assertEqual(all_inputs[0].content, "Input 3")
|
||||
self.assertEqual(all_inputs[1].content, "Input 2")
|
||||
self.assertEqual(all_inputs[2].content, "Input 1")
|
||||
|
||||
def test_get_recent(self):
|
||||
"""Test retrieving the most recent human input records."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create multiple human inputs
|
||||
self.repository.create(content="Input 1", source="cli")
|
||||
self.repository.create(content="Input 2", source="chat")
|
||||
self.repository.create(content="Input 3", source="hil")
|
||||
self.repository.create(content="Input 4", source="cli")
|
||||
self.repository.create(content="Input 5", source="chat")
|
||||
|
||||
# Get recent human inputs with a limit of 3
|
||||
recent_inputs = self.repository.get_recent(limit=3)
|
||||
|
||||
# Verify only the 3 most recent inputs were retrieved
|
||||
self.assertEqual(len(recent_inputs), 3)
|
||||
self.assertIsInstance(recent_inputs[0], HumanInputModel)
|
||||
self.assertEqual(recent_inputs[0].content, "Input 5")
|
||||
self.assertEqual(recent_inputs[1].content, "Input 4")
|
||||
self.assertEqual(recent_inputs[2].content, "Input 3")
|
||||
|
||||
def test_get_by_source(self):
|
||||
"""Test retrieving human input records by source."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create human inputs with different sources
|
||||
self.repository.create(content="CLI Input 1", source="cli")
|
||||
self.repository.create(content="Chat Input 1", source="chat")
|
||||
self.repository.create(content="HIL Input", source="hil")
|
||||
self.repository.create(content="CLI Input 2", source="cli")
|
||||
self.repository.create(content="Chat Input 2", source="chat")
|
||||
|
||||
# Get human inputs for the 'cli' source
|
||||
cli_inputs = self.repository.get_by_source("cli")
|
||||
|
||||
# Verify only cli inputs were retrieved
|
||||
self.assertEqual(len(cli_inputs), 2)
|
||||
self.assertIsInstance(cli_inputs[0], HumanInputModel)
|
||||
self.assertEqual(cli_inputs[0].content, "CLI Input 2")
|
||||
self.assertEqual(cli_inputs[1].content, "CLI Input 1")
|
||||
|
||||
def test_get_most_recent_id(self):
|
||||
"""Test retrieving the ID of the most recent human input record."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create multiple human inputs
|
||||
self.repository.create(content="Input 1", source="cli")
|
||||
input2 = self.repository.create(content="Input 2", source="chat")
|
||||
|
||||
# Get the most recent ID
|
||||
most_recent_id = self.repository.get_most_recent_id()
|
||||
|
||||
# Verify the correct ID was retrieved
|
||||
self.assertEqual(most_recent_id, input2.id)
|
||||
|
|
@ -15,6 +15,7 @@ from ra_aid.database.repositories.key_fact_repository import (
|
|||
get_key_fact_repository,
|
||||
key_fact_repo_var
|
||||
)
|
||||
from ra_aid.database.pydantic_models import KeyFactModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -87,7 +88,8 @@ def test_create_key_fact(setup_db):
|
|||
content = "Test key fact"
|
||||
fact = repo.create(content)
|
||||
|
||||
# Verify the fact was created correctly
|
||||
# Verify the fact was created correctly and is a KeyFactModel
|
||||
assert isinstance(fact, KeyFactModel)
|
||||
assert fact.id is not None
|
||||
assert fact.content == content
|
||||
|
||||
|
|
@ -108,7 +110,8 @@ def test_get_key_fact(setup_db):
|
|||
# Retrieve the fact by ID
|
||||
retrieved_fact = repo.get(fact.id)
|
||||
|
||||
# Verify the retrieved fact matches the original
|
||||
# Verify the retrieved fact matches the original and is a KeyFactModel
|
||||
assert isinstance(retrieved_fact, KeyFactModel)
|
||||
assert retrieved_fact is not None
|
||||
assert retrieved_fact.id == fact.id
|
||||
assert retrieved_fact.content == content
|
||||
|
|
@ -131,7 +134,8 @@ def test_update_key_fact(setup_db):
|
|||
new_content = "Updated content"
|
||||
updated_fact = repo.update(fact.id, new_content)
|
||||
|
||||
# Verify the fact was updated correctly
|
||||
# Verify the fact was updated correctly and is a KeyFactModel
|
||||
assert isinstance(updated_fact, KeyFactModel)
|
||||
assert updated_fact is not None
|
||||
assert updated_fact.id == fact.id
|
||||
assert updated_fact.content == new_content
|
||||
|
|
@ -184,8 +188,10 @@ def test_get_all_key_facts(setup_db):
|
|||
# Retrieve all facts
|
||||
all_facts = repo.get_all()
|
||||
|
||||
# Verify we got the correct number of facts
|
||||
# Verify we got the correct number of facts and they are KeyFactModel instances
|
||||
assert len(all_facts) == len(contents)
|
||||
for fact in all_facts:
|
||||
assert isinstance(fact, KeyFactModel)
|
||||
|
||||
# Verify the content of each fact
|
||||
fact_contents = [fact.content for fact in all_facts]
|
||||
|
|
@ -237,6 +243,7 @@ def test_key_fact_repository_manager(setup_db, cleanup_repo):
|
|||
# Verify we can use the repository
|
||||
content = "Test fact via context manager"
|
||||
fact = repo.create(content)
|
||||
assert isinstance(fact, KeyFactModel)
|
||||
assert fact.id is not None
|
||||
assert fact.content == content
|
||||
|
||||
|
|
@ -259,3 +266,25 @@ def test_get_key_fact_repository_when_not_set(cleanup_repo):
|
|||
|
||||
# Verify the correct error message
|
||||
assert "No KeyFactRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_to_model_method(setup_db):
|
||||
"""Test the _to_model method converts KeyFact to KeyFactModel correctly."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create a Peewee KeyFact directly
|
||||
peewee_fact = KeyFact.create(content="Test fact for conversion")
|
||||
|
||||
# Convert to Pydantic model
|
||||
pydantic_fact = repo._to_model(peewee_fact)
|
||||
|
||||
# Verify conversion was correct
|
||||
assert isinstance(pydantic_fact, KeyFactModel)
|
||||
assert pydantic_fact.id == peewee_fact.id
|
||||
assert pydantic_fact.content == peewee_fact.content
|
||||
assert pydantic_fact.created_at == peewee_fact.created_at
|
||||
assert pydantic_fact.updated_at == peewee_fact.updated_at
|
||||
|
||||
# Test with None input
|
||||
assert repo._to_model(None) is None
|
||||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import KeySnippet
|
||||
from ra_aid.database.pydantic_models import KeySnippetModel
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
|
||||
|
||||
|
|
@ -79,6 +80,9 @@ def test_create_key_snippet(setup_db):
|
|||
assert key_snippet.snippet == snippet
|
||||
assert key_snippet.description == description
|
||||
|
||||
# Verify the return type is a Pydantic model
|
||||
assert isinstance(key_snippet, KeySnippetModel)
|
||||
|
||||
# Verify we can retrieve it from the database
|
||||
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||
assert snippet_from_db.filepath == filepath
|
||||
|
|
@ -116,6 +120,9 @@ def test_get_key_snippet(setup_db):
|
|||
assert retrieved_snippet.snippet == snippet
|
||||
assert retrieved_snippet.description == description
|
||||
|
||||
# Verify the return type is a Pydantic model
|
||||
assert isinstance(retrieved_snippet, KeySnippetModel)
|
||||
|
||||
# Try to retrieve a non-existent snippet
|
||||
non_existent_snippet = repo.get(999)
|
||||
assert non_existent_snippet is None
|
||||
|
|
@ -161,6 +168,9 @@ def test_update_key_snippet(setup_db):
|
|||
assert updated_snippet.snippet == new_snippet
|
||||
assert updated_snippet.description == new_description
|
||||
|
||||
# Verify the return type is a Pydantic model
|
||||
assert isinstance(updated_snippet, KeySnippetModel)
|
||||
|
||||
# Verify we can retrieve the updated content from the database
|
||||
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||
assert snippet_from_db.filepath == new_filepath
|
||||
|
|
@ -250,6 +260,9 @@ def test_get_all_key_snippets(setup_db):
|
|||
# Verify we got the correct number of snippets
|
||||
assert len(all_snippets) == len(snippets_data)
|
||||
|
||||
# Verify all returned snippets are Pydantic models
|
||||
assert all(isinstance(snippet, KeySnippetModel) for snippet in all_snippets)
|
||||
|
||||
# Verify the content of each snippet
|
||||
for i, snippet in enumerate(all_snippets):
|
||||
assert snippet.filepath == snippets_data[i]["filepath"]
|
||||
|
|
@ -302,3 +315,30 @@ def test_get_snippets_dict(setup_db):
|
|||
assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"]
|
||||
assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"]
|
||||
assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"]
|
||||
|
||||
|
||||
def test_to_model_conversion(setup_db):
|
||||
"""Test conversion from Peewee model to Pydantic model."""
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create a snippet in the database using Peewee directly
|
||||
peewee_snippet = KeySnippet.create(
|
||||
filepath="conversion_test.py",
|
||||
line_number=100,
|
||||
snippet="def conversion_test():",
|
||||
description="Test model conversion"
|
||||
)
|
||||
|
||||
# Use the _to_model method to convert it
|
||||
pydantic_snippet = repo._to_model(peewee_snippet)
|
||||
|
||||
# Verify the conversion was successful
|
||||
assert isinstance(pydantic_snippet, KeySnippetModel)
|
||||
assert pydantic_snippet.id == peewee_snippet.id
|
||||
assert pydantic_snippet.filepath == peewee_snippet.filepath
|
||||
assert pydantic_snippet.line_number == peewee_snippet.line_number
|
||||
assert pydantic_snippet.snippet == peewee_snippet.snippet
|
||||
assert pydantic_snippet.description == peewee_snippet.description
|
||||
|
||||
# Test conversion of None
|
||||
assert repo._to_model(None) is None
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
"""
|
||||
Tests for the Pydantic models in ra_aid.database.pydantic_models
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.models import Session
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
|
||||
class TestSessionModel:
|
||||
"""Tests for the SessionModel Pydantic model"""
|
||||
|
||||
def test_from_peewee_model(self):
|
||||
"""Test conversion from a Peewee model instance"""
|
||||
# Create a Peewee Session instance
|
||||
now = datetime.datetime.now()
|
||||
metadata = {"os": "Linux", "cpu_cores": 8, "memory_gb": 16}
|
||||
session = Session(
|
||||
id=1,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid run",
|
||||
program_version="1.0.0",
|
||||
machine_info=json.dumps(metadata)
|
||||
)
|
||||
|
||||
# Convert to Pydantic model
|
||||
session_model = SessionModel.model_validate(session, from_attributes=True)
|
||||
|
||||
# Verify fields
|
||||
assert session_model.id == 1
|
||||
assert session_model.created_at == now
|
||||
assert session_model.updated_at == now
|
||||
assert session_model.start_time == now
|
||||
assert session_model.command_line == "ra-aid run"
|
||||
assert session_model.program_version == "1.0.0"
|
||||
assert session_model.machine_info == metadata
|
||||
|
||||
def test_with_dict_machine_info(self):
|
||||
"""Test creating a model with a dict for machine_info"""
|
||||
# Create directly with a dict for machine_info
|
||||
now = datetime.datetime.now()
|
||||
metadata = {"os": "Windows", "cpu_cores": 4, "memory_gb": 8}
|
||||
|
||||
session_model = SessionModel(
|
||||
id=2,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid --debug",
|
||||
program_version="1.0.1",
|
||||
machine_info=metadata
|
||||
)
|
||||
|
||||
# Verify fields
|
||||
assert session_model.id == 2
|
||||
assert session_model.machine_info == metadata
|
||||
|
||||
def test_with_none_machine_info(self):
|
||||
"""Test creating a model with None for machine_info"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
session_model = SessionModel(
|
||||
id=3,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid",
|
||||
program_version="1.0.0",
|
||||
machine_info=None
|
||||
)
|
||||
|
||||
assert session_model.id == 3
|
||||
assert session_model.machine_info is None
|
||||
|
||||
def test_invalid_json_machine_info(self):
|
||||
"""Test error handling for invalid JSON in machine_info"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Invalid JSON string should raise ValueError
|
||||
with pytest.raises(ValueError):
|
||||
SessionModel(
|
||||
id=4,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid",
|
||||
program_version="1.0.0",
|
||||
machine_info="{invalid json}"
|
||||
)
|
||||
|
||||
def test_unexpected_type_machine_info(self):
|
||||
"""Test error handling for unexpected type in machine_info"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Integer type should raise ValueError
|
||||
with pytest.raises(ValueError):
|
||||
SessionModel(
|
||||
id=5,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid",
|
||||
program_version="1.0.0",
|
||||
machine_info=123 # Not a dict or string
|
||||
)
|
||||
|
|
@ -9,6 +9,7 @@ import peewee
|
|||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import ResearchNote, BaseModel
|
||||
from ra_aid.database.pydantic_models import ResearchNoteModel
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
ResearchNoteRepository,
|
||||
ResearchNoteRepositoryManager,
|
||||
|
|
@ -90,10 +91,12 @@ def test_create_research_note(setup_db):
|
|||
# Verify the note was created correctly
|
||||
assert note.id is not None
|
||||
assert note.content == content
|
||||
assert isinstance(note, ResearchNoteModel)
|
||||
|
||||
# Verify we can retrieve it from the database using the repository
|
||||
note_from_db = repo.get(note.id)
|
||||
assert note_from_db.content == content
|
||||
assert isinstance(note_from_db, ResearchNoteModel)
|
||||
|
||||
|
||||
def test_get_research_note(setup_db):
|
||||
|
|
@ -112,6 +115,7 @@ def test_get_research_note(setup_db):
|
|||
assert retrieved_note is not None
|
||||
assert retrieved_note.id == note.id
|
||||
assert retrieved_note.content == content
|
||||
assert isinstance(retrieved_note, ResearchNoteModel)
|
||||
|
||||
# Try to retrieve a non-existent note
|
||||
non_existent_note = repo.get(999)
|
||||
|
|
@ -135,10 +139,12 @@ def test_update_research_note(setup_db):
|
|||
assert updated_note is not None
|
||||
assert updated_note.id == note.id
|
||||
assert updated_note.content == new_content
|
||||
assert isinstance(updated_note, ResearchNoteModel)
|
||||
|
||||
# Verify we can retrieve the updated content from the database using the repository
|
||||
note_from_db = repo.get(note.id)
|
||||
assert note_from_db.content == new_content
|
||||
assert isinstance(note_from_db, ResearchNoteModel)
|
||||
|
||||
# Try to update a non-existent note
|
||||
non_existent_update = repo.update(999, "This shouldn't work")
|
||||
|
|
@ -187,8 +193,11 @@ def test_get_all_research_notes(setup_db):
|
|||
# Verify we got the correct number of notes
|
||||
assert len(all_notes) == len(contents)
|
||||
|
||||
# Verify the content of each note
|
||||
# Verify the content of each note and that they are Pydantic models
|
||||
note_contents = [note.content for note in all_notes]
|
||||
for note in all_notes:
|
||||
assert isinstance(note, ResearchNoteModel)
|
||||
|
||||
for content in contents:
|
||||
assert content in note_contents
|
||||
|
||||
|
|
@ -239,6 +248,7 @@ def test_research_note_repository_manager(setup_db, cleanup_repo):
|
|||
note = repo.create(content)
|
||||
assert note.id is not None
|
||||
assert note.content == content
|
||||
assert isinstance(note, ResearchNoteModel)
|
||||
|
||||
# Verify we can get the repository using get_research_note_repository
|
||||
repo_from_var = get_research_note_repository()
|
||||
|
|
@ -259,3 +269,25 @@ def test_get_research_note_repository_when_not_set(cleanup_repo):
|
|||
|
||||
# Verify the correct error message
|
||||
assert "No ResearchNoteRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_to_model_method(setup_db):
|
||||
"""Test the _to_model method converts Peewee models to Pydantic models correctly."""
|
||||
# Set up repository
|
||||
repo = ResearchNoteRepository(db=setup_db)
|
||||
|
||||
# Create a Peewee ResearchNote directly
|
||||
peewee_note = ResearchNote.create(content="Test note for conversion")
|
||||
|
||||
# Convert it using _to_model
|
||||
pydantic_note = repo._to_model(peewee_note)
|
||||
|
||||
# Verify the conversion
|
||||
assert isinstance(pydantic_note, ResearchNoteModel)
|
||||
assert pydantic_note.id == peewee_note.id
|
||||
assert pydantic_note.content == peewee_note.content
|
||||
assert pydantic_note.created_at == peewee_note.created_at
|
||||
assert pydantic_note.updated_at == peewee_note.updated_at
|
||||
|
||||
# Test with None
|
||||
assert repo._to_model(None) is None
|
||||
|
|
@ -0,0 +1,347 @@
|
|||
"""
|
||||
Tests for the SessionRepository class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import Session, BaseModel
|
||||
from ra_aid.database.repositories.session_repository import (
|
||||
SessionRepository,
|
||||
SessionRepositoryManager,
|
||||
get_session_repository,
|
||||
session_repo_var
|
||||
)
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_repo():
|
||||
"""Reset the repository contextvar after each test."""
|
||||
# Reset before the test
|
||||
session_repo_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
session_repo_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the Session table and patch the BaseModel.Meta.database."""
|
||||
# Initialize an in-memory database connection
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Patch the BaseModel.Meta.database to use our in-memory database
|
||||
with patch.object(BaseModel._meta, 'database', db):
|
||||
# Create the Session table
|
||||
with db.atomic():
|
||||
db.create_tables([Session], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
Session.drop_table(safe=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_metadata():
|
||||
"""Return test metadata for sessions."""
|
||||
return {
|
||||
"os": "Test OS",
|
||||
"version": "1.0",
|
||||
"cpu_cores": 4,
|
||||
"memory_gb": 16,
|
||||
"additional_info": {
|
||||
"gpu": "Test GPU",
|
||||
"display_resolution": "1920x1080"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session(setup_db, test_metadata):
|
||||
"""Create a sample session in the database."""
|
||||
now = datetime.datetime.now()
|
||||
return Session.create(
|
||||
start_time=now,
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info=json.dumps(test_metadata)
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_with_metadata(setup_db, test_metadata):
|
||||
"""Test creating a session with metadata."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create a session with metadata
|
||||
session = repo.create_session(metadata=test_metadata)
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(session, SessionModel)
|
||||
|
||||
# Verify the session was created correctly
|
||||
assert session.id is not None
|
||||
assert session.command_line is not None
|
||||
assert session.program_version is not None
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(session.machine_info, dict)
|
||||
assert session.machine_info == test_metadata
|
||||
|
||||
# Verify the dictionary structure is preserved
|
||||
assert "additional_info" in session.machine_info
|
||||
assert session.machine_info["additional_info"]["gpu"] == "Test GPU"
|
||||
|
||||
|
||||
def test_create_session_without_metadata(setup_db):
|
||||
"""Test creating a session without metadata."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create a session without metadata
|
||||
session = repo.create_session()
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(session, SessionModel)
|
||||
|
||||
# Verify the session was created correctly
|
||||
assert session.id is not None
|
||||
assert session.command_line is not None
|
||||
assert session.program_version is not None
|
||||
|
||||
# Verify machine_info is None
|
||||
assert session.machine_info is None
|
||||
|
||||
|
||||
def test_get_current_session(setup_db, sample_session):
|
||||
"""Test retrieving the current session."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Set the current session
|
||||
repo.current_session = sample_session
|
||||
|
||||
# Get the current session
|
||||
current_session = repo.get_current_session()
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(current_session, SessionModel)
|
||||
|
||||
# Verify the retrieved session matches the original
|
||||
assert current_session.id == sample_session.id
|
||||
assert current_session.command_line == sample_session.command_line
|
||||
assert current_session.program_version == sample_session.program_version
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(current_session.machine_info, dict)
|
||||
|
||||
|
||||
def test_get_current_session_from_db(setup_db, sample_session):
|
||||
"""Test retrieving the current session from the database when no current session is set."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Get the current session (should retrieve the most recent from DB)
|
||||
current_session = repo.get_current_session()
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(current_session, SessionModel)
|
||||
|
||||
# Verify the retrieved session matches the sample session
|
||||
assert current_session.id == sample_session.id
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(current_session.machine_info, dict)
|
||||
|
||||
|
||||
def test_get_by_id(setup_db, sample_session):
|
||||
"""Test retrieving a session by ID."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Get the session by ID
|
||||
session = repo.get(sample_session.id)
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(session, SessionModel)
|
||||
|
||||
# Verify the retrieved session matches the original
|
||||
assert session.id == sample_session.id
|
||||
assert session.command_line == sample_session.command_line
|
||||
assert session.program_version == sample_session.program_version
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(session.machine_info, dict)
|
||||
|
||||
# Verify getting a non-existent session returns None
|
||||
non_existent_session = repo.get(999)
|
||||
assert non_existent_session is None
|
||||
|
||||
|
||||
def test_get_all(setup_db):
|
||||
"""Test retrieving all sessions."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create multiple sessions
|
||||
metadata1 = {"os": "Linux", "cpu_cores": 8}
|
||||
metadata2 = {"os": "Windows", "cpu_cores": 4}
|
||||
metadata3 = {"os": "macOS", "cpu_cores": 10}
|
||||
|
||||
repo.create_session(metadata=metadata1)
|
||||
repo.create_session(metadata=metadata2)
|
||||
repo.create_session(metadata=metadata3)
|
||||
|
||||
# Get all sessions
|
||||
sessions = repo.get_all()
|
||||
|
||||
# Verify we got a list of SessionModel objects
|
||||
assert len(sessions) == 3
|
||||
for session in sessions:
|
||||
assert isinstance(session, SessionModel)
|
||||
assert isinstance(session.machine_info, dict)
|
||||
|
||||
# Verify the sessions are in descending order of creation time
|
||||
assert sessions[0].created_at >= sessions[1].created_at
|
||||
assert sessions[1].created_at >= sessions[2].created_at
|
||||
|
||||
# Verify the machine_info fields
|
||||
os_values = [session.machine_info["os"] for session in sessions]
|
||||
assert "Linux" in os_values
|
||||
assert "Windows" in os_values
|
||||
assert "macOS" in os_values
|
||||
|
||||
|
||||
def test_get_all_empty(setup_db):
|
||||
"""Test retrieving all sessions when none exist."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Get all sessions
|
||||
sessions = repo.get_all()
|
||||
|
||||
# Verify we got an empty list
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) == 0
|
||||
|
||||
|
||||
def test_get_recent(setup_db):
|
||||
"""Test retrieving recent sessions with a limit."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create multiple sessions
|
||||
for i in range(5):
|
||||
metadata = {"index": i, "os": f"OS {i}"}
|
||||
repo.create_session(metadata=metadata)
|
||||
|
||||
# Get recent sessions with limit=3
|
||||
sessions = repo.get_recent(limit=3)
|
||||
|
||||
# Verify we got the correct number of SessionModel objects
|
||||
assert len(sessions) == 3
|
||||
for session in sessions:
|
||||
assert isinstance(session, SessionModel)
|
||||
assert isinstance(session.machine_info, dict)
|
||||
|
||||
# Verify the sessions are in descending order and are the most recent ones
|
||||
indexes = [session.machine_info["index"] for session in sessions]
|
||||
assert indexes == [4, 3, 2] # Most recent first
|
||||
|
||||
|
||||
def test_session_repository_manager(setup_db, cleanup_repo):
|
||||
"""Test the SessionRepositoryManager context manager."""
|
||||
# Use the context manager to create a repository
|
||||
with SessionRepositoryManager(setup_db) as repo:
|
||||
# Verify the repository was created correctly
|
||||
assert isinstance(repo, SessionRepository)
|
||||
assert repo.db is setup_db
|
||||
|
||||
# Create a session and verify it's a SessionModel
|
||||
metadata = {"test": "manager"}
|
||||
session = repo.create_session(metadata=metadata)
|
||||
assert isinstance(session, SessionModel)
|
||||
assert session.machine_info["test"] == "manager"
|
||||
|
||||
# Verify we can get the repository using get_session_repository
|
||||
repo_from_var = get_session_repository()
|
||||
assert repo_from_var is repo
|
||||
|
||||
# Verify the repository was removed from the context var
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_session_repository()
|
||||
|
||||
assert "No SessionRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_repository_init_without_db():
|
||||
"""Test that SessionRepository raises an error when initialized without a db parameter."""
|
||||
# Attempt to create a repository without a database connection
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
SessionRepository(db=None)
|
||||
|
||||
# Verify the correct error message
|
||||
assert "Database connection is required" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_current_session_id(setup_db, sample_session):
|
||||
"""Test retrieving the ID of the current session."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Set the current session
|
||||
repo.current_session = sample_session
|
||||
|
||||
# Get the current session ID
|
||||
session_id = repo.get_current_session_id()
|
||||
|
||||
# Verify the ID matches
|
||||
assert session_id == sample_session.id
|
||||
|
||||
# Test when no current session exists
|
||||
repo.current_session = None
|
||||
# Delete all sessions
|
||||
Session.delete().execute()
|
||||
|
||||
# Verify None is returned when no session exists
|
||||
session_id = repo.get_current_session_id()
|
||||
assert session_id is None
|
||||
|
|
@ -0,0 +1,458 @@
|
|||
"""
|
||||
Tests for the TrajectoryRepository class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import Trajectory, HumanInput, Session, BaseModel
|
||||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepository,
|
||||
TrajectoryRepositoryManager,
|
||||
get_trajectory_repository,
|
||||
trajectory_repo_var
|
||||
)
|
||||
from ra_aid.database.pydantic_models import TrajectoryModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_repo():
|
||||
"""Reset the repository contextvar after each test."""
|
||||
# Reset before the test
|
||||
trajectory_repo_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
trajectory_repo_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the necessary tables and patch the BaseModel.Meta.database."""
|
||||
# Initialize an in-memory database connection
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Patch the BaseModel.Meta.database to use our in-memory database
|
||||
with patch.object(BaseModel._meta, 'database', db):
|
||||
# Create the required tables
|
||||
with db.atomic():
|
||||
db.create_tables([Trajectory, HumanInput, Session], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
Trajectory.drop_table(safe=True)
|
||||
HumanInput.drop_table(safe=True)
|
||||
Session.drop_table(safe=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_human_input(setup_db):
|
||||
"""Create a sample human input in the database."""
|
||||
return HumanInput.create(
|
||||
content="Test human input",
|
||||
source="test"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_tool_parameters():
|
||||
"""Return test tool parameters."""
|
||||
return {
|
||||
"pattern": "test pattern",
|
||||
"file_path": "/path/to/file",
|
||||
"options": {
|
||||
"case_sensitive": True,
|
||||
"whole_words": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_tool_result():
|
||||
"""Return test tool result."""
|
||||
return {
|
||||
"matches": [
|
||||
{"line": 10, "content": "This is a test pattern"},
|
||||
{"line": 20, "content": "Another test pattern here"}
|
||||
],
|
||||
"total_matches": 2,
|
||||
"execution_time": 0.5
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_step_data():
|
||||
"""Return test step data for UI rendering."""
|
||||
return {
|
||||
"display_type": "text",
|
||||
"content": "Tool execution results",
|
||||
"highlights": [
|
||||
{"start": 10, "end": 15, "color": "red"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Create a sample trajectory in the database."""
|
||||
return Trajectory.create(
|
||||
human_input=sample_human_input,
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters=json.dumps(test_tool_parameters),
|
||||
tool_result=json.dumps(test_tool_result),
|
||||
step_data=json.dumps(test_step_data),
|
||||
record_type="tool_execution",
|
||||
cost=0.001,
|
||||
tokens=100,
|
||||
is_error=False
|
||||
)
|
||||
|
||||
|
||||
def test_create_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Test creating a trajectory with all fields."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create a trajectory
|
||||
trajectory = repo.create(
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters=test_tool_parameters,
|
||||
tool_result=test_tool_result,
|
||||
step_data=test_step_data,
|
||||
record_type="tool_execution",
|
||||
human_input_id=sample_human_input.id,
|
||||
cost=0.001,
|
||||
tokens=100
|
||||
)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the trajectory was created correctly
|
||||
assert trajectory.id is not None
|
||||
assert trajectory.tool_name == "ripgrep_search"
|
||||
|
||||
# Verify the JSON fields are dictionaries, not strings
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
assert isinstance(trajectory.tool_result, dict)
|
||||
assert isinstance(trajectory.step_data, dict)
|
||||
|
||||
# Verify the nested structure of tool parameters
|
||||
assert trajectory.tool_parameters["options"]["case_sensitive"] == True
|
||||
assert trajectory.tool_result["total_matches"] == 2
|
||||
assert trajectory.step_data["highlights"][0]["color"] == "red"
|
||||
|
||||
# Verify foreign key reference
|
||||
assert trajectory.human_input_id == sample_human_input.id
|
||||
|
||||
|
||||
def test_create_trajectory_minimal(setup_db):
|
||||
"""Test creating a trajectory with minimal fields."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create a trajectory with minimal fields
|
||||
trajectory = repo.create(
|
||||
tool_name="simple_tool"
|
||||
)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the trajectory was created correctly
|
||||
assert trajectory.id is not None
|
||||
assert trajectory.tool_name == "simple_tool"
|
||||
|
||||
# Verify optional fields are None
|
||||
assert trajectory.tool_parameters is None
|
||||
assert trajectory.tool_result is None
|
||||
assert trajectory.step_data is None
|
||||
assert trajectory.human_input_id is None
|
||||
assert trajectory.cost is None
|
||||
assert trajectory.tokens is None
|
||||
assert trajectory.is_error is False
|
||||
|
||||
|
||||
def test_get_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Test retrieving a trajectory by ID."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Get the trajectory by ID
|
||||
trajectory = repo.get(sample_trajectory.id)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the retrieved trajectory matches the original
|
||||
assert trajectory.id == sample_trajectory.id
|
||||
assert trajectory.tool_name == sample_trajectory.tool_name
|
||||
|
||||
# Verify the JSON fields are dictionaries, not strings
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
assert isinstance(trajectory.tool_result, dict)
|
||||
assert isinstance(trajectory.step_data, dict)
|
||||
|
||||
# Verify the content of JSON fields
|
||||
assert trajectory.tool_parameters == test_tool_parameters
|
||||
assert trajectory.tool_result == test_tool_result
|
||||
assert trajectory.step_data == test_step_data
|
||||
|
||||
# Verify non-existent trajectory returns None
|
||||
non_existent_trajectory = repo.get(999)
|
||||
assert non_existent_trajectory is None
|
||||
|
||||
|
||||
def test_update_trajectory(setup_db, sample_trajectory):
|
||||
"""Test updating a trajectory."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# New data for update
|
||||
new_tool_result = {
|
||||
"matches": [
|
||||
{"line": 15, "content": "Updated test pattern"}
|
||||
],
|
||||
"total_matches": 1,
|
||||
"execution_time": 0.3
|
||||
}
|
||||
|
||||
new_step_data = {
|
||||
"display_type": "html",
|
||||
"content": "Updated UI rendering",
|
||||
"highlights": []
|
||||
}
|
||||
|
||||
# Update the trajectory
|
||||
updated_trajectory = repo.update(
|
||||
trajectory_id=sample_trajectory.id,
|
||||
tool_result=new_tool_result,
|
||||
step_data=new_step_data,
|
||||
cost=0.002,
|
||||
tokens=200,
|
||||
is_error=True,
|
||||
error_message="Test error",
|
||||
error_type="TestErrorType",
|
||||
error_details="Detailed error information"
|
||||
)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(updated_trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the fields were updated
|
||||
assert updated_trajectory.tool_result == new_tool_result
|
||||
assert updated_trajectory.step_data == new_step_data
|
||||
assert updated_trajectory.cost == 0.002
|
||||
assert updated_trajectory.tokens == 200
|
||||
assert updated_trajectory.is_error is True
|
||||
assert updated_trajectory.error_message == "Test error"
|
||||
assert updated_trajectory.error_type == "TestErrorType"
|
||||
assert updated_trajectory.error_details == "Detailed error information"
|
||||
|
||||
# Original tool parameters should not change
|
||||
# We need to parse the JSON string from the Peewee object for comparison
|
||||
original_params = json.loads(sample_trajectory.tool_parameters)
|
||||
assert updated_trajectory.tool_parameters == original_params
|
||||
|
||||
# Verify updating a non-existent trajectory returns None
|
||||
non_existent_update = repo.update(trajectory_id=999, cost=0.005)
|
||||
assert non_existent_update is None
|
||||
|
||||
|
||||
def test_delete_trajectory(setup_db, sample_trajectory):
|
||||
"""Test deleting a trajectory."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Verify the trajectory exists
|
||||
assert repo.get(sample_trajectory.id) is not None
|
||||
|
||||
# Delete the trajectory
|
||||
result = repo.delete(sample_trajectory.id)
|
||||
|
||||
# Verify the trajectory was deleted
|
||||
assert result is True
|
||||
assert repo.get(sample_trajectory.id) is None
|
||||
|
||||
# Verify deleting a non-existent trajectory returns False
|
||||
result = repo.delete(999)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_get_all_trajectories(setup_db, sample_human_input):
|
||||
"""Test retrieving all trajectories."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create multiple trajectories
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
tool_name=f"tool_{i}",
|
||||
tool_parameters={"index": i},
|
||||
human_input_id=sample_human_input.id
|
||||
)
|
||||
|
||||
# Get all trajectories
|
||||
trajectories = repo.get_all()
|
||||
|
||||
# Verify we got a dictionary of TrajectoryModel objects
|
||||
assert len(trajectories) == 3
|
||||
for trajectory_id, trajectory in trajectories.items():
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
|
||||
# Verify the trajectories have the correct tool names
|
||||
tool_names = {trajectory.tool_name for trajectory in trajectories.values()}
|
||||
assert "tool_0" in tool_names
|
||||
assert "tool_1" in tool_names
|
||||
assert "tool_2" in tool_names
|
||||
|
||||
|
||||
def test_get_trajectories_by_human_input(setup_db, sample_human_input):
|
||||
"""Test retrieving trajectories by human input ID."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create another human input
|
||||
other_human_input = HumanInput.create(
|
||||
content="Another human input",
|
||||
source="test"
|
||||
)
|
||||
|
||||
# Create trajectories for both human inputs
|
||||
for i in range(2):
|
||||
repo.create(
|
||||
tool_name=f"tool_1_{i}",
|
||||
human_input_id=sample_human_input.id
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
tool_name=f"tool_2_{i}",
|
||||
human_input_id=other_human_input.id
|
||||
)
|
||||
|
||||
# Get trajectories for the first human input
|
||||
trajectories = repo.get_trajectories_by_human_input(sample_human_input.id)
|
||||
|
||||
# Verify we got a list of TrajectoryModel objects for the first human input
|
||||
assert len(trajectories) == 2
|
||||
for trajectory in trajectories:
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert trajectory.human_input_id == sample_human_input.id
|
||||
assert trajectory.tool_name.startswith("tool_1")
|
||||
|
||||
# Get trajectories for the second human input
|
||||
trajectories = repo.get_trajectories_by_human_input(other_human_input.id)
|
||||
|
||||
# Verify we got a list of TrajectoryModel objects for the second human input
|
||||
assert len(trajectories) == 3
|
||||
for trajectory in trajectories:
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert trajectory.human_input_id == other_human_input.id
|
||||
assert trajectory.tool_name.startswith("tool_2")
|
||||
|
||||
|
||||
def test_get_parsed_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Test retrieving a parsed trajectory."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Get the parsed trajectory
|
||||
trajectory = repo.get_parsed_trajectory(sample_trajectory.id)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the retrieved trajectory matches the original
|
||||
assert trajectory.id == sample_trajectory.id
|
||||
assert trajectory.tool_name == sample_trajectory.tool_name
|
||||
|
||||
# Verify the JSON fields are dictionaries, not strings
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
assert isinstance(trajectory.tool_result, dict)
|
||||
assert isinstance(trajectory.step_data, dict)
|
||||
|
||||
# Verify the content of JSON fields
|
||||
assert trajectory.tool_parameters == test_tool_parameters
|
||||
assert trajectory.tool_result == test_tool_result
|
||||
assert trajectory.step_data == test_step_data
|
||||
|
||||
# Verify non-existent trajectory returns None
|
||||
non_existent_trajectory = repo.get_parsed_trajectory(999)
|
||||
assert non_existent_trajectory is None
|
||||
|
||||
|
||||
def test_trajectory_repository_manager(setup_db, cleanup_repo):
|
||||
"""Test the TrajectoryRepositoryManager context manager."""
|
||||
# Use the context manager to create a repository
|
||||
with TrajectoryRepositoryManager(setup_db) as repo:
|
||||
# Verify the repository was created correctly
|
||||
assert isinstance(repo, TrajectoryRepository)
|
||||
assert repo.db is setup_db
|
||||
|
||||
# Create a trajectory and verify it's a TrajectoryModel
|
||||
tool_parameters = {"test": "manager"}
|
||||
trajectory = repo.create(
|
||||
tool_name="manager_test",
|
||||
tool_parameters=tool_parameters
|
||||
)
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert trajectory.tool_parameters["test"] == "manager"
|
||||
|
||||
# Verify we can get the repository using get_trajectory_repository
|
||||
repo_from_var = get_trajectory_repository()
|
||||
assert repo_from_var is repo
|
||||
|
||||
# Verify the repository was removed from the context var
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_trajectory_repository()
|
||||
|
||||
assert "No TrajectoryRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_repository_init_without_db():
|
||||
"""Test that TrajectoryRepository raises an error when initialized without a db parameter."""
|
||||
# Attempt to create a repository without a database connection
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
TrajectoryRepository(db=None)
|
||||
|
||||
# Verify the correct error message
|
||||
assert "Database connection is required" in str(excinfo.value)
|
||||
Loading…
Reference in New Issue