diff --git a/ra_aid/database/pydantic_models.py b/ra_aid/database/pydantic_models.py new file mode 100644 index 0000000..aca7ee8 --- /dev/null +++ b/ra_aid/database/pydantic_models.py @@ -0,0 +1,376 @@ +""" +Pydantic models for ra_aid database entities. + +This module defines Pydantic models that correspond to Peewee ORM models, +providing validation, serialization, and deserialization capabilities. +""" + +import datetime +import json +from typing import Dict, List, Any, Optional + +from pydantic import BaseModel, ConfigDict, field_serializer, field_validator + + +class SessionModel(BaseModel): + """ + Pydantic model representing a Session. + + This model corresponds to the Session Peewee ORM model and provides + validation and serialization capabilities. It handles the conversion + between JSON-encoded strings and Python dictionaries for the machine_info field. + + Attributes: + id: Unique identifier for the session + created_at: When the session record was created + updated_at: When the session record was last updated + start_time: When the program session started + command_line: Command line arguments used to start the program + program_version: Version of the program + machine_info: Dictionary containing machine-specific metadata + """ + id: Optional[int] = None + created_at: datetime.datetime + updated_at: datetime.datetime + start_time: datetime.datetime + command_line: Optional[str] = None + program_version: Optional[str] = None + machine_info: Optional[Dict[str, Any]] = None + + # Configure the model to work with ORM objects + model_config = ConfigDict(from_attributes=True) + + @field_validator("machine_info", mode="before") + @classmethod + def parse_machine_info(cls, value: Any) -> Optional[Dict[str, Any]]: + """ + Parse the machine_info field from a JSON string to a dictionary. + + Args: + value: The value to parse, can be a string, dict, or None + + Returns: + Optional[Dict[str, Any]]: The parsed dictionary or None + + Raises: + ValueError: If the JSON string is invalid + """ + if value is None: + return None + + if isinstance(value, dict): + return value + + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in machine_info: {e}") + + raise ValueError(f"Unexpected type for machine_info: {type(value)}") + + @field_serializer("machine_info") + def serialize_machine_info(self, machine_info: Optional[Dict[str, Any]]) -> Optional[str]: + """ + Serialize the machine_info dictionary to a JSON string for storage. + + Args: + machine_info: Dictionary to serialize + + Returns: + Optional[str]: JSON-encoded string or None + """ + if machine_info is None: + return None + + return json.dumps(machine_info) + + +class HumanInputModel(BaseModel): + """ + Pydantic model representing a HumanInput. + + This model corresponds to the HumanInput Peewee ORM model and provides + validation and serialization capabilities. + + Attributes: + id: Unique identifier for the human input + created_at: When the record was created + updated_at: When the record was last updated + content: The text content of the input + source: The source of the input ('cli', 'chat', or 'hil') + session_id: Optional reference to the associated session + """ + id: Optional[int] = None + created_at: datetime.datetime + updated_at: datetime.datetime + content: str + source: str + session_id: Optional[int] = None + + # Configure the model to work with ORM objects + model_config = ConfigDict(from_attributes=True) + + +class KeyFactModel(BaseModel): + """ + Pydantic model representing a KeyFact. + + This model corresponds to the KeyFact Peewee ORM model and provides + validation and serialization capabilities. + + Attributes: + id: Unique identifier for the key fact + created_at: When the record was created + updated_at: When the record was last updated + content: The text content of the key fact + human_input_id: Optional reference to the associated human input + session_id: Optional reference to the associated session + """ + id: Optional[int] = None + created_at: datetime.datetime + updated_at: datetime.datetime + content: str + human_input_id: Optional[int] = None + session_id: Optional[int] = None + + # Configure the model to work with ORM objects + model_config = ConfigDict(from_attributes=True) + + +class KeySnippetModel(BaseModel): + """ + Pydantic model representing a KeySnippet. + + This model corresponds to the KeySnippet Peewee ORM model and provides + validation and serialization capabilities. + + Attributes: + id: Unique identifier for the key snippet + created_at: When the record was created + updated_at: When the record was last updated + filepath: Path to the source file + line_number: Line number where the snippet starts + snippet: The source code snippet text + description: Optional description of the significance + human_input_id: Optional reference to the associated human input + session_id: Optional reference to the associated session + """ + id: Optional[int] = None + created_at: datetime.datetime + updated_at: datetime.datetime + filepath: str + line_number: int + snippet: str + description: Optional[str] = None + human_input_id: Optional[int] = None + session_id: Optional[int] = None + + # Configure the model to work with ORM objects + model_config = ConfigDict(from_attributes=True) + + +class ResearchNoteModel(BaseModel): + """ + Pydantic model representing a ResearchNote. + + This model corresponds to the ResearchNote Peewee ORM model and provides + validation and serialization capabilities. + + Attributes: + id: Unique identifier for the research note + created_at: When the record was created + updated_at: When the record was last updated + content: The text content of the research note + human_input_id: Optional reference to the associated human input + session_id: Optional reference to the associated session + """ + id: Optional[int] = None + created_at: datetime.datetime + updated_at: datetime.datetime + content: str + human_input_id: Optional[int] = None + session_id: Optional[int] = None + + # Configure the model to work with ORM objects + model_config = ConfigDict(from_attributes=True) + + +class TrajectoryModel(BaseModel): + """ + Pydantic model representing a Trajectory. + + This model corresponds to the Trajectory Peewee ORM model and provides + validation and serialization capabilities. It handles the conversion + between JSON-encoded strings and Python dictionaries for the tool_parameters, + tool_result, and step_data fields. + + Attributes: + id: Unique identifier for the trajectory + created_at: When the record was created + updated_at: When the record was last updated + human_input_id: Optional reference to the associated human input + tool_name: Name of the tool that was executed + tool_parameters: Dictionary containing the parameters passed to the tool + tool_result: Dictionary containing the result returned by the tool + step_data: Dictionary containing UI rendering data + record_type: Type of trajectory record + cost: Optional cost of the tool execution + tokens: Optional token usage of the tool execution + is_error: Flag indicating if this record represents an error + error_message: The error message if is_error is True + error_type: The type/class of the error if is_error is True + error_details: Additional error details if is_error is True + session_id: Optional reference to the associated session + """ + id: Optional[int] = None + created_at: datetime.datetime + updated_at: datetime.datetime + human_input_id: Optional[int] = None + tool_name: Optional[str] = None + tool_parameters: Optional[Dict[str, Any]] = None + tool_result: Optional[Any] = None + step_data: Optional[Dict[str, Any]] = None + record_type: Optional[str] = None + cost: Optional[float] = None + tokens: Optional[int] = None + is_error: bool = False + error_message: Optional[str] = None + error_type: Optional[str] = None + error_details: Optional[str] = None + session_id: Optional[int] = None + + # Configure the model to work with ORM objects + model_config = ConfigDict(from_attributes=True) + + @field_validator("tool_parameters", mode="before") + @classmethod + def parse_tool_parameters(cls, value: Any) -> Optional[Dict[str, Any]]: + """ + Parse the tool_parameters field from a JSON string to a dictionary. + + Args: + value: The value to parse, can be a string, dict, or None + + Returns: + Optional[Dict[str, Any]]: The parsed dictionary or None + + Raises: + ValueError: If the JSON string is invalid + """ + if value is None: + return None + + if isinstance(value, dict): + return value + + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in tool_parameters: {e}") + + raise ValueError(f"Unexpected type for tool_parameters: {type(value)}") + + @field_validator("tool_result", mode="before") + @classmethod + def parse_tool_result(cls, value: Any) -> Optional[Any]: + """ + Parse the tool_result field from a JSON string to a Python object. + + Args: + value: The value to parse, can be a string, dict, list, or None + + Returns: + Optional[Any]: The parsed object or None + + Raises: + ValueError: If the JSON string is invalid + """ + if value is None: + return None + + if not isinstance(value, str): + return value + + try: + return json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in tool_result: {e}") + + @field_validator("step_data", mode="before") + @classmethod + def parse_step_data(cls, value: Any) -> Optional[Dict[str, Any]]: + """ + Parse the step_data field from a JSON string to a dictionary. + + Args: + value: The value to parse, can be a string, dict, or None + + Returns: + Optional[Dict[str, Any]]: The parsed dictionary or None + + Raises: + ValueError: If the JSON string is invalid + """ + if value is None: + return None + + if isinstance(value, dict): + return value + + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in step_data: {e}") + + raise ValueError(f"Unexpected type for step_data: {type(value)}") + + @field_serializer("tool_parameters") + def serialize_tool_parameters(self, tool_parameters: Optional[Dict[str, Any]]) -> Optional[str]: + """ + Serialize the tool_parameters dictionary to a JSON string for storage. + + Args: + tool_parameters: Dictionary to serialize + + Returns: + Optional[str]: JSON-encoded string or None + """ + if tool_parameters is None: + return None + + return json.dumps(tool_parameters) + + @field_serializer("tool_result") + def serialize_tool_result(self, tool_result: Optional[Any]) -> Optional[str]: + """ + Serialize the tool_result object to a JSON string for storage. + + Args: + tool_result: Object to serialize + + Returns: + Optional[str]: JSON-encoded string or None + """ + if tool_result is None: + return None + + return json.dumps(tool_result) + + @field_serializer("step_data") + def serialize_step_data(self, step_data: Optional[Dict[str, Any]]) -> Optional[str]: + """ + Serialize the step_data dictionary to a JSON string for storage. + + Args: + step_data: Dictionary to serialize + + Returns: + Optional[str]: JSON-encoded string or None + """ + if step_data is None: + return None + + return json.dumps(step_data) \ No newline at end of file diff --git a/ra_aid/database/repositories/human_input_repository.py b/ra_aid/database/repositories/human_input_repository.py index f20853d..6384b90 100644 --- a/ra_aid/database/repositories/human_input_repository.py +++ b/ra_aid/database/repositories/human_input_repository.py @@ -11,6 +11,7 @@ import contextvars import peewee from ra_aid.database.models import HumanInput +from ra_aid.database.pydantic_models import HumanInputModel from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -118,8 +119,23 @@ class HumanInputRepository: if db is None: raise ValueError("Database connection is required for HumanInputRepository") self.db = db + + def _to_model(self, human_input: Optional[HumanInput]) -> Optional[HumanInputModel]: + """ + Convert a Peewee HumanInput object to a Pydantic HumanInputModel. + + Args: + human_input: Peewee HumanInput instance or None + + Returns: + Optional[HumanInputModel]: Pydantic model representation or None if human_input is None + """ + if human_input is None: + return None + + return HumanInputModel.model_validate(human_input, from_attributes=True) - def create(self, content: str, source: str) -> HumanInput: + def create(self, content: str, source: str) -> HumanInputModel: """ Create a new human input record in the database. @@ -128,7 +144,7 @@ class HumanInputRepository: source: The source of the input (e.g., "cli", "chat", "hil") Returns: - HumanInput: The newly created human input instance + HumanInputModel: The newly created human input instance Raises: peewee.DatabaseError: If there's an error creating the record @@ -136,12 +152,12 @@ class HumanInputRepository: try: input_record = HumanInput.create(content=content, source=source) logger.debug(f"Created human input ID {input_record.id} from {source}") - return input_record + return self._to_model(input_record) except peewee.DatabaseError as e: logger.error(f"Failed to create human input record: {str(e)}") raise - def get(self, input_id: int) -> Optional[HumanInput]: + def get(self, input_id: int) -> Optional[HumanInputModel]: """ Retrieve a human input record by its ID. @@ -149,18 +165,19 @@ class HumanInputRepository: input_id: The ID of the human input to retrieve Returns: - Optional[HumanInput]: The human input instance if found, None otherwise + Optional[HumanInputModel]: The human input instance if found, None otherwise Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return HumanInput.get_or_none(HumanInput.id == input_id) + human_input = HumanInput.get_or_none(HumanInput.id == input_id) + return self._to_model(human_input) except peewee.DatabaseError as e: logger.error(f"Failed to fetch human input {input_id}: {str(e)}") raise - def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInput]: + def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInputModel]: """ Update an existing human input record. @@ -170,14 +187,14 @@ class HumanInputRepository: source: The new source for the human input Returns: - Optional[HumanInput]: The updated human input if found, None otherwise + Optional[HumanInputModel]: The updated human input if found, None otherwise Raises: peewee.DatabaseError: If there's an error updating the record """ try: - # First check if the record exists - input_record = self.get(input_id) + # We need to get the raw Peewee object for updating + input_record = HumanInput.get_or_none(HumanInput.id == input_id) if not input_record: logger.warning(f"Attempted to update non-existent human input {input_id}") return None @@ -190,7 +207,7 @@ class HumanInputRepository: input_record.save() logger.debug(f"Updated human input ID {input_id}") - return input_record + return self._to_model(input_record) except peewee.DatabaseError as e: logger.error(f"Failed to update human input {input_id}: {str(e)}") raise @@ -223,23 +240,24 @@ class HumanInputRepository: logger.error(f"Failed to delete human input {input_id}: {str(e)}") raise - def get_all(self) -> List[HumanInput]: + def get_all(self) -> List[HumanInputModel]: """ Retrieve all human input records from the database. Returns: - List[HumanInput]: List of all human input instances + List[HumanInputModel]: List of all human input instances Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return list(HumanInput.select().order_by(HumanInput.created_at.desc())) + human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc())) + return [self._to_model(input) for input in human_inputs] except peewee.DatabaseError as e: logger.error(f"Failed to fetch all human inputs: {str(e)}") raise - def get_recent(self, limit: int = 10) -> List[HumanInput]: + def get_recent(self, limit: int = 10) -> List[HumanInputModel]: """ Retrieve the most recent human input records. @@ -247,13 +265,14 @@ class HumanInputRepository: limit: Maximum number of records to retrieve (default: 10) Returns: - List[HumanInput]: List of the most recent human input records + List[HumanInputModel]: List of the most recent human input records Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit)) + human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit)) + return [self._to_model(input) for input in human_inputs] except peewee.DatabaseError as e: logger.error(f"Failed to fetch recent human inputs: {str(e)}") raise @@ -277,7 +296,7 @@ class HumanInputRepository: logger.error(f"Failed to fetch most recent human input ID: {str(e)}") raise - def get_by_source(self, source: str) -> List[HumanInput]: + def get_by_source(self, source: str) -> List[HumanInputModel]: """ Retrieve human input records by source. @@ -285,13 +304,14 @@ class HumanInputRepository: source: The source to filter by (e.g., "cli", "chat", "hil") Returns: - List[HumanInput]: List of human input records from the specified source + List[HumanInputModel]: List of human input records from the specified source Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc())) + human_inputs = list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc())) + return [self._to_model(input) for input in human_inputs] except peewee.DatabaseError as e: logger.error(f"Failed to fetch human inputs by source {source}: {str(e)}") raise diff --git a/ra_aid/database/repositories/key_fact_repository.py b/ra_aid/database/repositories/key_fact_repository.py index 2c29c52..7398642 100644 --- a/ra_aid/database/repositories/key_fact_repository.py +++ b/ra_aid/database/repositories/key_fact_repository.py @@ -12,6 +12,7 @@ from contextlib import contextmanager import peewee from ra_aid.database.models import KeyFact +from ra_aid.database.pydantic_models import KeyFactModel from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -120,7 +121,22 @@ class KeyFactRepository: raise ValueError("Database connection is required for KeyFactRepository") self.db = db - def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact: + def _to_model(self, fact: Optional[KeyFact]) -> Optional[KeyFactModel]: + """ + Convert a Peewee KeyFact object to a Pydantic KeyFactModel. + + Args: + fact: Peewee KeyFact instance or None + + Returns: + Optional[KeyFactModel]: Pydantic model representation or None if fact is None + """ + if fact is None: + return None + + return KeyFactModel.model_validate(fact, from_attributes=True) + + def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFactModel: """ Create a new key fact in the database. @@ -129,7 +145,7 @@ class KeyFactRepository: human_input_id: Optional ID of the associated human input Returns: - KeyFact: The newly created key fact instance + KeyFactModel: The newly created key fact instance Raises: peewee.DatabaseError: If there's an error creating the fact @@ -137,12 +153,12 @@ class KeyFactRepository: try: fact = KeyFact.create(content=content, human_input_id=human_input_id) logger.debug(f"Created key fact ID {fact.id}: {content}") - return fact + return self._to_model(fact) except peewee.DatabaseError as e: logger.error(f"Failed to create key fact: {str(e)}") raise - def get(self, fact_id: int) -> Optional[KeyFact]: + def get(self, fact_id: int) -> Optional[KeyFactModel]: """ Retrieve a key fact by its ID. @@ -150,18 +166,19 @@ class KeyFactRepository: fact_id: The ID of the key fact to retrieve Returns: - Optional[KeyFact]: The key fact instance if found, None otherwise + Optional[KeyFactModel]: The key fact instance if found, None otherwise Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return KeyFact.get_or_none(KeyFact.id == fact_id) + fact = KeyFact.get_or_none(KeyFact.id == fact_id) + return self._to_model(fact) except peewee.DatabaseError as e: logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}") raise - def update(self, fact_id: int, content: str) -> Optional[KeyFact]: + def update(self, fact_id: int, content: str) -> Optional[KeyFactModel]: """ Update an existing key fact. @@ -170,14 +187,14 @@ class KeyFactRepository: content: The new content for the key fact Returns: - Optional[KeyFact]: The updated key fact if found, None otherwise + Optional[KeyFactModel]: The updated key fact if found, None otherwise Raises: peewee.DatabaseError: If there's an error updating the fact """ try: # First check if the fact exists - fact = self.get(fact_id) + fact = KeyFact.get_or_none(KeyFact.id == fact_id) if not fact: logger.warning(f"Attempted to update non-existent key fact {fact_id}") return None @@ -186,7 +203,7 @@ class KeyFactRepository: fact.content = content fact.save() logger.debug(f"Updated key fact ID {fact_id}: {content}") - return fact + return self._to_model(fact) except peewee.DatabaseError as e: logger.error(f"Failed to update key fact {fact_id}: {str(e)}") raise @@ -206,7 +223,7 @@ class KeyFactRepository: """ try: # First check if the fact exists - fact = self.get(fact_id) + fact = KeyFact.get_or_none(KeyFact.id == fact_id) if not fact: logger.warning(f"Attempted to delete non-existent key fact {fact_id}") return False @@ -219,18 +236,19 @@ class KeyFactRepository: logger.error(f"Failed to delete key fact {fact_id}: {str(e)}") raise - def get_all(self) -> List[KeyFact]: + def get_all(self) -> List[KeyFactModel]: """ Retrieve all key facts from the database. Returns: - List[KeyFact]: List of all key fact instances + List[KeyFactModel]: List of all key fact instances Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return list(KeyFact.select().order_by(KeyFact.id)) + facts = list(KeyFact.select().order_by(KeyFact.id)) + return [self._to_model(fact) for fact in facts] except peewee.DatabaseError as e: logger.error(f"Failed to fetch all key facts: {str(e)}") raise diff --git a/ra_aid/database/repositories/key_snippet_repository.py b/ra_aid/database/repositories/key_snippet_repository.py index e63826f..900a06f 100644 --- a/ra_aid/database/repositories/key_snippet_repository.py +++ b/ra_aid/database/repositories/key_snippet_repository.py @@ -11,6 +11,7 @@ import contextvars import peewee from ra_aid.database.models import KeySnippet +from ra_aid.database.pydantic_models import KeySnippetModel from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -129,10 +130,25 @@ class KeySnippetRepository: raise ValueError("Database connection is required for KeySnippetRepository") self.db = db + def _to_model(self, snippet: Optional[KeySnippet]) -> Optional[KeySnippetModel]: + """ + Convert a Peewee KeySnippet object to a Pydantic KeySnippetModel. + + Args: + snippet: Peewee KeySnippet instance or None + + Returns: + Optional[KeySnippetModel]: Pydantic model representation or None if snippet is None + """ + if snippet is None: + return None + + return KeySnippetModel.model_validate(snippet, from_attributes=True) + def create( self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None, human_input_id: Optional[int] = None - ) -> KeySnippet: + ) -> KeySnippetModel: """ Create a new key snippet in the database. @@ -144,7 +160,7 @@ class KeySnippetRepository: human_input_id: Optional ID of the associated human input Returns: - KeySnippet: The newly created key snippet instance + KeySnippetModel: The newly created key snippet instance Raises: peewee.DatabaseError: If there's an error creating the snippet @@ -158,12 +174,12 @@ class KeySnippetRepository: human_input_id=human_input_id ) logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}") - return key_snippet + return self._to_model(key_snippet) except peewee.DatabaseError as e: logger.error(f"Failed to create key snippet: {str(e)}") raise - def get(self, snippet_id: int) -> Optional[KeySnippet]: + def get(self, snippet_id: int) -> Optional[KeySnippetModel]: """ Retrieve a key snippet by its ID. @@ -171,13 +187,14 @@ class KeySnippetRepository: snippet_id: The ID of the key snippet to retrieve Returns: - Optional[KeySnippet]: The key snippet instance if found, None otherwise + Optional[KeySnippetModel]: The key snippet instance if found, None otherwise Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return KeySnippet.get_or_none(KeySnippet.id == snippet_id) + snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id) + return self._to_model(snippet) except peewee.DatabaseError as e: logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}") raise @@ -189,7 +206,7 @@ class KeySnippetRepository: line_number: int, snippet: str, description: Optional[str] = None - ) -> Optional[KeySnippet]: + ) -> Optional[KeySnippetModel]: """ Update an existing key snippet. @@ -201,14 +218,14 @@ class KeySnippetRepository: description: Optional description of the significance Returns: - Optional[KeySnippet]: The updated key snippet if found, None otherwise + Optional[KeySnippetModel]: The updated key snippet if found, None otherwise Raises: peewee.DatabaseError: If there's an error updating the snippet """ try: # First check if the snippet exists - key_snippet = self.get(snippet_id) + key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id) if not key_snippet: logger.warning(f"Attempted to update non-existent key snippet {snippet_id}") return None @@ -220,7 +237,7 @@ class KeySnippetRepository: key_snippet.description = description key_snippet.save() logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}") - return key_snippet + return self._to_model(key_snippet) except peewee.DatabaseError as e: logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}") raise @@ -240,7 +257,7 @@ class KeySnippetRepository: """ try: # First check if the snippet exists - key_snippet = self.get(snippet_id) + key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id) if not key_snippet: logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}") return False @@ -253,18 +270,19 @@ class KeySnippetRepository: logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}") raise - def get_all(self) -> List[KeySnippet]: + def get_all(self) -> List[KeySnippetModel]: """ Retrieve all key snippets from the database. Returns: - List[KeySnippet]: List of all key snippet instances + List[KeySnippetModel]: List of all key snippet instances Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return list(KeySnippet.select().order_by(KeySnippet.id)) + snippets = list(KeySnippet.select().order_by(KeySnippet.id)) + return [self._to_model(snippet) for snippet in snippets] except peewee.DatabaseError as e: logger.error(f"Failed to fetch all key snippets: {str(e)}") raise diff --git a/ra_aid/database/repositories/research_note_repository.py b/ra_aid/database/repositories/research_note_repository.py index 2eb84e4..230c59a 100644 --- a/ra_aid/database/repositories/research_note_repository.py +++ b/ra_aid/database/repositories/research_note_repository.py @@ -12,6 +12,7 @@ from contextlib import contextmanager import peewee from ra_aid.database.models import ResearchNote +from ra_aid.database.pydantic_models import ResearchNoteModel from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -120,7 +121,22 @@ class ResearchNoteRepository: raise ValueError("Database connection is required for ResearchNoteRepository") self.db = db - def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNote: + def _to_model(self, note: Optional[ResearchNote]) -> Optional[ResearchNoteModel]: + """ + Convert a Peewee ResearchNote object to a Pydantic ResearchNoteModel. + + Args: + note: Peewee ResearchNote instance or None + + Returns: + Optional[ResearchNoteModel]: Pydantic model representation or None if note is None + """ + if note is None: + return None + + return ResearchNoteModel.model_validate(note, from_attributes=True) + + def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNoteModel: """ Create a new research note in the database. @@ -129,7 +145,7 @@ class ResearchNoteRepository: human_input_id: Optional ID of the associated human input Returns: - ResearchNote: The newly created research note instance + ResearchNoteModel: The newly created research note instance Raises: peewee.DatabaseError: If there's an error creating the note @@ -137,12 +153,12 @@ class ResearchNoteRepository: try: note = ResearchNote.create(content=content, human_input_id=human_input_id) logger.debug(f"Created research note ID {note.id}: {content[:50]}...") - return note + return self._to_model(note) except peewee.DatabaseError as e: logger.error(f"Failed to create research note: {str(e)}") raise - def get(self, note_id: int) -> Optional[ResearchNote]: + def get(self, note_id: int) -> Optional[ResearchNoteModel]: """ Retrieve a research note by its ID. @@ -150,18 +166,19 @@ class ResearchNoteRepository: note_id: The ID of the research note to retrieve Returns: - Optional[ResearchNote]: The research note instance if found, None otherwise + Optional[ResearchNoteModel]: The research note instance if found, None otherwise Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return ResearchNote.get_or_none(ResearchNote.id == note_id) + note = ResearchNote.get_or_none(ResearchNote.id == note_id) + return self._to_model(note) except peewee.DatabaseError as e: logger.error(f"Failed to fetch research note {note_id}: {str(e)}") raise - def update(self, note_id: int, content: str) -> Optional[ResearchNote]: + def update(self, note_id: int, content: str) -> Optional[ResearchNoteModel]: """ Update an existing research note. @@ -170,14 +187,14 @@ class ResearchNoteRepository: content: The new content for the research note Returns: - Optional[ResearchNote]: The updated research note if found, None otherwise + Optional[ResearchNoteModel]: The updated research note if found, None otherwise Raises: peewee.DatabaseError: If there's an error updating the note """ try: # First check if the note exists - note = self.get(note_id) + note = ResearchNote.get_or_none(ResearchNote.id == note_id) if not note: logger.warning(f"Attempted to update non-existent research note {note_id}") return None @@ -186,7 +203,7 @@ class ResearchNoteRepository: note.content = content note.save() logger.debug(f"Updated research note ID {note_id}: {content[:50]}...") - return note + return self._to_model(note) except peewee.DatabaseError as e: logger.error(f"Failed to update research note {note_id}: {str(e)}") raise @@ -206,7 +223,7 @@ class ResearchNoteRepository: """ try: # First check if the note exists - note = self.get(note_id) + note = ResearchNote.get_or_none(ResearchNote.id == note_id) if not note: logger.warning(f"Attempted to delete non-existent research note {note_id}") return False @@ -219,18 +236,19 @@ class ResearchNoteRepository: logger.error(f"Failed to delete research note {note_id}: {str(e)}") raise - def get_all(self) -> List[ResearchNote]: + def get_all(self) -> List[ResearchNoteModel]: """ Retrieve all research notes from the database. Returns: - List[ResearchNote]: List of all research note instances + List[ResearchNoteModel]: List of all research note instances Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return list(ResearchNote.select().order_by(ResearchNote.id)) + notes = list(ResearchNote.select().order_by(ResearchNote.id)) + return [self._to_model(note) for note in notes] except peewee.DatabaseError as e: logger.error(f"Failed to fetch all research notes: {str(e)}") raise diff --git a/ra_aid/database/repositories/session_repository.py b/ra_aid/database/repositories/session_repository.py index 85509f4..9996f0e 100644 --- a/ra_aid/database/repositories/session_repository.py +++ b/ra_aid/database/repositories/session_repository.py @@ -16,6 +16,7 @@ import sys import peewee from ra_aid.database.models import Session +from ra_aid.database.pydantic_models import SessionModel from ra_aid.__version__ import __version__ from ra_aid.logging_config import get_logger @@ -120,8 +121,23 @@ class SessionRepository: raise ValueError("Database connection is required for SessionRepository") self.db = db self.current_session = None + + def _to_model(self, session: Optional[Session]) -> Optional[SessionModel]: + """ + Convert a Peewee Session object to a Pydantic SessionModel. + + Args: + session: Peewee Session instance or None + + Returns: + Optional[SessionModel]: Pydantic model representation or None if session is None + """ + if session is None: + return None + + return SessionModel.model_validate(session, from_attributes=True) - def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session: + def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> SessionModel: """ Create a new session record in the database. @@ -129,7 +145,7 @@ class SessionRepository: metadata: Optional dictionary of additional metadata to store with the session Returns: - Session: The newly created session instance + SessionModel: The newly created session instance Raises: peewee.DatabaseError: If there's an error creating the record @@ -155,12 +171,12 @@ class SessionRepository: self.current_session = session logger.debug(f"Created new session with ID {session.id}") - return session + return self._to_model(session) except peewee.DatabaseError as e: logger.error(f"Failed to create session record: {str(e)}") raise - def get_current_session(self) -> Optional[Session]: + def get_current_session(self) -> Optional[SessionModel]: """ Get the current active session. @@ -168,17 +184,17 @@ class SessionRepository: retrieves the most recent session from the database. Returns: - Optional[Session]: The current session or None if no sessions exist + Optional[SessionModel]: The current session or None if no sessions exist """ if self.current_session is not None: - return self.current_session + return self._to_model(self.current_session) try: # Find the most recent session session = Session.select().order_by(Session.created_at.desc()).first() if session: self.current_session = session - return session + return self._to_model(session) except peewee.DatabaseError as e: logger.error(f"Failed to get current session: {str(e)}") return None @@ -193,7 +209,7 @@ class SessionRepository: session = self.get_current_session() return session.id if session else None - def get(self, session_id: int) -> Optional[Session]: + def get(self, session_id: int) -> Optional[SessionModel]: """ Get a session by its ID. @@ -201,28 +217,30 @@ class SessionRepository: session_id: The ID of the session to retrieve Returns: - Optional[Session]: The session with the given ID or None if not found + Optional[SessionModel]: The session with the given ID or None if not found """ try: - return Session.get_or_none(Session.id == session_id) + session = Session.get_or_none(Session.id == session_id) + return self._to_model(session) except peewee.DatabaseError as e: logger.error(f"Database error getting session {session_id}: {str(e)}") return None - def get_all(self) -> List[Session]: + def get_all(self) -> List[SessionModel]: """ Get all sessions from the database. Returns: - List[Session]: List of all sessions + List[SessionModel]: List of all sessions """ try: - return list(Session.select().order_by(Session.created_at.desc())) + sessions = list(Session.select().order_by(Session.created_at.desc())) + return [self._to_model(session) for session in sessions] except peewee.DatabaseError as e: logger.error(f"Failed to get all sessions: {str(e)}") return [] - def get_recent(self, limit: int = 10) -> List[Session]: + def get_recent(self, limit: int = 10) -> List[SessionModel]: """ Get the most recent sessions from the database. @@ -230,14 +248,15 @@ class SessionRepository: limit: Maximum number of sessions to return (default: 10) Returns: - List[Session]: List of the most recent sessions + List[SessionModel]: List of the most recent sessions """ try: - return list( + sessions = list( Session.select() .order_by(Session.created_at.desc()) .limit(limit) ) + return [self._to_model(session) for session in sessions] except peewee.DatabaseError as e: logger.error(f"Failed to get recent sessions: {str(e)}") return [] \ No newline at end of file diff --git a/ra_aid/database/repositories/trajectory_repository.py b/ra_aid/database/repositories/trajectory_repository.py index 792ff79..71f9433 100644 --- a/ra_aid/database/repositories/trajectory_repository.py +++ b/ra_aid/database/repositories/trajectory_repository.py @@ -14,6 +14,7 @@ import logging import peewee from ra_aid.database.models import Trajectory, HumanInput +from ra_aid.database.pydantic_models import TrajectoryModel from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -130,6 +131,21 @@ class TrajectoryRepository: raise ValueError("Database connection is required for TrajectoryRepository") self.db = db + def _to_model(self, trajectory: Optional[Trajectory]) -> Optional[TrajectoryModel]: + """ + Convert a Peewee Trajectory object to a Pydantic TrajectoryModel. + + Args: + trajectory: Peewee Trajectory instance or None + + Returns: + Optional[TrajectoryModel]: Pydantic model representation or None if trajectory is None + """ + if trajectory is None: + return None + + return TrajectoryModel.model_validate(trajectory, from_attributes=True) + def create( self, tool_name: Optional[str] = None, @@ -144,7 +160,7 @@ class TrajectoryRepository: error_message: Optional[str] = None, error_type: Optional[str] = None, error_details: Optional[str] = None - ) -> Trajectory: + ) -> TrajectoryModel: """ Create a new trajectory record in the database. @@ -163,7 +179,7 @@ class TrajectoryRepository: error_details: Additional error details like stack traces (if is_error is True) Returns: - Trajectory: The newly created trajectory instance + TrajectoryModel: The newly created trajectory instance as a Pydantic model Raises: peewee.DatabaseError: If there's an error creating the record @@ -201,12 +217,12 @@ class TrajectoryRepository: logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}") else: logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}") - return trajectory + return self._to_model(trajectory) except peewee.DatabaseError as e: logger.error(f"Failed to create trajectory record: {str(e)}") raise - def get(self, trajectory_id: int) -> Optional[Trajectory]: + def get(self, trajectory_id: int) -> Optional[TrajectoryModel]: """ Retrieve a trajectory record by its ID. @@ -214,13 +230,14 @@ class TrajectoryRepository: trajectory_id: The ID of the trajectory record to retrieve Returns: - Optional[Trajectory]: The trajectory instance if found, None otherwise + Optional[TrajectoryModel]: The trajectory instance as a Pydantic model if found, None otherwise Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return Trajectory.get_or_none(Trajectory.id == trajectory_id) + trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id) + return self._to_model(trajectory) except peewee.DatabaseError as e: logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}") raise @@ -236,7 +253,7 @@ class TrajectoryRepository: error_message: Optional[str] = None, error_type: Optional[str] = None, error_details: Optional[str] = None - ) -> Optional[Trajectory]: + ) -> Optional[TrajectoryModel]: """ Update an existing trajectory record. @@ -254,15 +271,15 @@ class TrajectoryRepository: error_details: Additional error details like stack traces Returns: - Optional[Trajectory]: The updated trajectory if found, None otherwise + Optional[TrajectoryModel]: The updated trajectory as a Pydantic model if found, None otherwise Raises: peewee.DatabaseError: If there's an error updating the record """ try: # First check if the trajectory exists - trajectory = self.get(trajectory_id) - if not trajectory: + peewee_trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id) + if not peewee_trajectory: logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}") return None @@ -299,7 +316,7 @@ class TrajectoryRepository: logger.debug(f"Updated trajectory record ID {trajectory_id}") return self.get(trajectory_id) - return trajectory + return self._to_model(peewee_trajectory) except peewee.DatabaseError as e: logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}") raise @@ -319,7 +336,7 @@ class TrajectoryRepository: """ try: # First check if the trajectory exists - trajectory = self.get(trajectory_id) + trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id) if not trajectory: logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}") return False @@ -332,23 +349,24 @@ class TrajectoryRepository: logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}") raise - def get_all(self) -> Dict[int, Trajectory]: + def get_all(self) -> Dict[int, TrajectoryModel]: """ Retrieve all trajectory records from the database. Returns: - Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances + Dict[int, TrajectoryModel]: Dictionary mapping trajectory IDs to trajectory Pydantic models Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)} + trajectories = Trajectory.select().order_by(Trajectory.id) + return {trajectory.id: self._to_model(trajectory) for trajectory in trajectories} except peewee.DatabaseError as e: logger.error(f"Failed to fetch all trajectories: {str(e)}") raise - def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]: + def get_trajectories_by_human_input(self, human_input_id: int) -> List[TrajectoryModel]: """ Retrieve all trajectory records associated with a specific human input. @@ -356,37 +374,19 @@ class TrajectoryRepository: human_input_id: The ID of the human input to get trajectories for Returns: - List[Trajectory]: List of trajectory instances associated with the human input + List[TrajectoryModel]: List of trajectory Pydantic models associated with the human input Raises: peewee.DatabaseError: If there's an error accessing the database """ try: - return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id)) + trajectories = list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id)) + return [self._to_model(trajectory) for trajectory in trajectories] except peewee.DatabaseError as e: logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}") raise - def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]: - """ - Parse a JSON string into a Python dictionary. - - Args: - json_str: JSON string to parse - - Returns: - Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid - """ - if not json_str: - return None - - try: - return json.loads(json_str) - except json.JSONDecodeError as e: - logger.error(f"Error parsing JSON field: {str(e)}") - return None - - def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]: + def get_parsed_trajectory(self, trajectory_id: int) -> Optional[TrajectoryModel]: """ Get a trajectory record with JSON fields parsed into dictionaries. @@ -394,27 +394,7 @@ class TrajectoryRepository: trajectory_id: ID of the trajectory to retrieve Returns: - Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields, - or None if not found + Optional[TrajectoryModel]: The trajectory as a Pydantic model with parsed JSON fields, + or None if not found """ - trajectory = self.get(trajectory_id) - if trajectory is None: - return None - - return { - "id": trajectory.id, - "created_at": trajectory.created_at, - "updated_at": trajectory.updated_at, - "tool_name": trajectory.tool_name, - "tool_parameters": self.parse_json_field(trajectory.tool_parameters), - "tool_result": self.parse_json_field(trajectory.tool_result), - "step_data": self.parse_json_field(trajectory.step_data), - "record_type": trajectory.record_type, - "cost": trajectory.cost, - "tokens": trajectory.tokens, - "human_input_id": trajectory.human_input.id if trajectory.human_input else None, - "is_error": trajectory.is_error, - "error_message": trajectory.error_message, - "error_type": trajectory.error_type, - "error_details": trajectory.error_details, - } \ No newline at end of file + return self.get(trajectory_id) \ No newline at end of file diff --git a/tests/ra_aid/database/test_human_input_repository.py b/tests/ra_aid/database/test_human_input_repository.py new file mode 100644 index 0000000..c8e4b19 --- /dev/null +++ b/tests/ra_aid/database/test_human_input_repository.py @@ -0,0 +1,198 @@ +""" +Tests for the human input repository. + +This module provides tests for the HumanInputRepository class, +ensuring it correctly interfaces with the database and returns +appropriate Pydantic models. +""" + +import unittest +from typing import List, Dict, Any + +import pytest +from peewee import SqliteDatabase + +from ra_aid.database.models import HumanInput, Session, database_proxy +from ra_aid.database.pydantic_models import HumanInputModel, SessionModel +from ra_aid.database.repositories.human_input_repository import HumanInputRepository +from ra_aid.database.repositories.session_repository import SessionRepository + + +@pytest.fixture +def test_db(): + """Fixture for creating a test database.""" + # Create an in-memory SQLite database for testing + test_db = SqliteDatabase(':memory:') + + # Register the models with the test database + with test_db.bind_ctx([HumanInput, Session]): + # Create the tables + test_db.create_tables([HumanInput, Session]) + + # Return the test database for use in the tests + yield test_db + + # Drop the tables after the tests + test_db.drop_tables([HumanInput, Session]) + + +class TestHumanInputRepository(unittest.TestCase): + """Test case for the HumanInputRepository class.""" + + def setUp(self): + """Set up the test case with a test database and repositories.""" + # Create an in-memory database for testing + self.db = SqliteDatabase(':memory:') + + # Register the models with the test database + self.models = [HumanInput, Session] + self.db.bind(self.models) + + # Create the tables + self.db.create_tables(self.models) + + # Create repository instances for testing + self.repository = HumanInputRepository(self.db) + self.session_repository = SessionRepository(self.db) + + # Bind the test database to the repository model + database_proxy.initialize(self.db) + + def tearDown(self): + """Clean up after the test case.""" + # Close the database connection + self.db.close() + + def test_create(self): + """Test creating a human input record.""" + # Create a session first + session_model = self.session_repository.create_session() + + # Create a human input + content = "Test human input" + source = "cli" + human_input = self.repository.create(content=content, source=source) + + # Verify the human input was created + self.assertIsInstance(human_input, HumanInputModel) + self.assertEqual(human_input.content, content) + self.assertEqual(human_input.source, source) + + def test_get(self): + """Test retrieving a human input record by ID.""" + # Create a session first + session_model = self.session_repository.create_session() + + # Create a human input + content = "Test human input" + source = "chat" + created_input = self.repository.create(content=content, source=source) + + # Get the human input by ID + retrieved_input = self.repository.get(created_input.id) + + # Verify the human input was retrieved correctly + self.assertIsInstance(retrieved_input, HumanInputModel) + self.assertEqual(retrieved_input.id, created_input.id) + self.assertEqual(retrieved_input.content, content) + self.assertEqual(retrieved_input.source, source) + + def test_update(self): + """Test updating a human input record.""" + # Create a session first + session_model = self.session_repository.create_session() + + # Create a human input + content = "Original content" + source = "cli" + created_input = self.repository.create(content=content, source=source) + + # Update the human input + new_content = "Updated content" + updated_input = self.repository.update(created_input.id, content=new_content) + + # Verify the human input was updated correctly + self.assertIsInstance(updated_input, HumanInputModel) + self.assertEqual(updated_input.id, created_input.id) + self.assertEqual(updated_input.content, new_content) + self.assertEqual(updated_input.source, source) + + def test_get_all(self): + """Test retrieving all human input records.""" + # Create a session first + session_model = self.session_repository.create_session() + + # Create multiple human inputs + self.repository.create(content="Input 1", source="cli") + self.repository.create(content="Input 2", source="chat") + self.repository.create(content="Input 3", source="hil") + + # Get all human inputs + all_inputs = self.repository.get_all() + + # Verify all human inputs were retrieved + self.assertEqual(len(all_inputs), 3) + self.assertIsInstance(all_inputs[0], HumanInputModel) + + # Verify the inputs are ordered by created_at in descending order + self.assertEqual(all_inputs[0].content, "Input 3") + self.assertEqual(all_inputs[1].content, "Input 2") + self.assertEqual(all_inputs[2].content, "Input 1") + + def test_get_recent(self): + """Test retrieving the most recent human input records.""" + # Create a session first + session_model = self.session_repository.create_session() + + # Create multiple human inputs + self.repository.create(content="Input 1", source="cli") + self.repository.create(content="Input 2", source="chat") + self.repository.create(content="Input 3", source="hil") + self.repository.create(content="Input 4", source="cli") + self.repository.create(content="Input 5", source="chat") + + # Get recent human inputs with a limit of 3 + recent_inputs = self.repository.get_recent(limit=3) + + # Verify only the 3 most recent inputs were retrieved + self.assertEqual(len(recent_inputs), 3) + self.assertIsInstance(recent_inputs[0], HumanInputModel) + self.assertEqual(recent_inputs[0].content, "Input 5") + self.assertEqual(recent_inputs[1].content, "Input 4") + self.assertEqual(recent_inputs[2].content, "Input 3") + + def test_get_by_source(self): + """Test retrieving human input records by source.""" + # Create a session first + session_model = self.session_repository.create_session() + + # Create human inputs with different sources + self.repository.create(content="CLI Input 1", source="cli") + self.repository.create(content="Chat Input 1", source="chat") + self.repository.create(content="HIL Input", source="hil") + self.repository.create(content="CLI Input 2", source="cli") + self.repository.create(content="Chat Input 2", source="chat") + + # Get human inputs for the 'cli' source + cli_inputs = self.repository.get_by_source("cli") + + # Verify only cli inputs were retrieved + self.assertEqual(len(cli_inputs), 2) + self.assertIsInstance(cli_inputs[0], HumanInputModel) + self.assertEqual(cli_inputs[0].content, "CLI Input 2") + self.assertEqual(cli_inputs[1].content, "CLI Input 1") + + def test_get_most_recent_id(self): + """Test retrieving the ID of the most recent human input record.""" + # Create a session first + session_model = self.session_repository.create_session() + + # Create multiple human inputs + self.repository.create(content="Input 1", source="cli") + input2 = self.repository.create(content="Input 2", source="chat") + + # Get the most recent ID + most_recent_id = self.repository.get_most_recent_id() + + # Verify the correct ID was retrieved + self.assertEqual(most_recent_id, input2.id) \ No newline at end of file diff --git a/tests/ra_aid/database/test_key_fact_repository.py b/tests/ra_aid/database/test_key_fact_repository.py index d97a455..d4da855 100644 --- a/tests/ra_aid/database/test_key_fact_repository.py +++ b/tests/ra_aid/database/test_key_fact_repository.py @@ -15,6 +15,7 @@ from ra_aid.database.repositories.key_fact_repository import ( get_key_fact_repository, key_fact_repo_var ) +from ra_aid.database.pydantic_models import KeyFactModel @pytest.fixture @@ -87,7 +88,8 @@ def test_create_key_fact(setup_db): content = "Test key fact" fact = repo.create(content) - # Verify the fact was created correctly + # Verify the fact was created correctly and is a KeyFactModel + assert isinstance(fact, KeyFactModel) assert fact.id is not None assert fact.content == content @@ -108,7 +110,8 @@ def test_get_key_fact(setup_db): # Retrieve the fact by ID retrieved_fact = repo.get(fact.id) - # Verify the retrieved fact matches the original + # Verify the retrieved fact matches the original and is a KeyFactModel + assert isinstance(retrieved_fact, KeyFactModel) assert retrieved_fact is not None assert retrieved_fact.id == fact.id assert retrieved_fact.content == content @@ -131,7 +134,8 @@ def test_update_key_fact(setup_db): new_content = "Updated content" updated_fact = repo.update(fact.id, new_content) - # Verify the fact was updated correctly + # Verify the fact was updated correctly and is a KeyFactModel + assert isinstance(updated_fact, KeyFactModel) assert updated_fact is not None assert updated_fact.id == fact.id assert updated_fact.content == new_content @@ -184,8 +188,10 @@ def test_get_all_key_facts(setup_db): # Retrieve all facts all_facts = repo.get_all() - # Verify we got the correct number of facts + # Verify we got the correct number of facts and they are KeyFactModel instances assert len(all_facts) == len(contents) + for fact in all_facts: + assert isinstance(fact, KeyFactModel) # Verify the content of each fact fact_contents = [fact.content for fact in all_facts] @@ -237,6 +243,7 @@ def test_key_fact_repository_manager(setup_db, cleanup_repo): # Verify we can use the repository content = "Test fact via context manager" fact = repo.create(content) + assert isinstance(fact, KeyFactModel) assert fact.id is not None assert fact.content == content @@ -258,4 +265,26 @@ def test_get_key_fact_repository_when_not_set(cleanup_repo): get_key_fact_repository() # Verify the correct error message - assert "No KeyFactRepository available" in str(excinfo.value) \ No newline at end of file + assert "No KeyFactRepository available" in str(excinfo.value) + + +def test_to_model_method(setup_db): + """Test the _to_model method converts KeyFact to KeyFactModel correctly.""" + # Set up repository + repo = KeyFactRepository(db=setup_db) + + # Create a Peewee KeyFact directly + peewee_fact = KeyFact.create(content="Test fact for conversion") + + # Convert to Pydantic model + pydantic_fact = repo._to_model(peewee_fact) + + # Verify conversion was correct + assert isinstance(pydantic_fact, KeyFactModel) + assert pydantic_fact.id == peewee_fact.id + assert pydantic_fact.content == peewee_fact.content + assert pydantic_fact.created_at == peewee_fact.created_at + assert pydantic_fact.updated_at == peewee_fact.updated_at + + # Test with None input + assert repo._to_model(None) is None \ No newline at end of file diff --git a/tests/ra_aid/database/test_key_snippet_repository.py b/tests/ra_aid/database/test_key_snippet_repository.py index 1c83006..9ea3fac 100644 --- a/tests/ra_aid/database/test_key_snippet_repository.py +++ b/tests/ra_aid/database/test_key_snippet_repository.py @@ -6,6 +6,7 @@ import pytest from ra_aid.database.connection import DatabaseManager, db_var from ra_aid.database.models import KeySnippet +from ra_aid.database.pydantic_models import KeySnippetModel from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository @@ -79,6 +80,9 @@ def test_create_key_snippet(setup_db): assert key_snippet.snippet == snippet assert key_snippet.description == description + # Verify the return type is a Pydantic model + assert isinstance(key_snippet, KeySnippetModel) + # Verify we can retrieve it from the database snippet_from_db = KeySnippet.get_by_id(key_snippet.id) assert snippet_from_db.filepath == filepath @@ -116,6 +120,9 @@ def test_get_key_snippet(setup_db): assert retrieved_snippet.snippet == snippet assert retrieved_snippet.description == description + # Verify the return type is a Pydantic model + assert isinstance(retrieved_snippet, KeySnippetModel) + # Try to retrieve a non-existent snippet non_existent_snippet = repo.get(999) assert non_existent_snippet is None @@ -161,6 +168,9 @@ def test_update_key_snippet(setup_db): assert updated_snippet.snippet == new_snippet assert updated_snippet.description == new_description + # Verify the return type is a Pydantic model + assert isinstance(updated_snippet, KeySnippetModel) + # Verify we can retrieve the updated content from the database snippet_from_db = KeySnippet.get_by_id(key_snippet.id) assert snippet_from_db.filepath == new_filepath @@ -250,6 +260,9 @@ def test_get_all_key_snippets(setup_db): # Verify we got the correct number of snippets assert len(all_snippets) == len(snippets_data) + # Verify all returned snippets are Pydantic models + assert all(isinstance(snippet, KeySnippetModel) for snippet in all_snippets) + # Verify the content of each snippet for i, snippet in enumerate(all_snippets): assert snippet.filepath == snippets_data[i]["filepath"] @@ -301,4 +314,31 @@ def test_get_snippets_dict(setup_db): assert snippets_dict[snippet.id]["filepath"] == snippets_data[i]["filepath"] assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"] assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"] - assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"] \ No newline at end of file + assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"] + + +def test_to_model_conversion(setup_db): + """Test conversion from Peewee model to Pydantic model.""" + repo = KeySnippetRepository(db=setup_db) + + # Create a snippet in the database using Peewee directly + peewee_snippet = KeySnippet.create( + filepath="conversion_test.py", + line_number=100, + snippet="def conversion_test():", + description="Test model conversion" + ) + + # Use the _to_model method to convert it + pydantic_snippet = repo._to_model(peewee_snippet) + + # Verify the conversion was successful + assert isinstance(pydantic_snippet, KeySnippetModel) + assert pydantic_snippet.id == peewee_snippet.id + assert pydantic_snippet.filepath == peewee_snippet.filepath + assert pydantic_snippet.line_number == peewee_snippet.line_number + assert pydantic_snippet.snippet == peewee_snippet.snippet + assert pydantic_snippet.description == peewee_snippet.description + + # Test conversion of None + assert repo._to_model(None) is None \ No newline at end of file diff --git a/tests/ra_aid/database/test_pydantic_models.py b/tests/ra_aid/database/test_pydantic_models.py new file mode 100644 index 0000000..d3f8ab0 --- /dev/null +++ b/tests/ra_aid/database/test_pydantic_models.py @@ -0,0 +1,110 @@ +""" +Tests for the Pydantic models in ra_aid.database.pydantic_models +""" + +import datetime +import json +import pytest + +from ra_aid.database.models import Session +from ra_aid.database.pydantic_models import SessionModel + + +class TestSessionModel: + """Tests for the SessionModel Pydantic model""" + + def test_from_peewee_model(self): + """Test conversion from a Peewee model instance""" + # Create a Peewee Session instance + now = datetime.datetime.now() + metadata = {"os": "Linux", "cpu_cores": 8, "memory_gb": 16} + session = Session( + id=1, + created_at=now, + updated_at=now, + start_time=now, + command_line="ra-aid run", + program_version="1.0.0", + machine_info=json.dumps(metadata) + ) + + # Convert to Pydantic model + session_model = SessionModel.model_validate(session, from_attributes=True) + + # Verify fields + assert session_model.id == 1 + assert session_model.created_at == now + assert session_model.updated_at == now + assert session_model.start_time == now + assert session_model.command_line == "ra-aid run" + assert session_model.program_version == "1.0.0" + assert session_model.machine_info == metadata + + def test_with_dict_machine_info(self): + """Test creating a model with a dict for machine_info""" + # Create directly with a dict for machine_info + now = datetime.datetime.now() + metadata = {"os": "Windows", "cpu_cores": 4, "memory_gb": 8} + + session_model = SessionModel( + id=2, + created_at=now, + updated_at=now, + start_time=now, + command_line="ra-aid --debug", + program_version="1.0.1", + machine_info=metadata + ) + + # Verify fields + assert session_model.id == 2 + assert session_model.machine_info == metadata + + def test_with_none_machine_info(self): + """Test creating a model with None for machine_info""" + now = datetime.datetime.now() + + session_model = SessionModel( + id=3, + created_at=now, + updated_at=now, + start_time=now, + command_line="ra-aid", + program_version="1.0.0", + machine_info=None + ) + + assert session_model.id == 3 + assert session_model.machine_info is None + + def test_invalid_json_machine_info(self): + """Test error handling for invalid JSON in machine_info""" + now = datetime.datetime.now() + + # Invalid JSON string should raise ValueError + with pytest.raises(ValueError): + SessionModel( + id=4, + created_at=now, + updated_at=now, + start_time=now, + command_line="ra-aid", + program_version="1.0.0", + machine_info="{invalid json}" + ) + + def test_unexpected_type_machine_info(self): + """Test error handling for unexpected type in machine_info""" + now = datetime.datetime.now() + + # Integer type should raise ValueError + with pytest.raises(ValueError): + SessionModel( + id=5, + created_at=now, + updated_at=now, + start_time=now, + command_line="ra-aid", + program_version="1.0.0", + machine_info=123 # Not a dict or string + ) \ No newline at end of file diff --git a/tests/ra_aid/database/test_research_note_repository.py b/tests/ra_aid/database/test_research_note_repository.py index afbcc11..8653506 100644 --- a/tests/ra_aid/database/test_research_note_repository.py +++ b/tests/ra_aid/database/test_research_note_repository.py @@ -9,6 +9,7 @@ import peewee from ra_aid.database.connection import DatabaseManager, db_var from ra_aid.database.models import ResearchNote, BaseModel +from ra_aid.database.pydantic_models import ResearchNoteModel from ra_aid.database.repositories.research_note_repository import ( ResearchNoteRepository, ResearchNoteRepositoryManager, @@ -90,10 +91,12 @@ def test_create_research_note(setup_db): # Verify the note was created correctly assert note.id is not None assert note.content == content + assert isinstance(note, ResearchNoteModel) # Verify we can retrieve it from the database using the repository note_from_db = repo.get(note.id) assert note_from_db.content == content + assert isinstance(note_from_db, ResearchNoteModel) def test_get_research_note(setup_db): @@ -112,6 +115,7 @@ def test_get_research_note(setup_db): assert retrieved_note is not None assert retrieved_note.id == note.id assert retrieved_note.content == content + assert isinstance(retrieved_note, ResearchNoteModel) # Try to retrieve a non-existent note non_existent_note = repo.get(999) @@ -135,10 +139,12 @@ def test_update_research_note(setup_db): assert updated_note is not None assert updated_note.id == note.id assert updated_note.content == new_content + assert isinstance(updated_note, ResearchNoteModel) # Verify we can retrieve the updated content from the database using the repository note_from_db = repo.get(note.id) assert note_from_db.content == new_content + assert isinstance(note_from_db, ResearchNoteModel) # Try to update a non-existent note non_existent_update = repo.update(999, "This shouldn't work") @@ -187,8 +193,11 @@ def test_get_all_research_notes(setup_db): # Verify we got the correct number of notes assert len(all_notes) == len(contents) - # Verify the content of each note + # Verify the content of each note and that they are Pydantic models note_contents = [note.content for note in all_notes] + for note in all_notes: + assert isinstance(note, ResearchNoteModel) + for content in contents: assert content in note_contents @@ -239,6 +248,7 @@ def test_research_note_repository_manager(setup_db, cleanup_repo): note = repo.create(content) assert note.id is not None assert note.content == content + assert isinstance(note, ResearchNoteModel) # Verify we can get the repository using get_research_note_repository repo_from_var = get_research_note_repository() @@ -258,4 +268,26 @@ def test_get_research_note_repository_when_not_set(cleanup_repo): get_research_note_repository() # Verify the correct error message - assert "No ResearchNoteRepository available" in str(excinfo.value) \ No newline at end of file + assert "No ResearchNoteRepository available" in str(excinfo.value) + + +def test_to_model_method(setup_db): + """Test the _to_model method converts Peewee models to Pydantic models correctly.""" + # Set up repository + repo = ResearchNoteRepository(db=setup_db) + + # Create a Peewee ResearchNote directly + peewee_note = ResearchNote.create(content="Test note for conversion") + + # Convert it using _to_model + pydantic_note = repo._to_model(peewee_note) + + # Verify the conversion + assert isinstance(pydantic_note, ResearchNoteModel) + assert pydantic_note.id == peewee_note.id + assert pydantic_note.content == peewee_note.content + assert pydantic_note.created_at == peewee_note.created_at + assert pydantic_note.updated_at == peewee_note.updated_at + + # Test with None + assert repo._to_model(None) is None \ No newline at end of file diff --git a/tests/ra_aid/database/test_session_repository.py b/tests/ra_aid/database/test_session_repository.py new file mode 100644 index 0000000..2845395 --- /dev/null +++ b/tests/ra_aid/database/test_session_repository.py @@ -0,0 +1,347 @@ +""" +Tests for the SessionRepository class. +""" + +import pytest +import datetime +import json +from unittest.mock import patch + +import peewee + +from ra_aid.database.connection import DatabaseManager, db_var +from ra_aid.database.models import Session, BaseModel +from ra_aid.database.repositories.session_repository import ( + SessionRepository, + SessionRepositoryManager, + get_session_repository, + session_repo_var +) +from ra_aid.database.pydantic_models import SessionModel + + +@pytest.fixture +def cleanup_db(): + """Reset the database contextvar and connection state after each test.""" + # Reset before the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + # Run the test + yield + + # Reset after the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + +@pytest.fixture +def cleanup_repo(): + """Reset the repository contextvar after each test.""" + # Reset before the test + session_repo_var.set(None) + + # Run the test + yield + + # Reset after the test + session_repo_var.set(None) + + +@pytest.fixture +def setup_db(cleanup_db): + """Set up an in-memory database with the Session table and patch the BaseModel.Meta.database.""" + # Initialize an in-memory database connection + with DatabaseManager(in_memory=True) as db: + # Patch the BaseModel.Meta.database to use our in-memory database + with patch.object(BaseModel._meta, 'database', db): + # Create the Session table + with db.atomic(): + db.create_tables([Session], safe=True) + + yield db + + # Clean up + with db.atomic(): + Session.drop_table(safe=True) + + +@pytest.fixture +def test_metadata(): + """Return test metadata for sessions.""" + return { + "os": "Test OS", + "version": "1.0", + "cpu_cores": 4, + "memory_gb": 16, + "additional_info": { + "gpu": "Test GPU", + "display_resolution": "1920x1080" + } + } + + +@pytest.fixture +def sample_session(setup_db, test_metadata): + """Create a sample session in the database.""" + now = datetime.datetime.now() + return Session.create( + start_time=now, + command_line="ra-aid test", + program_version="1.0.0", + machine_info=json.dumps(test_metadata) + ) + + +def test_create_session_with_metadata(setup_db, test_metadata): + """Test creating a session with metadata.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Create a session with metadata + session = repo.create_session(metadata=test_metadata) + + # Verify type is SessionModel, not Session + assert isinstance(session, SessionModel) + + # Verify the session was created correctly + assert session.id is not None + assert session.command_line is not None + assert session.program_version is not None + + # Verify machine_info is a dict, not a JSON string + assert isinstance(session.machine_info, dict) + assert session.machine_info == test_metadata + + # Verify the dictionary structure is preserved + assert "additional_info" in session.machine_info + assert session.machine_info["additional_info"]["gpu"] == "Test GPU" + + +def test_create_session_without_metadata(setup_db): + """Test creating a session without metadata.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Create a session without metadata + session = repo.create_session() + + # Verify type is SessionModel, not Session + assert isinstance(session, SessionModel) + + # Verify the session was created correctly + assert session.id is not None + assert session.command_line is not None + assert session.program_version is not None + + # Verify machine_info is None + assert session.machine_info is None + + +def test_get_current_session(setup_db, sample_session): + """Test retrieving the current session.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Set the current session + repo.current_session = sample_session + + # Get the current session + current_session = repo.get_current_session() + + # Verify type is SessionModel, not Session + assert isinstance(current_session, SessionModel) + + # Verify the retrieved session matches the original + assert current_session.id == sample_session.id + assert current_session.command_line == sample_session.command_line + assert current_session.program_version == sample_session.program_version + + # Verify machine_info is a dict, not a JSON string + assert isinstance(current_session.machine_info, dict) + + +def test_get_current_session_from_db(setup_db, sample_session): + """Test retrieving the current session from the database when no current session is set.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Get the current session (should retrieve the most recent from DB) + current_session = repo.get_current_session() + + # Verify type is SessionModel, not Session + assert isinstance(current_session, SessionModel) + + # Verify the retrieved session matches the sample session + assert current_session.id == sample_session.id + + # Verify machine_info is a dict, not a JSON string + assert isinstance(current_session.machine_info, dict) + + +def test_get_by_id(setup_db, sample_session): + """Test retrieving a session by ID.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Get the session by ID + session = repo.get(sample_session.id) + + # Verify type is SessionModel, not Session + assert isinstance(session, SessionModel) + + # Verify the retrieved session matches the original + assert session.id == sample_session.id + assert session.command_line == sample_session.command_line + assert session.program_version == sample_session.program_version + + # Verify machine_info is a dict, not a JSON string + assert isinstance(session.machine_info, dict) + + # Verify getting a non-existent session returns None + non_existent_session = repo.get(999) + assert non_existent_session is None + + +def test_get_all(setup_db): + """Test retrieving all sessions.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Create multiple sessions + metadata1 = {"os": "Linux", "cpu_cores": 8} + metadata2 = {"os": "Windows", "cpu_cores": 4} + metadata3 = {"os": "macOS", "cpu_cores": 10} + + repo.create_session(metadata=metadata1) + repo.create_session(metadata=metadata2) + repo.create_session(metadata=metadata3) + + # Get all sessions + sessions = repo.get_all() + + # Verify we got a list of SessionModel objects + assert len(sessions) == 3 + for session in sessions: + assert isinstance(session, SessionModel) + assert isinstance(session.machine_info, dict) + + # Verify the sessions are in descending order of creation time + assert sessions[0].created_at >= sessions[1].created_at + assert sessions[1].created_at >= sessions[2].created_at + + # Verify the machine_info fields + os_values = [session.machine_info["os"] for session in sessions] + assert "Linux" in os_values + assert "Windows" in os_values + assert "macOS" in os_values + + +def test_get_all_empty(setup_db): + """Test retrieving all sessions when none exist.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Get all sessions + sessions = repo.get_all() + + # Verify we got an empty list + assert isinstance(sessions, list) + assert len(sessions) == 0 + + +def test_get_recent(setup_db): + """Test retrieving recent sessions with a limit.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Create multiple sessions + for i in range(5): + metadata = {"index": i, "os": f"OS {i}"} + repo.create_session(metadata=metadata) + + # Get recent sessions with limit=3 + sessions = repo.get_recent(limit=3) + + # Verify we got the correct number of SessionModel objects + assert len(sessions) == 3 + for session in sessions: + assert isinstance(session, SessionModel) + assert isinstance(session.machine_info, dict) + + # Verify the sessions are in descending order and are the most recent ones + indexes = [session.machine_info["index"] for session in sessions] + assert indexes == [4, 3, 2] # Most recent first + + +def test_session_repository_manager(setup_db, cleanup_repo): + """Test the SessionRepositoryManager context manager.""" + # Use the context manager to create a repository + with SessionRepositoryManager(setup_db) as repo: + # Verify the repository was created correctly + assert isinstance(repo, SessionRepository) + assert repo.db is setup_db + + # Create a session and verify it's a SessionModel + metadata = {"test": "manager"} + session = repo.create_session(metadata=metadata) + assert isinstance(session, SessionModel) + assert session.machine_info["test"] == "manager" + + # Verify we can get the repository using get_session_repository + repo_from_var = get_session_repository() + assert repo_from_var is repo + + # Verify the repository was removed from the context var + with pytest.raises(RuntimeError) as excinfo: + get_session_repository() + + assert "No SessionRepository available" in str(excinfo.value) + + +def test_repository_init_without_db(): + """Test that SessionRepository raises an error when initialized without a db parameter.""" + # Attempt to create a repository without a database connection + with pytest.raises(ValueError) as excinfo: + SessionRepository(db=None) + + # Verify the correct error message + assert "Database connection is required" in str(excinfo.value) + + +def test_get_current_session_id(setup_db, sample_session): + """Test retrieving the ID of the current session.""" + # Set up repository + repo = SessionRepository(db=setup_db) + + # Set the current session + repo.current_session = sample_session + + # Get the current session ID + session_id = repo.get_current_session_id() + + # Verify the ID matches + assert session_id == sample_session.id + + # Test when no current session exists + repo.current_session = None + # Delete all sessions + Session.delete().execute() + + # Verify None is returned when no session exists + session_id = repo.get_current_session_id() + assert session_id is None \ No newline at end of file diff --git a/tests/ra_aid/database/test_trajectory_repository.py b/tests/ra_aid/database/test_trajectory_repository.py new file mode 100644 index 0000000..b5bbb4d --- /dev/null +++ b/tests/ra_aid/database/test_trajectory_repository.py @@ -0,0 +1,458 @@ +""" +Tests for the TrajectoryRepository class. +""" + +import pytest +import datetime +import json +from unittest.mock import patch + +import peewee + +from ra_aid.database.connection import DatabaseManager, db_var +from ra_aid.database.models import Trajectory, HumanInput, Session, BaseModel +from ra_aid.database.repositories.trajectory_repository import ( + TrajectoryRepository, + TrajectoryRepositoryManager, + get_trajectory_repository, + trajectory_repo_var +) +from ra_aid.database.pydantic_models import TrajectoryModel + + +@pytest.fixture +def cleanup_db(): + """Reset the database contextvar and connection state after each test.""" + # Reset before the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + # Run the test + yield + + # Reset after the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + +@pytest.fixture +def cleanup_repo(): + """Reset the repository contextvar after each test.""" + # Reset before the test + trajectory_repo_var.set(None) + + # Run the test + yield + + # Reset after the test + trajectory_repo_var.set(None) + + +@pytest.fixture +def setup_db(cleanup_db): + """Set up an in-memory database with the necessary tables and patch the BaseModel.Meta.database.""" + # Initialize an in-memory database connection + with DatabaseManager(in_memory=True) as db: + # Patch the BaseModel.Meta.database to use our in-memory database + with patch.object(BaseModel._meta, 'database', db): + # Create the required tables + with db.atomic(): + db.create_tables([Trajectory, HumanInput, Session], safe=True) + + yield db + + # Clean up + with db.atomic(): + Trajectory.drop_table(safe=True) + HumanInput.drop_table(safe=True) + Session.drop_table(safe=True) + + +@pytest.fixture +def sample_human_input(setup_db): + """Create a sample human input in the database.""" + return HumanInput.create( + content="Test human input", + source="test" + ) + + +@pytest.fixture +def test_tool_parameters(): + """Return test tool parameters.""" + return { + "pattern": "test pattern", + "file_path": "/path/to/file", + "options": { + "case_sensitive": True, + "whole_words": False + } + } + + +@pytest.fixture +def test_tool_result(): + """Return test tool result.""" + return { + "matches": [ + {"line": 10, "content": "This is a test pattern"}, + {"line": 20, "content": "Another test pattern here"} + ], + "total_matches": 2, + "execution_time": 0.5 + } + + +@pytest.fixture +def test_step_data(): + """Return test step data for UI rendering.""" + return { + "display_type": "text", + "content": "Tool execution results", + "highlights": [ + {"start": 10, "end": 15, "color": "red"} + ] + } + + +@pytest.fixture +def sample_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data): + """Create a sample trajectory in the database.""" + return Trajectory.create( + human_input=sample_human_input, + tool_name="ripgrep_search", + tool_parameters=json.dumps(test_tool_parameters), + tool_result=json.dumps(test_tool_result), + step_data=json.dumps(test_step_data), + record_type="tool_execution", + cost=0.001, + tokens=100, + is_error=False + ) + + +def test_create_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data): + """Test creating a trajectory with all fields.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # Create a trajectory + trajectory = repo.create( + tool_name="ripgrep_search", + tool_parameters=test_tool_parameters, + tool_result=test_tool_result, + step_data=test_step_data, + record_type="tool_execution", + human_input_id=sample_human_input.id, + cost=0.001, + tokens=100 + ) + + # Verify type is TrajectoryModel, not Trajectory + assert isinstance(trajectory, TrajectoryModel) + + # Verify the trajectory was created correctly + assert trajectory.id is not None + assert trajectory.tool_name == "ripgrep_search" + + # Verify the JSON fields are dictionaries, not strings + assert isinstance(trajectory.tool_parameters, dict) + assert isinstance(trajectory.tool_result, dict) + assert isinstance(trajectory.step_data, dict) + + # Verify the nested structure of tool parameters + assert trajectory.tool_parameters["options"]["case_sensitive"] == True + assert trajectory.tool_result["total_matches"] == 2 + assert trajectory.step_data["highlights"][0]["color"] == "red" + + # Verify foreign key reference + assert trajectory.human_input_id == sample_human_input.id + + +def test_create_trajectory_minimal(setup_db): + """Test creating a trajectory with minimal fields.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # Create a trajectory with minimal fields + trajectory = repo.create( + tool_name="simple_tool" + ) + + # Verify type is TrajectoryModel, not Trajectory + assert isinstance(trajectory, TrajectoryModel) + + # Verify the trajectory was created correctly + assert trajectory.id is not None + assert trajectory.tool_name == "simple_tool" + + # Verify optional fields are None + assert trajectory.tool_parameters is None + assert trajectory.tool_result is None + assert trajectory.step_data is None + assert trajectory.human_input_id is None + assert trajectory.cost is None + assert trajectory.tokens is None + assert trajectory.is_error is False + + +def test_get_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data): + """Test retrieving a trajectory by ID.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # Get the trajectory by ID + trajectory = repo.get(sample_trajectory.id) + + # Verify type is TrajectoryModel, not Trajectory + assert isinstance(trajectory, TrajectoryModel) + + # Verify the retrieved trajectory matches the original + assert trajectory.id == sample_trajectory.id + assert trajectory.tool_name == sample_trajectory.tool_name + + # Verify the JSON fields are dictionaries, not strings + assert isinstance(trajectory.tool_parameters, dict) + assert isinstance(trajectory.tool_result, dict) + assert isinstance(trajectory.step_data, dict) + + # Verify the content of JSON fields + assert trajectory.tool_parameters == test_tool_parameters + assert trajectory.tool_result == test_tool_result + assert trajectory.step_data == test_step_data + + # Verify non-existent trajectory returns None + non_existent_trajectory = repo.get(999) + assert non_existent_trajectory is None + + +def test_update_trajectory(setup_db, sample_trajectory): + """Test updating a trajectory.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # New data for update + new_tool_result = { + "matches": [ + {"line": 15, "content": "Updated test pattern"} + ], + "total_matches": 1, + "execution_time": 0.3 + } + + new_step_data = { + "display_type": "html", + "content": "Updated UI rendering", + "highlights": [] + } + + # Update the trajectory + updated_trajectory = repo.update( + trajectory_id=sample_trajectory.id, + tool_result=new_tool_result, + step_data=new_step_data, + cost=0.002, + tokens=200, + is_error=True, + error_message="Test error", + error_type="TestErrorType", + error_details="Detailed error information" + ) + + # Verify type is TrajectoryModel, not Trajectory + assert isinstance(updated_trajectory, TrajectoryModel) + + # Verify the fields were updated + assert updated_trajectory.tool_result == new_tool_result + assert updated_trajectory.step_data == new_step_data + assert updated_trajectory.cost == 0.002 + assert updated_trajectory.tokens == 200 + assert updated_trajectory.is_error is True + assert updated_trajectory.error_message == "Test error" + assert updated_trajectory.error_type == "TestErrorType" + assert updated_trajectory.error_details == "Detailed error information" + + # Original tool parameters should not change + # We need to parse the JSON string from the Peewee object for comparison + original_params = json.loads(sample_trajectory.tool_parameters) + assert updated_trajectory.tool_parameters == original_params + + # Verify updating a non-existent trajectory returns None + non_existent_update = repo.update(trajectory_id=999, cost=0.005) + assert non_existent_update is None + + +def test_delete_trajectory(setup_db, sample_trajectory): + """Test deleting a trajectory.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # Verify the trajectory exists + assert repo.get(sample_trajectory.id) is not None + + # Delete the trajectory + result = repo.delete(sample_trajectory.id) + + # Verify the trajectory was deleted + assert result is True + assert repo.get(sample_trajectory.id) is None + + # Verify deleting a non-existent trajectory returns False + result = repo.delete(999) + assert result is False + + +def test_get_all_trajectories(setup_db, sample_human_input): + """Test retrieving all trajectories.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # Create multiple trajectories + for i in range(3): + repo.create( + tool_name=f"tool_{i}", + tool_parameters={"index": i}, + human_input_id=sample_human_input.id + ) + + # Get all trajectories + trajectories = repo.get_all() + + # Verify we got a dictionary of TrajectoryModel objects + assert len(trajectories) == 3 + for trajectory_id, trajectory in trajectories.items(): + assert isinstance(trajectory, TrajectoryModel) + assert isinstance(trajectory.tool_parameters, dict) + + # Verify the trajectories have the correct tool names + tool_names = {trajectory.tool_name for trajectory in trajectories.values()} + assert "tool_0" in tool_names + assert "tool_1" in tool_names + assert "tool_2" in tool_names + + +def test_get_trajectories_by_human_input(setup_db, sample_human_input): + """Test retrieving trajectories by human input ID.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # Create another human input + other_human_input = HumanInput.create( + content="Another human input", + source="test" + ) + + # Create trajectories for both human inputs + for i in range(2): + repo.create( + tool_name=f"tool_1_{i}", + human_input_id=sample_human_input.id + ) + + for i in range(3): + repo.create( + tool_name=f"tool_2_{i}", + human_input_id=other_human_input.id + ) + + # Get trajectories for the first human input + trajectories = repo.get_trajectories_by_human_input(sample_human_input.id) + + # Verify we got a list of TrajectoryModel objects for the first human input + assert len(trajectories) == 2 + for trajectory in trajectories: + assert isinstance(trajectory, TrajectoryModel) + assert trajectory.human_input_id == sample_human_input.id + assert trajectory.tool_name.startswith("tool_1") + + # Get trajectories for the second human input + trajectories = repo.get_trajectories_by_human_input(other_human_input.id) + + # Verify we got a list of TrajectoryModel objects for the second human input + assert len(trajectories) == 3 + for trajectory in trajectories: + assert isinstance(trajectory, TrajectoryModel) + assert trajectory.human_input_id == other_human_input.id + assert trajectory.tool_name.startswith("tool_2") + + +def test_get_parsed_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data): + """Test retrieving a parsed trajectory.""" + # Set up repository + repo = TrajectoryRepository(db=setup_db) + + # Get the parsed trajectory + trajectory = repo.get_parsed_trajectory(sample_trajectory.id) + + # Verify type is TrajectoryModel, not Trajectory + assert isinstance(trajectory, TrajectoryModel) + + # Verify the retrieved trajectory matches the original + assert trajectory.id == sample_trajectory.id + assert trajectory.tool_name == sample_trajectory.tool_name + + # Verify the JSON fields are dictionaries, not strings + assert isinstance(trajectory.tool_parameters, dict) + assert isinstance(trajectory.tool_result, dict) + assert isinstance(trajectory.step_data, dict) + + # Verify the content of JSON fields + assert trajectory.tool_parameters == test_tool_parameters + assert trajectory.tool_result == test_tool_result + assert trajectory.step_data == test_step_data + + # Verify non-existent trajectory returns None + non_existent_trajectory = repo.get_parsed_trajectory(999) + assert non_existent_trajectory is None + + +def test_trajectory_repository_manager(setup_db, cleanup_repo): + """Test the TrajectoryRepositoryManager context manager.""" + # Use the context manager to create a repository + with TrajectoryRepositoryManager(setup_db) as repo: + # Verify the repository was created correctly + assert isinstance(repo, TrajectoryRepository) + assert repo.db is setup_db + + # Create a trajectory and verify it's a TrajectoryModel + tool_parameters = {"test": "manager"} + trajectory = repo.create( + tool_name="manager_test", + tool_parameters=tool_parameters + ) + assert isinstance(trajectory, TrajectoryModel) + assert trajectory.tool_parameters["test"] == "manager" + + # Verify we can get the repository using get_trajectory_repository + repo_from_var = get_trajectory_repository() + assert repo_from_var is repo + + # Verify the repository was removed from the context var + with pytest.raises(RuntimeError) as excinfo: + get_trajectory_repository() + + assert "No TrajectoryRepository available" in str(excinfo.value) + + +def test_repository_init_without_db(): + """Test that TrajectoryRepository raises an error when initialized without a db parameter.""" + # Attempt to create a repository without a database connection + with pytest.raises(ValueError) as excinfo: + TrajectoryRepository(db=None) + + # Verify the correct error message + assert "Database connection is required" in str(excinfo.value) \ No newline at end of file