use pydantic models

This commit is contained in:
AI Christianson 2025-03-15 14:18:00 -04:00
parent 5d07a7f7b8
commit e0aab1021b
14 changed files with 1810 additions and 147 deletions

View File

@ -0,0 +1,376 @@
"""
Pydantic models for ra_aid database entities.
This module defines Pydantic models that correspond to Peewee ORM models,
providing validation, serialization, and deserialization capabilities.
"""
import datetime
import json
from typing import Dict, List, Any, Optional
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class SessionModel(BaseModel):
"""
Pydantic model representing a Session.
This model corresponds to the Session Peewee ORM model and provides
validation and serialization capabilities. It handles the conversion
between JSON-encoded strings and Python dictionaries for the machine_info field.
Attributes:
id: Unique identifier for the session
created_at: When the session record was created
updated_at: When the session record was last updated
start_time: When the program session started
command_line: Command line arguments used to start the program
program_version: Version of the program
machine_info: Dictionary containing machine-specific metadata
"""
id: Optional[int] = None
created_at: datetime.datetime
updated_at: datetime.datetime
start_time: datetime.datetime
command_line: Optional[str] = None
program_version: Optional[str] = None
machine_info: Optional[Dict[str, Any]] = None
# Configure the model to work with ORM objects
model_config = ConfigDict(from_attributes=True)
@field_validator("machine_info", mode="before")
@classmethod
def parse_machine_info(cls, value: Any) -> Optional[Dict[str, Any]]:
"""
Parse the machine_info field from a JSON string to a dictionary.
Args:
value: The value to parse, can be a string, dict, or None
Returns:
Optional[Dict[str, Any]]: The parsed dictionary or None
Raises:
ValueError: If the JSON string is invalid
"""
if value is None:
return None
if isinstance(value, dict):
return value
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in machine_info: {e}")
raise ValueError(f"Unexpected type for machine_info: {type(value)}")
@field_serializer("machine_info")
def serialize_machine_info(self, machine_info: Optional[Dict[str, Any]]) -> Optional[str]:
"""
Serialize the machine_info dictionary to a JSON string for storage.
Args:
machine_info: Dictionary to serialize
Returns:
Optional[str]: JSON-encoded string or None
"""
if machine_info is None:
return None
return json.dumps(machine_info)
class HumanInputModel(BaseModel):
"""
Pydantic model representing a HumanInput.
This model corresponds to the HumanInput Peewee ORM model and provides
validation and serialization capabilities.
Attributes:
id: Unique identifier for the human input
created_at: When the record was created
updated_at: When the record was last updated
content: The text content of the input
source: The source of the input ('cli', 'chat', or 'hil')
session_id: Optional reference to the associated session
"""
id: Optional[int] = None
created_at: datetime.datetime
updated_at: datetime.datetime
content: str
source: str
session_id: Optional[int] = None
# Configure the model to work with ORM objects
model_config = ConfigDict(from_attributes=True)
class KeyFactModel(BaseModel):
"""
Pydantic model representing a KeyFact.
This model corresponds to the KeyFact Peewee ORM model and provides
validation and serialization capabilities.
Attributes:
id: Unique identifier for the key fact
created_at: When the record was created
updated_at: When the record was last updated
content: The text content of the key fact
human_input_id: Optional reference to the associated human input
session_id: Optional reference to the associated session
"""
id: Optional[int] = None
created_at: datetime.datetime
updated_at: datetime.datetime
content: str
human_input_id: Optional[int] = None
session_id: Optional[int] = None
# Configure the model to work with ORM objects
model_config = ConfigDict(from_attributes=True)
class KeySnippetModel(BaseModel):
"""
Pydantic model representing a KeySnippet.
This model corresponds to the KeySnippet Peewee ORM model and provides
validation and serialization capabilities.
Attributes:
id: Unique identifier for the key snippet
created_at: When the record was created
updated_at: When the record was last updated
filepath: Path to the source file
line_number: Line number where the snippet starts
snippet: The source code snippet text
description: Optional description of the significance
human_input_id: Optional reference to the associated human input
session_id: Optional reference to the associated session
"""
id: Optional[int] = None
created_at: datetime.datetime
updated_at: datetime.datetime
filepath: str
line_number: int
snippet: str
description: Optional[str] = None
human_input_id: Optional[int] = None
session_id: Optional[int] = None
# Configure the model to work with ORM objects
model_config = ConfigDict(from_attributes=True)
class ResearchNoteModel(BaseModel):
"""
Pydantic model representing a ResearchNote.
This model corresponds to the ResearchNote Peewee ORM model and provides
validation and serialization capabilities.
Attributes:
id: Unique identifier for the research note
created_at: When the record was created
updated_at: When the record was last updated
content: The text content of the research note
human_input_id: Optional reference to the associated human input
session_id: Optional reference to the associated session
"""
id: Optional[int] = None
created_at: datetime.datetime
updated_at: datetime.datetime
content: str
human_input_id: Optional[int] = None
session_id: Optional[int] = None
# Configure the model to work with ORM objects
model_config = ConfigDict(from_attributes=True)
class TrajectoryModel(BaseModel):
"""
Pydantic model representing a Trajectory.
This model corresponds to the Trajectory Peewee ORM model and provides
validation and serialization capabilities. It handles the conversion
between JSON-encoded strings and Python dictionaries for the tool_parameters,
tool_result, and step_data fields.
Attributes:
id: Unique identifier for the trajectory
created_at: When the record was created
updated_at: When the record was last updated
human_input_id: Optional reference to the associated human input
tool_name: Name of the tool that was executed
tool_parameters: Dictionary containing the parameters passed to the tool
tool_result: Dictionary containing the result returned by the tool
step_data: Dictionary containing UI rendering data
record_type: Type of trajectory record
cost: Optional cost of the tool execution
tokens: Optional token usage of the tool execution
is_error: Flag indicating if this record represents an error
error_message: The error message if is_error is True
error_type: The type/class of the error if is_error is True
error_details: Additional error details if is_error is True
session_id: Optional reference to the associated session
"""
id: Optional[int] = None
created_at: datetime.datetime
updated_at: datetime.datetime
human_input_id: Optional[int] = None
tool_name: Optional[str] = None
tool_parameters: Optional[Dict[str, Any]] = None
tool_result: Optional[Any] = None
step_data: Optional[Dict[str, Any]] = None
record_type: Optional[str] = None
cost: Optional[float] = None
tokens: Optional[int] = None
is_error: bool = False
error_message: Optional[str] = None
error_type: Optional[str] = None
error_details: Optional[str] = None
session_id: Optional[int] = None
# Configure the model to work with ORM objects
model_config = ConfigDict(from_attributes=True)
@field_validator("tool_parameters", mode="before")
@classmethod
def parse_tool_parameters(cls, value: Any) -> Optional[Dict[str, Any]]:
"""
Parse the tool_parameters field from a JSON string to a dictionary.
Args:
value: The value to parse, can be a string, dict, or None
Returns:
Optional[Dict[str, Any]]: The parsed dictionary or None
Raises:
ValueError: If the JSON string is invalid
"""
if value is None:
return None
if isinstance(value, dict):
return value
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in tool_parameters: {e}")
raise ValueError(f"Unexpected type for tool_parameters: {type(value)}")
@field_validator("tool_result", mode="before")
@classmethod
def parse_tool_result(cls, value: Any) -> Optional[Any]:
"""
Parse the tool_result field from a JSON string to a Python object.
Args:
value: The value to parse, can be a string, dict, list, or None
Returns:
Optional[Any]: The parsed object or None
Raises:
ValueError: If the JSON string is invalid
"""
if value is None:
return None
if not isinstance(value, str):
return value
try:
return json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in tool_result: {e}")
@field_validator("step_data", mode="before")
@classmethod
def parse_step_data(cls, value: Any) -> Optional[Dict[str, Any]]:
"""
Parse the step_data field from a JSON string to a dictionary.
Args:
value: The value to parse, can be a string, dict, or None
Returns:
Optional[Dict[str, Any]]: The parsed dictionary or None
Raises:
ValueError: If the JSON string is invalid
"""
if value is None:
return None
if isinstance(value, dict):
return value
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in step_data: {e}")
raise ValueError(f"Unexpected type for step_data: {type(value)}")
@field_serializer("tool_parameters")
def serialize_tool_parameters(self, tool_parameters: Optional[Dict[str, Any]]) -> Optional[str]:
"""
Serialize the tool_parameters dictionary to a JSON string for storage.
Args:
tool_parameters: Dictionary to serialize
Returns:
Optional[str]: JSON-encoded string or None
"""
if tool_parameters is None:
return None
return json.dumps(tool_parameters)
@field_serializer("tool_result")
def serialize_tool_result(self, tool_result: Optional[Any]) -> Optional[str]:
"""
Serialize the tool_result object to a JSON string for storage.
Args:
tool_result: Object to serialize
Returns:
Optional[str]: JSON-encoded string or None
"""
if tool_result is None:
return None
return json.dumps(tool_result)
@field_serializer("step_data")
def serialize_step_data(self, step_data: Optional[Dict[str, Any]]) -> Optional[str]:
"""
Serialize the step_data dictionary to a JSON string for storage.
Args:
step_data: Dictionary to serialize
Returns:
Optional[str]: JSON-encoded string or None
"""
if step_data is None:
return None
return json.dumps(step_data)

View File

@ -11,6 +11,7 @@ import contextvars
import peewee
from ra_aid.database.models import HumanInput
from ra_aid.database.pydantic_models import HumanInputModel
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -119,7 +120,22 @@ class HumanInputRepository:
raise ValueError("Database connection is required for HumanInputRepository")
self.db = db
def create(self, content: str, source: str) -> HumanInput:
def _to_model(self, human_input: Optional[HumanInput]) -> Optional[HumanInputModel]:
"""
Convert a Peewee HumanInput object to a Pydantic HumanInputModel.
Args:
human_input: Peewee HumanInput instance or None
Returns:
Optional[HumanInputModel]: Pydantic model representation or None if human_input is None
"""
if human_input is None:
return None
return HumanInputModel.model_validate(human_input, from_attributes=True)
def create(self, content: str, source: str) -> HumanInputModel:
"""
Create a new human input record in the database.
@ -128,7 +144,7 @@ class HumanInputRepository:
source: The source of the input (e.g., "cli", "chat", "hil")
Returns:
HumanInput: The newly created human input instance
HumanInputModel: The newly created human input instance
Raises:
peewee.DatabaseError: If there's an error creating the record
@ -136,12 +152,12 @@ class HumanInputRepository:
try:
input_record = HumanInput.create(content=content, source=source)
logger.debug(f"Created human input ID {input_record.id} from {source}")
return input_record
return self._to_model(input_record)
except peewee.DatabaseError as e:
logger.error(f"Failed to create human input record: {str(e)}")
raise
def get(self, input_id: int) -> Optional[HumanInput]:
def get(self, input_id: int) -> Optional[HumanInputModel]:
"""
Retrieve a human input record by its ID.
@ -149,18 +165,19 @@ class HumanInputRepository:
input_id: The ID of the human input to retrieve
Returns:
Optional[HumanInput]: The human input instance if found, None otherwise
Optional[HumanInputModel]: The human input instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return HumanInput.get_or_none(HumanInput.id == input_id)
human_input = HumanInput.get_or_none(HumanInput.id == input_id)
return self._to_model(human_input)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch human input {input_id}: {str(e)}")
raise
def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInput]:
def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInputModel]:
"""
Update an existing human input record.
@ -170,14 +187,14 @@ class HumanInputRepository:
source: The new source for the human input
Returns:
Optional[HumanInput]: The updated human input if found, None otherwise
Optional[HumanInputModel]: The updated human input if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the record
"""
try:
# First check if the record exists
input_record = self.get(input_id)
# We need to get the raw Peewee object for updating
input_record = HumanInput.get_or_none(HumanInput.id == input_id)
if not input_record:
logger.warning(f"Attempted to update non-existent human input {input_id}")
return None
@ -190,7 +207,7 @@ class HumanInputRepository:
input_record.save()
logger.debug(f"Updated human input ID {input_id}")
return input_record
return self._to_model(input_record)
except peewee.DatabaseError as e:
logger.error(f"Failed to update human input {input_id}: {str(e)}")
raise
@ -223,23 +240,24 @@ class HumanInputRepository:
logger.error(f"Failed to delete human input {input_id}: {str(e)}")
raise
def get_all(self) -> List[HumanInput]:
def get_all(self) -> List[HumanInputModel]:
"""
Retrieve all human input records from the database.
Returns:
List[HumanInput]: List of all human input instances
List[HumanInputModel]: List of all human input instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(HumanInput.select().order_by(HumanInput.created_at.desc()))
human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc()))
return [self._to_model(input) for input in human_inputs]
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all human inputs: {str(e)}")
raise
def get_recent(self, limit: int = 10) -> List[HumanInput]:
def get_recent(self, limit: int = 10) -> List[HumanInputModel]:
"""
Retrieve the most recent human input records.
@ -247,13 +265,14 @@ class HumanInputRepository:
limit: Maximum number of records to retrieve (default: 10)
Returns:
List[HumanInput]: List of the most recent human input records
List[HumanInputModel]: List of the most recent human input records
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit))
human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit))
return [self._to_model(input) for input in human_inputs]
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch recent human inputs: {str(e)}")
raise
@ -277,7 +296,7 @@ class HumanInputRepository:
logger.error(f"Failed to fetch most recent human input ID: {str(e)}")
raise
def get_by_source(self, source: str) -> List[HumanInput]:
def get_by_source(self, source: str) -> List[HumanInputModel]:
"""
Retrieve human input records by source.
@ -285,13 +304,14 @@ class HumanInputRepository:
source: The source to filter by (e.g., "cli", "chat", "hil")
Returns:
List[HumanInput]: List of human input records from the specified source
List[HumanInputModel]: List of human input records from the specified source
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc()))
human_inputs = list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc()))
return [self._to_model(input) for input in human_inputs]
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch human inputs by source {source}: {str(e)}")
raise

View File

@ -12,6 +12,7 @@ from contextlib import contextmanager
import peewee
from ra_aid.database.models import KeyFact
from ra_aid.database.pydantic_models import KeyFactModel
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -120,7 +121,22 @@ class KeyFactRepository:
raise ValueError("Database connection is required for KeyFactRepository")
self.db = db
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
def _to_model(self, fact: Optional[KeyFact]) -> Optional[KeyFactModel]:
"""
Convert a Peewee KeyFact object to a Pydantic KeyFactModel.
Args:
fact: Peewee KeyFact instance or None
Returns:
Optional[KeyFactModel]: Pydantic model representation or None if fact is None
"""
if fact is None:
return None
return KeyFactModel.model_validate(fact, from_attributes=True)
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFactModel:
"""
Create a new key fact in the database.
@ -129,7 +145,7 @@ class KeyFactRepository:
human_input_id: Optional ID of the associated human input
Returns:
KeyFact: The newly created key fact instance
KeyFactModel: The newly created key fact instance
Raises:
peewee.DatabaseError: If there's an error creating the fact
@ -137,12 +153,12 @@ class KeyFactRepository:
try:
fact = KeyFact.create(content=content, human_input_id=human_input_id)
logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact
return self._to_model(fact)
except peewee.DatabaseError as e:
logger.error(f"Failed to create key fact: {str(e)}")
raise
def get(self, fact_id: int) -> Optional[KeyFact]:
def get(self, fact_id: int) -> Optional[KeyFactModel]:
"""
Retrieve a key fact by its ID.
@ -150,18 +166,19 @@ class KeyFactRepository:
fact_id: The ID of the key fact to retrieve
Returns:
Optional[KeyFact]: The key fact instance if found, None otherwise
Optional[KeyFactModel]: The key fact instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return KeyFact.get_or_none(KeyFact.id == fact_id)
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
return self._to_model(fact)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
raise
def update(self, fact_id: int, content: str) -> Optional[KeyFact]:
def update(self, fact_id: int, content: str) -> Optional[KeyFactModel]:
"""
Update an existing key fact.
@ -170,14 +187,14 @@ class KeyFactRepository:
content: The new content for the key fact
Returns:
Optional[KeyFact]: The updated key fact if found, None otherwise
Optional[KeyFactModel]: The updated key fact if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the fact
"""
try:
# First check if the fact exists
fact = self.get(fact_id)
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
if not fact:
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
return None
@ -186,7 +203,7 @@ class KeyFactRepository:
fact.content = content
fact.save()
logger.debug(f"Updated key fact ID {fact_id}: {content}")
return fact
return self._to_model(fact)
except peewee.DatabaseError as e:
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
raise
@ -206,7 +223,7 @@ class KeyFactRepository:
"""
try:
# First check if the fact exists
fact = self.get(fact_id)
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
if not fact:
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
return False
@ -219,18 +236,19 @@ class KeyFactRepository:
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
raise
def get_all(self) -> List[KeyFact]:
def get_all(self) -> List[KeyFactModel]:
"""
Retrieve all key facts from the database.
Returns:
List[KeyFact]: List of all key fact instances
List[KeyFactModel]: List of all key fact instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(KeyFact.select().order_by(KeyFact.id))
facts = list(KeyFact.select().order_by(KeyFact.id))
return [self._to_model(fact) for fact in facts]
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key facts: {str(e)}")
raise

View File

@ -11,6 +11,7 @@ import contextvars
import peewee
from ra_aid.database.models import KeySnippet
from ra_aid.database.pydantic_models import KeySnippetModel
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -129,10 +130,25 @@ class KeySnippetRepository:
raise ValueError("Database connection is required for KeySnippetRepository")
self.db = db
def _to_model(self, snippet: Optional[KeySnippet]) -> Optional[KeySnippetModel]:
"""
Convert a Peewee KeySnippet object to a Pydantic KeySnippetModel.
Args:
snippet: Peewee KeySnippet instance or None
Returns:
Optional[KeySnippetModel]: Pydantic model representation or None if snippet is None
"""
if snippet is None:
return None
return KeySnippetModel.model_validate(snippet, from_attributes=True)
def create(
self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None,
human_input_id: Optional[int] = None
) -> KeySnippet:
) -> KeySnippetModel:
"""
Create a new key snippet in the database.
@ -144,7 +160,7 @@ class KeySnippetRepository:
human_input_id: Optional ID of the associated human input
Returns:
KeySnippet: The newly created key snippet instance
KeySnippetModel: The newly created key snippet instance
Raises:
peewee.DatabaseError: If there's an error creating the snippet
@ -158,12 +174,12 @@ class KeySnippetRepository:
human_input_id=human_input_id
)
logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}")
return key_snippet
return self._to_model(key_snippet)
except peewee.DatabaseError as e:
logger.error(f"Failed to create key snippet: {str(e)}")
raise
def get(self, snippet_id: int) -> Optional[KeySnippet]:
def get(self, snippet_id: int) -> Optional[KeySnippetModel]:
"""
Retrieve a key snippet by its ID.
@ -171,13 +187,14 @@ class KeySnippetRepository:
snippet_id: The ID of the key snippet to retrieve
Returns:
Optional[KeySnippet]: The key snippet instance if found, None otherwise
Optional[KeySnippetModel]: The key snippet instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
return self._to_model(snippet)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
raise
@ -189,7 +206,7 @@ class KeySnippetRepository:
line_number: int,
snippet: str,
description: Optional[str] = None
) -> Optional[KeySnippet]:
) -> Optional[KeySnippetModel]:
"""
Update an existing key snippet.
@ -201,14 +218,14 @@ class KeySnippetRepository:
description: Optional description of the significance
Returns:
Optional[KeySnippet]: The updated key snippet if found, None otherwise
Optional[KeySnippetModel]: The updated key snippet if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the snippet
"""
try:
# First check if the snippet exists
key_snippet = self.get(snippet_id)
key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
if not key_snippet:
logger.warning(f"Attempted to update non-existent key snippet {snippet_id}")
return None
@ -220,7 +237,7 @@ class KeySnippetRepository:
key_snippet.description = description
key_snippet.save()
logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}")
return key_snippet
return self._to_model(key_snippet)
except peewee.DatabaseError as e:
logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}")
raise
@ -240,7 +257,7 @@ class KeySnippetRepository:
"""
try:
# First check if the snippet exists
key_snippet = self.get(snippet_id)
key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
if not key_snippet:
logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}")
return False
@ -253,18 +270,19 @@ class KeySnippetRepository:
logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}")
raise
def get_all(self) -> List[KeySnippet]:
def get_all(self) -> List[KeySnippetModel]:
"""
Retrieve all key snippets from the database.
Returns:
List[KeySnippet]: List of all key snippet instances
List[KeySnippetModel]: List of all key snippet instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(KeySnippet.select().order_by(KeySnippet.id))
snippets = list(KeySnippet.select().order_by(KeySnippet.id))
return [self._to_model(snippet) for snippet in snippets]
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key snippets: {str(e)}")
raise

View File

@ -12,6 +12,7 @@ from contextlib import contextmanager
import peewee
from ra_aid.database.models import ResearchNote
from ra_aid.database.pydantic_models import ResearchNoteModel
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -120,7 +121,22 @@ class ResearchNoteRepository:
raise ValueError("Database connection is required for ResearchNoteRepository")
self.db = db
def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNote:
def _to_model(self, note: Optional[ResearchNote]) -> Optional[ResearchNoteModel]:
"""
Convert a Peewee ResearchNote object to a Pydantic ResearchNoteModel.
Args:
note: Peewee ResearchNote instance or None
Returns:
Optional[ResearchNoteModel]: Pydantic model representation or None if note is None
"""
if note is None:
return None
return ResearchNoteModel.model_validate(note, from_attributes=True)
def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNoteModel:
"""
Create a new research note in the database.
@ -129,7 +145,7 @@ class ResearchNoteRepository:
human_input_id: Optional ID of the associated human input
Returns:
ResearchNote: The newly created research note instance
ResearchNoteModel: The newly created research note instance
Raises:
peewee.DatabaseError: If there's an error creating the note
@ -137,12 +153,12 @@ class ResearchNoteRepository:
try:
note = ResearchNote.create(content=content, human_input_id=human_input_id)
logger.debug(f"Created research note ID {note.id}: {content[:50]}...")
return note
return self._to_model(note)
except peewee.DatabaseError as e:
logger.error(f"Failed to create research note: {str(e)}")
raise
def get(self, note_id: int) -> Optional[ResearchNote]:
def get(self, note_id: int) -> Optional[ResearchNoteModel]:
"""
Retrieve a research note by its ID.
@ -150,18 +166,19 @@ class ResearchNoteRepository:
note_id: The ID of the research note to retrieve
Returns:
Optional[ResearchNote]: The research note instance if found, None otherwise
Optional[ResearchNoteModel]: The research note instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return ResearchNote.get_or_none(ResearchNote.id == note_id)
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
return self._to_model(note)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch research note {note_id}: {str(e)}")
raise
def update(self, note_id: int, content: str) -> Optional[ResearchNote]:
def update(self, note_id: int, content: str) -> Optional[ResearchNoteModel]:
"""
Update an existing research note.
@ -170,14 +187,14 @@ class ResearchNoteRepository:
content: The new content for the research note
Returns:
Optional[ResearchNote]: The updated research note if found, None otherwise
Optional[ResearchNoteModel]: The updated research note if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the note
"""
try:
# First check if the note exists
note = self.get(note_id)
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
if not note:
logger.warning(f"Attempted to update non-existent research note {note_id}")
return None
@ -186,7 +203,7 @@ class ResearchNoteRepository:
note.content = content
note.save()
logger.debug(f"Updated research note ID {note_id}: {content[:50]}...")
return note
return self._to_model(note)
except peewee.DatabaseError as e:
logger.error(f"Failed to update research note {note_id}: {str(e)}")
raise
@ -206,7 +223,7 @@ class ResearchNoteRepository:
"""
try:
# First check if the note exists
note = self.get(note_id)
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
if not note:
logger.warning(f"Attempted to delete non-existent research note {note_id}")
return False
@ -219,18 +236,19 @@ class ResearchNoteRepository:
logger.error(f"Failed to delete research note {note_id}: {str(e)}")
raise
def get_all(self) -> List[ResearchNote]:
def get_all(self) -> List[ResearchNoteModel]:
"""
Retrieve all research notes from the database.
Returns:
List[ResearchNote]: List of all research note instances
List[ResearchNoteModel]: List of all research note instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(ResearchNote.select().order_by(ResearchNote.id))
notes = list(ResearchNote.select().order_by(ResearchNote.id))
return [self._to_model(note) for note in notes]
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all research notes: {str(e)}")
raise

View File

@ -16,6 +16,7 @@ import sys
import peewee
from ra_aid.database.models import Session
from ra_aid.database.pydantic_models import SessionModel
from ra_aid.__version__ import __version__
from ra_aid.logging_config import get_logger
@ -121,7 +122,22 @@ class SessionRepository:
self.db = db
self.current_session = None
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
def _to_model(self, session: Optional[Session]) -> Optional[SessionModel]:
"""
Convert a Peewee Session object to a Pydantic SessionModel.
Args:
session: Peewee Session instance or None
Returns:
Optional[SessionModel]: Pydantic model representation or None if session is None
"""
if session is None:
return None
return SessionModel.model_validate(session, from_attributes=True)
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> SessionModel:
"""
Create a new session record in the database.
@ -129,7 +145,7 @@ class SessionRepository:
metadata: Optional dictionary of additional metadata to store with the session
Returns:
Session: The newly created session instance
SessionModel: The newly created session instance
Raises:
peewee.DatabaseError: If there's an error creating the record
@ -155,12 +171,12 @@ class SessionRepository:
self.current_session = session
logger.debug(f"Created new session with ID {session.id}")
return session
return self._to_model(session)
except peewee.DatabaseError as e:
logger.error(f"Failed to create session record: {str(e)}")
raise
def get_current_session(self) -> Optional[Session]:
def get_current_session(self) -> Optional[SessionModel]:
"""
Get the current active session.
@ -168,17 +184,17 @@ class SessionRepository:
retrieves the most recent session from the database.
Returns:
Optional[Session]: The current session or None if no sessions exist
Optional[SessionModel]: The current session or None if no sessions exist
"""
if self.current_session is not None:
return self.current_session
return self._to_model(self.current_session)
try:
# Find the most recent session
session = Session.select().order_by(Session.created_at.desc()).first()
if session:
self.current_session = session
return session
return self._to_model(session)
except peewee.DatabaseError as e:
logger.error(f"Failed to get current session: {str(e)}")
return None
@ -193,7 +209,7 @@ class SessionRepository:
session = self.get_current_session()
return session.id if session else None
def get(self, session_id: int) -> Optional[Session]:
def get(self, session_id: int) -> Optional[SessionModel]:
"""
Get a session by its ID.
@ -201,28 +217,30 @@ class SessionRepository:
session_id: The ID of the session to retrieve
Returns:
Optional[Session]: The session with the given ID or None if not found
Optional[SessionModel]: The session with the given ID or None if not found
"""
try:
return Session.get_or_none(Session.id == session_id)
session = Session.get_or_none(Session.id == session_id)
return self._to_model(session)
except peewee.DatabaseError as e:
logger.error(f"Database error getting session {session_id}: {str(e)}")
return None
def get_all(self) -> List[Session]:
def get_all(self) -> List[SessionModel]:
"""
Get all sessions from the database.
Returns:
List[Session]: List of all sessions
List[SessionModel]: List of all sessions
"""
try:
return list(Session.select().order_by(Session.created_at.desc()))
sessions = list(Session.select().order_by(Session.created_at.desc()))
return [self._to_model(session) for session in sessions]
except peewee.DatabaseError as e:
logger.error(f"Failed to get all sessions: {str(e)}")
return []
def get_recent(self, limit: int = 10) -> List[Session]:
def get_recent(self, limit: int = 10) -> List[SessionModel]:
"""
Get the most recent sessions from the database.
@ -230,14 +248,15 @@ class SessionRepository:
limit: Maximum number of sessions to return (default: 10)
Returns:
List[Session]: List of the most recent sessions
List[SessionModel]: List of the most recent sessions
"""
try:
return list(
sessions = list(
Session.select()
.order_by(Session.created_at.desc())
.limit(limit)
)
return [self._to_model(session) for session in sessions]
except peewee.DatabaseError as e:
logger.error(f"Failed to get recent sessions: {str(e)}")
return []

View File

@ -14,6 +14,7 @@ import logging
import peewee
from ra_aid.database.models import Trajectory, HumanInput
from ra_aid.database.pydantic_models import TrajectoryModel
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -130,6 +131,21 @@ class TrajectoryRepository:
raise ValueError("Database connection is required for TrajectoryRepository")
self.db = db
def _to_model(self, trajectory: Optional[Trajectory]) -> Optional[TrajectoryModel]:
"""
Convert a Peewee Trajectory object to a Pydantic TrajectoryModel.
Args:
trajectory: Peewee Trajectory instance or None
Returns:
Optional[TrajectoryModel]: Pydantic model representation or None if trajectory is None
"""
if trajectory is None:
return None
return TrajectoryModel.model_validate(trajectory, from_attributes=True)
def create(
self,
tool_name: Optional[str] = None,
@ -144,7 +160,7 @@ class TrajectoryRepository:
error_message: Optional[str] = None,
error_type: Optional[str] = None,
error_details: Optional[str] = None
) -> Trajectory:
) -> TrajectoryModel:
"""
Create a new trajectory record in the database.
@ -163,7 +179,7 @@ class TrajectoryRepository:
error_details: Additional error details like stack traces (if is_error is True)
Returns:
Trajectory: The newly created trajectory instance
TrajectoryModel: The newly created trajectory instance as a Pydantic model
Raises:
peewee.DatabaseError: If there's an error creating the record
@ -201,12 +217,12 @@ class TrajectoryRepository:
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
else:
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
return trajectory
return self._to_model(trajectory)
except peewee.DatabaseError as e:
logger.error(f"Failed to create trajectory record: {str(e)}")
raise
def get(self, trajectory_id: int) -> Optional[Trajectory]:
def get(self, trajectory_id: int) -> Optional[TrajectoryModel]:
"""
Retrieve a trajectory record by its ID.
@ -214,13 +230,14 @@ class TrajectoryRepository:
trajectory_id: The ID of the trajectory record to retrieve
Returns:
Optional[Trajectory]: The trajectory instance if found, None otherwise
Optional[TrajectoryModel]: The trajectory instance as a Pydantic model if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return Trajectory.get_or_none(Trajectory.id == trajectory_id)
trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
return self._to_model(trajectory)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}")
raise
@ -236,7 +253,7 @@ class TrajectoryRepository:
error_message: Optional[str] = None,
error_type: Optional[str] = None,
error_details: Optional[str] = None
) -> Optional[Trajectory]:
) -> Optional[TrajectoryModel]:
"""
Update an existing trajectory record.
@ -254,15 +271,15 @@ class TrajectoryRepository:
error_details: Additional error details like stack traces
Returns:
Optional[Trajectory]: The updated trajectory if found, None otherwise
Optional[TrajectoryModel]: The updated trajectory as a Pydantic model if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the record
"""
try:
# First check if the trajectory exists
trajectory = self.get(trajectory_id)
if not trajectory:
peewee_trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
if not peewee_trajectory:
logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}")
return None
@ -299,7 +316,7 @@ class TrajectoryRepository:
logger.debug(f"Updated trajectory record ID {trajectory_id}")
return self.get(trajectory_id)
return trajectory
return self._to_model(peewee_trajectory)
except peewee.DatabaseError as e:
logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}")
raise
@ -319,7 +336,7 @@ class TrajectoryRepository:
"""
try:
# First check if the trajectory exists
trajectory = self.get(trajectory_id)
trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
if not trajectory:
logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}")
return False
@ -332,23 +349,24 @@ class TrajectoryRepository:
logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}")
raise
def get_all(self) -> Dict[int, Trajectory]:
def get_all(self) -> Dict[int, TrajectoryModel]:
"""
Retrieve all trajectory records from the database.
Returns:
Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances
Dict[int, TrajectoryModel]: Dictionary mapping trajectory IDs to trajectory Pydantic models
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)}
trajectories = Trajectory.select().order_by(Trajectory.id)
return {trajectory.id: self._to_model(trajectory) for trajectory in trajectories}
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all trajectories: {str(e)}")
raise
def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]:
def get_trajectories_by_human_input(self, human_input_id: int) -> List[TrajectoryModel]:
"""
Retrieve all trajectory records associated with a specific human input.
@ -356,37 +374,19 @@ class TrajectoryRepository:
human_input_id: The ID of the human input to get trajectories for
Returns:
List[Trajectory]: List of trajectory instances associated with the human input
List[TrajectoryModel]: List of trajectory Pydantic models associated with the human input
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
trajectories = list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
return [self._to_model(trajectory) for trajectory in trajectories]
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}")
raise
def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]:
"""
Parse a JSON string into a Python dictionary.
Args:
json_str: JSON string to parse
Returns:
Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid
"""
if not json_str:
return None
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON field: {str(e)}")
return None
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]:
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[TrajectoryModel]:
"""
Get a trajectory record with JSON fields parsed into dictionaries.
@ -394,27 +394,7 @@ class TrajectoryRepository:
trajectory_id: ID of the trajectory to retrieve
Returns:
Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields,
or None if not found
Optional[TrajectoryModel]: The trajectory as a Pydantic model with parsed JSON fields,
or None if not found
"""
trajectory = self.get(trajectory_id)
if trajectory is None:
return None
return {
"id": trajectory.id,
"created_at": trajectory.created_at,
"updated_at": trajectory.updated_at,
"tool_name": trajectory.tool_name,
"tool_parameters": self.parse_json_field(trajectory.tool_parameters),
"tool_result": self.parse_json_field(trajectory.tool_result),
"step_data": self.parse_json_field(trajectory.step_data),
"record_type": trajectory.record_type,
"cost": trajectory.cost,
"tokens": trajectory.tokens,
"human_input_id": trajectory.human_input.id if trajectory.human_input else None,
"is_error": trajectory.is_error,
"error_message": trajectory.error_message,
"error_type": trajectory.error_type,
"error_details": trajectory.error_details,
}
return self.get(trajectory_id)

View File

@ -0,0 +1,198 @@
"""
Tests for the human input repository.
This module provides tests for the HumanInputRepository class,
ensuring it correctly interfaces with the database and returns
appropriate Pydantic models.
"""
import unittest
from typing import List, Dict, Any
import pytest
from peewee import SqliteDatabase
from ra_aid.database.models import HumanInput, Session, database_proxy
from ra_aid.database.pydantic_models import HumanInputModel, SessionModel
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.repositories.session_repository import SessionRepository
@pytest.fixture
def test_db():
"""Fixture for creating a test database."""
# Create an in-memory SQLite database for testing
test_db = SqliteDatabase(':memory:')
# Register the models with the test database
with test_db.bind_ctx([HumanInput, Session]):
# Create the tables
test_db.create_tables([HumanInput, Session])
# Return the test database for use in the tests
yield test_db
# Drop the tables after the tests
test_db.drop_tables([HumanInput, Session])
class TestHumanInputRepository(unittest.TestCase):
"""Test case for the HumanInputRepository class."""
def setUp(self):
"""Set up the test case with a test database and repositories."""
# Create an in-memory database for testing
self.db = SqliteDatabase(':memory:')
# Register the models with the test database
self.models = [HumanInput, Session]
self.db.bind(self.models)
# Create the tables
self.db.create_tables(self.models)
# Create repository instances for testing
self.repository = HumanInputRepository(self.db)
self.session_repository = SessionRepository(self.db)
# Bind the test database to the repository model
database_proxy.initialize(self.db)
def tearDown(self):
"""Clean up after the test case."""
# Close the database connection
self.db.close()
def test_create(self):
"""Test creating a human input record."""
# Create a session first
session_model = self.session_repository.create_session()
# Create a human input
content = "Test human input"
source = "cli"
human_input = self.repository.create(content=content, source=source)
# Verify the human input was created
self.assertIsInstance(human_input, HumanInputModel)
self.assertEqual(human_input.content, content)
self.assertEqual(human_input.source, source)
def test_get(self):
"""Test retrieving a human input record by ID."""
# Create a session first
session_model = self.session_repository.create_session()
# Create a human input
content = "Test human input"
source = "chat"
created_input = self.repository.create(content=content, source=source)
# Get the human input by ID
retrieved_input = self.repository.get(created_input.id)
# Verify the human input was retrieved correctly
self.assertIsInstance(retrieved_input, HumanInputModel)
self.assertEqual(retrieved_input.id, created_input.id)
self.assertEqual(retrieved_input.content, content)
self.assertEqual(retrieved_input.source, source)
def test_update(self):
"""Test updating a human input record."""
# Create a session first
session_model = self.session_repository.create_session()
# Create a human input
content = "Original content"
source = "cli"
created_input = self.repository.create(content=content, source=source)
# Update the human input
new_content = "Updated content"
updated_input = self.repository.update(created_input.id, content=new_content)
# Verify the human input was updated correctly
self.assertIsInstance(updated_input, HumanInputModel)
self.assertEqual(updated_input.id, created_input.id)
self.assertEqual(updated_input.content, new_content)
self.assertEqual(updated_input.source, source)
def test_get_all(self):
"""Test retrieving all human input records."""
# Create a session first
session_model = self.session_repository.create_session()
# Create multiple human inputs
self.repository.create(content="Input 1", source="cli")
self.repository.create(content="Input 2", source="chat")
self.repository.create(content="Input 3", source="hil")
# Get all human inputs
all_inputs = self.repository.get_all()
# Verify all human inputs were retrieved
self.assertEqual(len(all_inputs), 3)
self.assertIsInstance(all_inputs[0], HumanInputModel)
# Verify the inputs are ordered by created_at in descending order
self.assertEqual(all_inputs[0].content, "Input 3")
self.assertEqual(all_inputs[1].content, "Input 2")
self.assertEqual(all_inputs[2].content, "Input 1")
def test_get_recent(self):
"""Test retrieving the most recent human input records."""
# Create a session first
session_model = self.session_repository.create_session()
# Create multiple human inputs
self.repository.create(content="Input 1", source="cli")
self.repository.create(content="Input 2", source="chat")
self.repository.create(content="Input 3", source="hil")
self.repository.create(content="Input 4", source="cli")
self.repository.create(content="Input 5", source="chat")
# Get recent human inputs with a limit of 3
recent_inputs = self.repository.get_recent(limit=3)
# Verify only the 3 most recent inputs were retrieved
self.assertEqual(len(recent_inputs), 3)
self.assertIsInstance(recent_inputs[0], HumanInputModel)
self.assertEqual(recent_inputs[0].content, "Input 5")
self.assertEqual(recent_inputs[1].content, "Input 4")
self.assertEqual(recent_inputs[2].content, "Input 3")
def test_get_by_source(self):
"""Test retrieving human input records by source."""
# Create a session first
session_model = self.session_repository.create_session()
# Create human inputs with different sources
self.repository.create(content="CLI Input 1", source="cli")
self.repository.create(content="Chat Input 1", source="chat")
self.repository.create(content="HIL Input", source="hil")
self.repository.create(content="CLI Input 2", source="cli")
self.repository.create(content="Chat Input 2", source="chat")
# Get human inputs for the 'cli' source
cli_inputs = self.repository.get_by_source("cli")
# Verify only cli inputs were retrieved
self.assertEqual(len(cli_inputs), 2)
self.assertIsInstance(cli_inputs[0], HumanInputModel)
self.assertEqual(cli_inputs[0].content, "CLI Input 2")
self.assertEqual(cli_inputs[1].content, "CLI Input 1")
def test_get_most_recent_id(self):
"""Test retrieving the ID of the most recent human input record."""
# Create a session first
session_model = self.session_repository.create_session()
# Create multiple human inputs
self.repository.create(content="Input 1", source="cli")
input2 = self.repository.create(content="Input 2", source="chat")
# Get the most recent ID
most_recent_id = self.repository.get_most_recent_id()
# Verify the correct ID was retrieved
self.assertEqual(most_recent_id, input2.id)

View File

@ -15,6 +15,7 @@ from ra_aid.database.repositories.key_fact_repository import (
get_key_fact_repository,
key_fact_repo_var
)
from ra_aid.database.pydantic_models import KeyFactModel
@pytest.fixture
@ -87,7 +88,8 @@ def test_create_key_fact(setup_db):
content = "Test key fact"
fact = repo.create(content)
# Verify the fact was created correctly
# Verify the fact was created correctly and is a KeyFactModel
assert isinstance(fact, KeyFactModel)
assert fact.id is not None
assert fact.content == content
@ -108,7 +110,8 @@ def test_get_key_fact(setup_db):
# Retrieve the fact by ID
retrieved_fact = repo.get(fact.id)
# Verify the retrieved fact matches the original
# Verify the retrieved fact matches the original and is a KeyFactModel
assert isinstance(retrieved_fact, KeyFactModel)
assert retrieved_fact is not None
assert retrieved_fact.id == fact.id
assert retrieved_fact.content == content
@ -131,7 +134,8 @@ def test_update_key_fact(setup_db):
new_content = "Updated content"
updated_fact = repo.update(fact.id, new_content)
# Verify the fact was updated correctly
# Verify the fact was updated correctly and is a KeyFactModel
assert isinstance(updated_fact, KeyFactModel)
assert updated_fact is not None
assert updated_fact.id == fact.id
assert updated_fact.content == new_content
@ -184,8 +188,10 @@ def test_get_all_key_facts(setup_db):
# Retrieve all facts
all_facts = repo.get_all()
# Verify we got the correct number of facts
# Verify we got the correct number of facts and they are KeyFactModel instances
assert len(all_facts) == len(contents)
for fact in all_facts:
assert isinstance(fact, KeyFactModel)
# Verify the content of each fact
fact_contents = [fact.content for fact in all_facts]
@ -237,6 +243,7 @@ def test_key_fact_repository_manager(setup_db, cleanup_repo):
# Verify we can use the repository
content = "Test fact via context manager"
fact = repo.create(content)
assert isinstance(fact, KeyFactModel)
assert fact.id is not None
assert fact.content == content
@ -259,3 +266,25 @@ def test_get_key_fact_repository_when_not_set(cleanup_repo):
# Verify the correct error message
assert "No KeyFactRepository available" in str(excinfo.value)
def test_to_model_method(setup_db):
"""Test the _to_model method converts KeyFact to KeyFactModel correctly."""
# Set up repository
repo = KeyFactRepository(db=setup_db)
# Create a Peewee KeyFact directly
peewee_fact = KeyFact.create(content="Test fact for conversion")
# Convert to Pydantic model
pydantic_fact = repo._to_model(peewee_fact)
# Verify conversion was correct
assert isinstance(pydantic_fact, KeyFactModel)
assert pydantic_fact.id == peewee_fact.id
assert pydantic_fact.content == peewee_fact.content
assert pydantic_fact.created_at == peewee_fact.created_at
assert pydantic_fact.updated_at == peewee_fact.updated_at
# Test with None input
assert repo._to_model(None) is None

View File

@ -6,6 +6,7 @@ import pytest
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import KeySnippet
from ra_aid.database.pydantic_models import KeySnippetModel
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
@ -79,6 +80,9 @@ def test_create_key_snippet(setup_db):
assert key_snippet.snippet == snippet
assert key_snippet.description == description
# Verify the return type is a Pydantic model
assert isinstance(key_snippet, KeySnippetModel)
# Verify we can retrieve it from the database
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
assert snippet_from_db.filepath == filepath
@ -116,6 +120,9 @@ def test_get_key_snippet(setup_db):
assert retrieved_snippet.snippet == snippet
assert retrieved_snippet.description == description
# Verify the return type is a Pydantic model
assert isinstance(retrieved_snippet, KeySnippetModel)
# Try to retrieve a non-existent snippet
non_existent_snippet = repo.get(999)
assert non_existent_snippet is None
@ -161,6 +168,9 @@ def test_update_key_snippet(setup_db):
assert updated_snippet.snippet == new_snippet
assert updated_snippet.description == new_description
# Verify the return type is a Pydantic model
assert isinstance(updated_snippet, KeySnippetModel)
# Verify we can retrieve the updated content from the database
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
assert snippet_from_db.filepath == new_filepath
@ -250,6 +260,9 @@ def test_get_all_key_snippets(setup_db):
# Verify we got the correct number of snippets
assert len(all_snippets) == len(snippets_data)
# Verify all returned snippets are Pydantic models
assert all(isinstance(snippet, KeySnippetModel) for snippet in all_snippets)
# Verify the content of each snippet
for i, snippet in enumerate(all_snippets):
assert snippet.filepath == snippets_data[i]["filepath"]
@ -302,3 +315,30 @@ def test_get_snippets_dict(setup_db):
assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"]
assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"]
assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"]
def test_to_model_conversion(setup_db):
"""Test conversion from Peewee model to Pydantic model."""
repo = KeySnippetRepository(db=setup_db)
# Create a snippet in the database using Peewee directly
peewee_snippet = KeySnippet.create(
filepath="conversion_test.py",
line_number=100,
snippet="def conversion_test():",
description="Test model conversion"
)
# Use the _to_model method to convert it
pydantic_snippet = repo._to_model(peewee_snippet)
# Verify the conversion was successful
assert isinstance(pydantic_snippet, KeySnippetModel)
assert pydantic_snippet.id == peewee_snippet.id
assert pydantic_snippet.filepath == peewee_snippet.filepath
assert pydantic_snippet.line_number == peewee_snippet.line_number
assert pydantic_snippet.snippet == peewee_snippet.snippet
assert pydantic_snippet.description == peewee_snippet.description
# Test conversion of None
assert repo._to_model(None) is None

View File

@ -0,0 +1,110 @@
"""
Tests for the Pydantic models in ra_aid.database.pydantic_models
"""
import datetime
import json
import pytest
from ra_aid.database.models import Session
from ra_aid.database.pydantic_models import SessionModel
class TestSessionModel:
"""Tests for the SessionModel Pydantic model"""
def test_from_peewee_model(self):
"""Test conversion from a Peewee model instance"""
# Create a Peewee Session instance
now = datetime.datetime.now()
metadata = {"os": "Linux", "cpu_cores": 8, "memory_gb": 16}
session = Session(
id=1,
created_at=now,
updated_at=now,
start_time=now,
command_line="ra-aid run",
program_version="1.0.0",
machine_info=json.dumps(metadata)
)
# Convert to Pydantic model
session_model = SessionModel.model_validate(session, from_attributes=True)
# Verify fields
assert session_model.id == 1
assert session_model.created_at == now
assert session_model.updated_at == now
assert session_model.start_time == now
assert session_model.command_line == "ra-aid run"
assert session_model.program_version == "1.0.0"
assert session_model.machine_info == metadata
def test_with_dict_machine_info(self):
"""Test creating a model with a dict for machine_info"""
# Create directly with a dict for machine_info
now = datetime.datetime.now()
metadata = {"os": "Windows", "cpu_cores": 4, "memory_gb": 8}
session_model = SessionModel(
id=2,
created_at=now,
updated_at=now,
start_time=now,
command_line="ra-aid --debug",
program_version="1.0.1",
machine_info=metadata
)
# Verify fields
assert session_model.id == 2
assert session_model.machine_info == metadata
def test_with_none_machine_info(self):
"""Test creating a model with None for machine_info"""
now = datetime.datetime.now()
session_model = SessionModel(
id=3,
created_at=now,
updated_at=now,
start_time=now,
command_line="ra-aid",
program_version="1.0.0",
machine_info=None
)
assert session_model.id == 3
assert session_model.machine_info is None
def test_invalid_json_machine_info(self):
"""Test error handling for invalid JSON in machine_info"""
now = datetime.datetime.now()
# Invalid JSON string should raise ValueError
with pytest.raises(ValueError):
SessionModel(
id=4,
created_at=now,
updated_at=now,
start_time=now,
command_line="ra-aid",
program_version="1.0.0",
machine_info="{invalid json}"
)
def test_unexpected_type_machine_info(self):
"""Test error handling for unexpected type in machine_info"""
now = datetime.datetime.now()
# Integer type should raise ValueError
with pytest.raises(ValueError):
SessionModel(
id=5,
created_at=now,
updated_at=now,
start_time=now,
command_line="ra-aid",
program_version="1.0.0",
machine_info=123 # Not a dict or string
)

View File

@ -9,6 +9,7 @@ import peewee
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import ResearchNote, BaseModel
from ra_aid.database.pydantic_models import ResearchNoteModel
from ra_aid.database.repositories.research_note_repository import (
ResearchNoteRepository,
ResearchNoteRepositoryManager,
@ -90,10 +91,12 @@ def test_create_research_note(setup_db):
# Verify the note was created correctly
assert note.id is not None
assert note.content == content
assert isinstance(note, ResearchNoteModel)
# Verify we can retrieve it from the database using the repository
note_from_db = repo.get(note.id)
assert note_from_db.content == content
assert isinstance(note_from_db, ResearchNoteModel)
def test_get_research_note(setup_db):
@ -112,6 +115,7 @@ def test_get_research_note(setup_db):
assert retrieved_note is not None
assert retrieved_note.id == note.id
assert retrieved_note.content == content
assert isinstance(retrieved_note, ResearchNoteModel)
# Try to retrieve a non-existent note
non_existent_note = repo.get(999)
@ -135,10 +139,12 @@ def test_update_research_note(setup_db):
assert updated_note is not None
assert updated_note.id == note.id
assert updated_note.content == new_content
assert isinstance(updated_note, ResearchNoteModel)
# Verify we can retrieve the updated content from the database using the repository
note_from_db = repo.get(note.id)
assert note_from_db.content == new_content
assert isinstance(note_from_db, ResearchNoteModel)
# Try to update a non-existent note
non_existent_update = repo.update(999, "This shouldn't work")
@ -187,8 +193,11 @@ def test_get_all_research_notes(setup_db):
# Verify we got the correct number of notes
assert len(all_notes) == len(contents)
# Verify the content of each note
# Verify the content of each note and that they are Pydantic models
note_contents = [note.content for note in all_notes]
for note in all_notes:
assert isinstance(note, ResearchNoteModel)
for content in contents:
assert content in note_contents
@ -239,6 +248,7 @@ def test_research_note_repository_manager(setup_db, cleanup_repo):
note = repo.create(content)
assert note.id is not None
assert note.content == content
assert isinstance(note, ResearchNoteModel)
# Verify we can get the repository using get_research_note_repository
repo_from_var = get_research_note_repository()
@ -259,3 +269,25 @@ def test_get_research_note_repository_when_not_set(cleanup_repo):
# Verify the correct error message
assert "No ResearchNoteRepository available" in str(excinfo.value)
def test_to_model_method(setup_db):
"""Test the _to_model method converts Peewee models to Pydantic models correctly."""
# Set up repository
repo = ResearchNoteRepository(db=setup_db)
# Create a Peewee ResearchNote directly
peewee_note = ResearchNote.create(content="Test note for conversion")
# Convert it using _to_model
pydantic_note = repo._to_model(peewee_note)
# Verify the conversion
assert isinstance(pydantic_note, ResearchNoteModel)
assert pydantic_note.id == peewee_note.id
assert pydantic_note.content == peewee_note.content
assert pydantic_note.created_at == peewee_note.created_at
assert pydantic_note.updated_at == peewee_note.updated_at
# Test with None
assert repo._to_model(None) is None

View File

@ -0,0 +1,347 @@
"""
Tests for the SessionRepository class.
"""
import pytest
import datetime
import json
from unittest.mock import patch
import peewee
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import Session, BaseModel
from ra_aid.database.repositories.session_repository import (
SessionRepository,
SessionRepositoryManager,
get_session_repository,
session_repo_var
)
from ra_aid.database.pydantic_models import SessionModel
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def cleanup_repo():
"""Reset the repository contextvar after each test."""
# Reset before the test
session_repo_var.set(None)
# Run the test
yield
# Reset after the test
session_repo_var.set(None)
@pytest.fixture
def setup_db(cleanup_db):
"""Set up an in-memory database with the Session table and patch the BaseModel.Meta.database."""
# Initialize an in-memory database connection
with DatabaseManager(in_memory=True) as db:
# Patch the BaseModel.Meta.database to use our in-memory database
with patch.object(BaseModel._meta, 'database', db):
# Create the Session table
with db.atomic():
db.create_tables([Session], safe=True)
yield db
# Clean up
with db.atomic():
Session.drop_table(safe=True)
@pytest.fixture
def test_metadata():
"""Return test metadata for sessions."""
return {
"os": "Test OS",
"version": "1.0",
"cpu_cores": 4,
"memory_gb": 16,
"additional_info": {
"gpu": "Test GPU",
"display_resolution": "1920x1080"
}
}
@pytest.fixture
def sample_session(setup_db, test_metadata):
"""Create a sample session in the database."""
now = datetime.datetime.now()
return Session.create(
start_time=now,
command_line="ra-aid test",
program_version="1.0.0",
machine_info=json.dumps(test_metadata)
)
def test_create_session_with_metadata(setup_db, test_metadata):
"""Test creating a session with metadata."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Create a session with metadata
session = repo.create_session(metadata=test_metadata)
# Verify type is SessionModel, not Session
assert isinstance(session, SessionModel)
# Verify the session was created correctly
assert session.id is not None
assert session.command_line is not None
assert session.program_version is not None
# Verify machine_info is a dict, not a JSON string
assert isinstance(session.machine_info, dict)
assert session.machine_info == test_metadata
# Verify the dictionary structure is preserved
assert "additional_info" in session.machine_info
assert session.machine_info["additional_info"]["gpu"] == "Test GPU"
def test_create_session_without_metadata(setup_db):
"""Test creating a session without metadata."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Create a session without metadata
session = repo.create_session()
# Verify type is SessionModel, not Session
assert isinstance(session, SessionModel)
# Verify the session was created correctly
assert session.id is not None
assert session.command_line is not None
assert session.program_version is not None
# Verify machine_info is None
assert session.machine_info is None
def test_get_current_session(setup_db, sample_session):
"""Test retrieving the current session."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Set the current session
repo.current_session = sample_session
# Get the current session
current_session = repo.get_current_session()
# Verify type is SessionModel, not Session
assert isinstance(current_session, SessionModel)
# Verify the retrieved session matches the original
assert current_session.id == sample_session.id
assert current_session.command_line == sample_session.command_line
assert current_session.program_version == sample_session.program_version
# Verify machine_info is a dict, not a JSON string
assert isinstance(current_session.machine_info, dict)
def test_get_current_session_from_db(setup_db, sample_session):
"""Test retrieving the current session from the database when no current session is set."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Get the current session (should retrieve the most recent from DB)
current_session = repo.get_current_session()
# Verify type is SessionModel, not Session
assert isinstance(current_session, SessionModel)
# Verify the retrieved session matches the sample session
assert current_session.id == sample_session.id
# Verify machine_info is a dict, not a JSON string
assert isinstance(current_session.machine_info, dict)
def test_get_by_id(setup_db, sample_session):
"""Test retrieving a session by ID."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Get the session by ID
session = repo.get(sample_session.id)
# Verify type is SessionModel, not Session
assert isinstance(session, SessionModel)
# Verify the retrieved session matches the original
assert session.id == sample_session.id
assert session.command_line == sample_session.command_line
assert session.program_version == sample_session.program_version
# Verify machine_info is a dict, not a JSON string
assert isinstance(session.machine_info, dict)
# Verify getting a non-existent session returns None
non_existent_session = repo.get(999)
assert non_existent_session is None
def test_get_all(setup_db):
"""Test retrieving all sessions."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Create multiple sessions
metadata1 = {"os": "Linux", "cpu_cores": 8}
metadata2 = {"os": "Windows", "cpu_cores": 4}
metadata3 = {"os": "macOS", "cpu_cores": 10}
repo.create_session(metadata=metadata1)
repo.create_session(metadata=metadata2)
repo.create_session(metadata=metadata3)
# Get all sessions
sessions = repo.get_all()
# Verify we got a list of SessionModel objects
assert len(sessions) == 3
for session in sessions:
assert isinstance(session, SessionModel)
assert isinstance(session.machine_info, dict)
# Verify the sessions are in descending order of creation time
assert sessions[0].created_at >= sessions[1].created_at
assert sessions[1].created_at >= sessions[2].created_at
# Verify the machine_info fields
os_values = [session.machine_info["os"] for session in sessions]
assert "Linux" in os_values
assert "Windows" in os_values
assert "macOS" in os_values
def test_get_all_empty(setup_db):
"""Test retrieving all sessions when none exist."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Get all sessions
sessions = repo.get_all()
# Verify we got an empty list
assert isinstance(sessions, list)
assert len(sessions) == 0
def test_get_recent(setup_db):
"""Test retrieving recent sessions with a limit."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Create multiple sessions
for i in range(5):
metadata = {"index": i, "os": f"OS {i}"}
repo.create_session(metadata=metadata)
# Get recent sessions with limit=3
sessions = repo.get_recent(limit=3)
# Verify we got the correct number of SessionModel objects
assert len(sessions) == 3
for session in sessions:
assert isinstance(session, SessionModel)
assert isinstance(session.machine_info, dict)
# Verify the sessions are in descending order and are the most recent ones
indexes = [session.machine_info["index"] for session in sessions]
assert indexes == [4, 3, 2] # Most recent first
def test_session_repository_manager(setup_db, cleanup_repo):
"""Test the SessionRepositoryManager context manager."""
# Use the context manager to create a repository
with SessionRepositoryManager(setup_db) as repo:
# Verify the repository was created correctly
assert isinstance(repo, SessionRepository)
assert repo.db is setup_db
# Create a session and verify it's a SessionModel
metadata = {"test": "manager"}
session = repo.create_session(metadata=metadata)
assert isinstance(session, SessionModel)
assert session.machine_info["test"] == "manager"
# Verify we can get the repository using get_session_repository
repo_from_var = get_session_repository()
assert repo_from_var is repo
# Verify the repository was removed from the context var
with pytest.raises(RuntimeError) as excinfo:
get_session_repository()
assert "No SessionRepository available" in str(excinfo.value)
def test_repository_init_without_db():
"""Test that SessionRepository raises an error when initialized without a db parameter."""
# Attempt to create a repository without a database connection
with pytest.raises(ValueError) as excinfo:
SessionRepository(db=None)
# Verify the correct error message
assert "Database connection is required" in str(excinfo.value)
def test_get_current_session_id(setup_db, sample_session):
"""Test retrieving the ID of the current session."""
# Set up repository
repo = SessionRepository(db=setup_db)
# Set the current session
repo.current_session = sample_session
# Get the current session ID
session_id = repo.get_current_session_id()
# Verify the ID matches
assert session_id == sample_session.id
# Test when no current session exists
repo.current_session = None
# Delete all sessions
Session.delete().execute()
# Verify None is returned when no session exists
session_id = repo.get_current_session_id()
assert session_id is None

View File

@ -0,0 +1,458 @@
"""
Tests for the TrajectoryRepository class.
"""
import pytest
import datetime
import json
from unittest.mock import patch
import peewee
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import Trajectory, HumanInput, Session, BaseModel
from ra_aid.database.repositories.trajectory_repository import (
TrajectoryRepository,
TrajectoryRepositoryManager,
get_trajectory_repository,
trajectory_repo_var
)
from ra_aid.database.pydantic_models import TrajectoryModel
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def cleanup_repo():
"""Reset the repository contextvar after each test."""
# Reset before the test
trajectory_repo_var.set(None)
# Run the test
yield
# Reset after the test
trajectory_repo_var.set(None)
@pytest.fixture
def setup_db(cleanup_db):
"""Set up an in-memory database with the necessary tables and patch the BaseModel.Meta.database."""
# Initialize an in-memory database connection
with DatabaseManager(in_memory=True) as db:
# Patch the BaseModel.Meta.database to use our in-memory database
with patch.object(BaseModel._meta, 'database', db):
# Create the required tables
with db.atomic():
db.create_tables([Trajectory, HumanInput, Session], safe=True)
yield db
# Clean up
with db.atomic():
Trajectory.drop_table(safe=True)
HumanInput.drop_table(safe=True)
Session.drop_table(safe=True)
@pytest.fixture
def sample_human_input(setup_db):
"""Create a sample human input in the database."""
return HumanInput.create(
content="Test human input",
source="test"
)
@pytest.fixture
def test_tool_parameters():
"""Return test tool parameters."""
return {
"pattern": "test pattern",
"file_path": "/path/to/file",
"options": {
"case_sensitive": True,
"whole_words": False
}
}
@pytest.fixture
def test_tool_result():
"""Return test tool result."""
return {
"matches": [
{"line": 10, "content": "This is a test pattern"},
{"line": 20, "content": "Another test pattern here"}
],
"total_matches": 2,
"execution_time": 0.5
}
@pytest.fixture
def test_step_data():
"""Return test step data for UI rendering."""
return {
"display_type": "text",
"content": "Tool execution results",
"highlights": [
{"start": 10, "end": 15, "color": "red"}
]
}
@pytest.fixture
def sample_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
"""Create a sample trajectory in the database."""
return Trajectory.create(
human_input=sample_human_input,
tool_name="ripgrep_search",
tool_parameters=json.dumps(test_tool_parameters),
tool_result=json.dumps(test_tool_result),
step_data=json.dumps(test_step_data),
record_type="tool_execution",
cost=0.001,
tokens=100,
is_error=False
)
def test_create_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
"""Test creating a trajectory with all fields."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create a trajectory
trajectory = repo.create(
tool_name="ripgrep_search",
tool_parameters=test_tool_parameters,
tool_result=test_tool_result,
step_data=test_step_data,
record_type="tool_execution",
human_input_id=sample_human_input.id,
cost=0.001,
tokens=100
)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the trajectory was created correctly
assert trajectory.id is not None
assert trajectory.tool_name == "ripgrep_search"
# Verify the JSON fields are dictionaries, not strings
assert isinstance(trajectory.tool_parameters, dict)
assert isinstance(trajectory.tool_result, dict)
assert isinstance(trajectory.step_data, dict)
# Verify the nested structure of tool parameters
assert trajectory.tool_parameters["options"]["case_sensitive"] == True
assert trajectory.tool_result["total_matches"] == 2
assert trajectory.step_data["highlights"][0]["color"] == "red"
# Verify foreign key reference
assert trajectory.human_input_id == sample_human_input.id
def test_create_trajectory_minimal(setup_db):
"""Test creating a trajectory with minimal fields."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create a trajectory with minimal fields
trajectory = repo.create(
tool_name="simple_tool"
)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the trajectory was created correctly
assert trajectory.id is not None
assert trajectory.tool_name == "simple_tool"
# Verify optional fields are None
assert trajectory.tool_parameters is None
assert trajectory.tool_result is None
assert trajectory.step_data is None
assert trajectory.human_input_id is None
assert trajectory.cost is None
assert trajectory.tokens is None
assert trajectory.is_error is False
def test_get_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
"""Test retrieving a trajectory by ID."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Get the trajectory by ID
trajectory = repo.get(sample_trajectory.id)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the retrieved trajectory matches the original
assert trajectory.id == sample_trajectory.id
assert trajectory.tool_name == sample_trajectory.tool_name
# Verify the JSON fields are dictionaries, not strings
assert isinstance(trajectory.tool_parameters, dict)
assert isinstance(trajectory.tool_result, dict)
assert isinstance(trajectory.step_data, dict)
# Verify the content of JSON fields
assert trajectory.tool_parameters == test_tool_parameters
assert trajectory.tool_result == test_tool_result
assert trajectory.step_data == test_step_data
# Verify non-existent trajectory returns None
non_existent_trajectory = repo.get(999)
assert non_existent_trajectory is None
def test_update_trajectory(setup_db, sample_trajectory):
"""Test updating a trajectory."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# New data for update
new_tool_result = {
"matches": [
{"line": 15, "content": "Updated test pattern"}
],
"total_matches": 1,
"execution_time": 0.3
}
new_step_data = {
"display_type": "html",
"content": "Updated UI rendering",
"highlights": []
}
# Update the trajectory
updated_trajectory = repo.update(
trajectory_id=sample_trajectory.id,
tool_result=new_tool_result,
step_data=new_step_data,
cost=0.002,
tokens=200,
is_error=True,
error_message="Test error",
error_type="TestErrorType",
error_details="Detailed error information"
)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(updated_trajectory, TrajectoryModel)
# Verify the fields were updated
assert updated_trajectory.tool_result == new_tool_result
assert updated_trajectory.step_data == new_step_data
assert updated_trajectory.cost == 0.002
assert updated_trajectory.tokens == 200
assert updated_trajectory.is_error is True
assert updated_trajectory.error_message == "Test error"
assert updated_trajectory.error_type == "TestErrorType"
assert updated_trajectory.error_details == "Detailed error information"
# Original tool parameters should not change
# We need to parse the JSON string from the Peewee object for comparison
original_params = json.loads(sample_trajectory.tool_parameters)
assert updated_trajectory.tool_parameters == original_params
# Verify updating a non-existent trajectory returns None
non_existent_update = repo.update(trajectory_id=999, cost=0.005)
assert non_existent_update is None
def test_delete_trajectory(setup_db, sample_trajectory):
"""Test deleting a trajectory."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Verify the trajectory exists
assert repo.get(sample_trajectory.id) is not None
# Delete the trajectory
result = repo.delete(sample_trajectory.id)
# Verify the trajectory was deleted
assert result is True
assert repo.get(sample_trajectory.id) is None
# Verify deleting a non-existent trajectory returns False
result = repo.delete(999)
assert result is False
def test_get_all_trajectories(setup_db, sample_human_input):
"""Test retrieving all trajectories."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create multiple trajectories
for i in range(3):
repo.create(
tool_name=f"tool_{i}",
tool_parameters={"index": i},
human_input_id=sample_human_input.id
)
# Get all trajectories
trajectories = repo.get_all()
# Verify we got a dictionary of TrajectoryModel objects
assert len(trajectories) == 3
for trajectory_id, trajectory in trajectories.items():
assert isinstance(trajectory, TrajectoryModel)
assert isinstance(trajectory.tool_parameters, dict)
# Verify the trajectories have the correct tool names
tool_names = {trajectory.tool_name for trajectory in trajectories.values()}
assert "tool_0" in tool_names
assert "tool_1" in tool_names
assert "tool_2" in tool_names
def test_get_trajectories_by_human_input(setup_db, sample_human_input):
"""Test retrieving trajectories by human input ID."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create another human input
other_human_input = HumanInput.create(
content="Another human input",
source="test"
)
# Create trajectories for both human inputs
for i in range(2):
repo.create(
tool_name=f"tool_1_{i}",
human_input_id=sample_human_input.id
)
for i in range(3):
repo.create(
tool_name=f"tool_2_{i}",
human_input_id=other_human_input.id
)
# Get trajectories for the first human input
trajectories = repo.get_trajectories_by_human_input(sample_human_input.id)
# Verify we got a list of TrajectoryModel objects for the first human input
assert len(trajectories) == 2
for trajectory in trajectories:
assert isinstance(trajectory, TrajectoryModel)
assert trajectory.human_input_id == sample_human_input.id
assert trajectory.tool_name.startswith("tool_1")
# Get trajectories for the second human input
trajectories = repo.get_trajectories_by_human_input(other_human_input.id)
# Verify we got a list of TrajectoryModel objects for the second human input
assert len(trajectories) == 3
for trajectory in trajectories:
assert isinstance(trajectory, TrajectoryModel)
assert trajectory.human_input_id == other_human_input.id
assert trajectory.tool_name.startswith("tool_2")
def test_get_parsed_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
"""Test retrieving a parsed trajectory."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Get the parsed trajectory
trajectory = repo.get_parsed_trajectory(sample_trajectory.id)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the retrieved trajectory matches the original
assert trajectory.id == sample_trajectory.id
assert trajectory.tool_name == sample_trajectory.tool_name
# Verify the JSON fields are dictionaries, not strings
assert isinstance(trajectory.tool_parameters, dict)
assert isinstance(trajectory.tool_result, dict)
assert isinstance(trajectory.step_data, dict)
# Verify the content of JSON fields
assert trajectory.tool_parameters == test_tool_parameters
assert trajectory.tool_result == test_tool_result
assert trajectory.step_data == test_step_data
# Verify non-existent trajectory returns None
non_existent_trajectory = repo.get_parsed_trajectory(999)
assert non_existent_trajectory is None
def test_trajectory_repository_manager(setup_db, cleanup_repo):
"""Test the TrajectoryRepositoryManager context manager."""
# Use the context manager to create a repository
with TrajectoryRepositoryManager(setup_db) as repo:
# Verify the repository was created correctly
assert isinstance(repo, TrajectoryRepository)
assert repo.db is setup_db
# Create a trajectory and verify it's a TrajectoryModel
tool_parameters = {"test": "manager"}
trajectory = repo.create(
tool_name="manager_test",
tool_parameters=tool_parameters
)
assert isinstance(trajectory, TrajectoryModel)
assert trajectory.tool_parameters["test"] == "manager"
# Verify we can get the repository using get_trajectory_repository
repo_from_var = get_trajectory_repository()
assert repo_from_var is repo
# Verify the repository was removed from the context var
with pytest.raises(RuntimeError) as excinfo:
get_trajectory_repository()
assert "No TrajectoryRepository available" in str(excinfo.value)
def test_repository_init_without_db():
"""Test that TrajectoryRepository raises an error when initialized without a db parameter."""
# Attempt to create a repository without a database connection
with pytest.raises(ValueError) as excinfo:
TrajectoryRepository(db=None)
# Verify the correct error message
assert "Database connection is required" in str(excinfo.value)