RA.Aid/tests/ra_aid/database/test_trajectory_repository.py

458 lines
15 KiB
Python

"""
Tests for the TrajectoryRepository class.
"""
import pytest
import datetime
import json
from unittest.mock import patch
import peewee
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import Trajectory, HumanInput, Session, BaseModel
from ra_aid.database.repositories.trajectory_repository import (
TrajectoryRepository,
TrajectoryRepositoryManager,
get_trajectory_repository,
trajectory_repo_var
)
from ra_aid.database.pydantic_models import TrajectoryModel
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def cleanup_repo():
"""Reset the repository contextvar after each test."""
# Reset before the test
trajectory_repo_var.set(None)
# Run the test
yield
# Reset after the test
trajectory_repo_var.set(None)
@pytest.fixture
def setup_db(cleanup_db):
"""Set up an in-memory database with the necessary tables and patch the BaseModel.Meta.database."""
# Initialize an in-memory database connection
with DatabaseManager(in_memory=True) as db:
# Patch the BaseModel.Meta.database to use our in-memory database
with patch.object(BaseModel._meta, 'database', db):
# Create the required tables
with db.atomic():
db.create_tables([Trajectory, HumanInput, Session], safe=True)
yield db
# Clean up
with db.atomic():
Trajectory.drop_table(safe=True)
HumanInput.drop_table(safe=True)
Session.drop_table(safe=True)
@pytest.fixture
def sample_human_input(setup_db):
"""Create a sample human input in the database."""
return HumanInput.create(
content="Test human input",
source="test"
)
@pytest.fixture
def test_tool_parameters():
"""Return test tool parameters."""
return {
"pattern": "test pattern",
"file_path": "/path/to/file",
"options": {
"case_sensitive": True,
"whole_words": False
}
}
@pytest.fixture
def test_tool_result():
"""Return test tool result."""
return {
"matches": [
{"line": 10, "content": "This is a test pattern"},
{"line": 20, "content": "Another test pattern here"}
],
"total_matches": 2,
"execution_time": 0.5
}
@pytest.fixture
def test_step_data():
"""Return test step data for UI rendering."""
return {
"display_type": "text",
"content": "Tool execution results",
"highlights": [
{"start": 10, "end": 15, "color": "red"}
]
}
@pytest.fixture
def sample_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
"""Create a sample trajectory in the database."""
return Trajectory.create(
human_input=sample_human_input,
tool_name="ripgrep_search",
tool_parameters=json.dumps(test_tool_parameters),
tool_result=json.dumps(test_tool_result),
step_data=json.dumps(test_step_data),
record_type="tool_execution",
cost=0.001,
tokens=100,
is_error=False
)
def test_create_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
"""Test creating a trajectory with all fields."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create a trajectory
trajectory = repo.create(
tool_name="ripgrep_search",
tool_parameters=test_tool_parameters,
tool_result=test_tool_result,
step_data=test_step_data,
record_type="tool_execution",
human_input_id=sample_human_input.id,
cost=0.001,
tokens=100
)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the trajectory was created correctly
assert trajectory.id is not None
assert trajectory.tool_name == "ripgrep_search"
# Verify the JSON fields are dictionaries, not strings
assert isinstance(trajectory.tool_parameters, dict)
assert isinstance(trajectory.tool_result, dict)
assert isinstance(trajectory.step_data, dict)
# Verify the nested structure of tool parameters
assert trajectory.tool_parameters["options"]["case_sensitive"] == True
assert trajectory.tool_result["total_matches"] == 2
assert trajectory.step_data["highlights"][0]["color"] == "red"
# Verify foreign key reference
assert trajectory.human_input_id == sample_human_input.id
def test_create_trajectory_minimal(setup_db):
"""Test creating a trajectory with minimal fields."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create a trajectory with minimal fields
trajectory = repo.create(
tool_name="simple_tool"
)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the trajectory was created correctly
assert trajectory.id is not None
assert trajectory.tool_name == "simple_tool"
# Verify optional fields are None
assert trajectory.tool_parameters is None
assert trajectory.tool_result is None
assert trajectory.step_data is None
assert trajectory.human_input_id is None
assert trajectory.cost is None
assert trajectory.tokens is None
assert trajectory.is_error is False
def test_get_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
"""Test retrieving a trajectory by ID."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Get the trajectory by ID
trajectory = repo.get(sample_trajectory.id)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the retrieved trajectory matches the original
assert trajectory.id == sample_trajectory.id
assert trajectory.tool_name == sample_trajectory.tool_name
# Verify the JSON fields are dictionaries, not strings
assert isinstance(trajectory.tool_parameters, dict)
assert isinstance(trajectory.tool_result, dict)
assert isinstance(trajectory.step_data, dict)
# Verify the content of JSON fields
assert trajectory.tool_parameters == test_tool_parameters
assert trajectory.tool_result == test_tool_result
assert trajectory.step_data == test_step_data
# Verify non-existent trajectory returns None
non_existent_trajectory = repo.get(999)
assert non_existent_trajectory is None
def test_update_trajectory(setup_db, sample_trajectory):
"""Test updating a trajectory."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# New data for update
new_tool_result = {
"matches": [
{"line": 15, "content": "Updated test pattern"}
],
"total_matches": 1,
"execution_time": 0.3
}
new_step_data = {
"display_type": "html",
"content": "Updated UI rendering",
"highlights": []
}
# Update the trajectory
updated_trajectory = repo.update(
trajectory_id=sample_trajectory.id,
tool_result=new_tool_result,
step_data=new_step_data,
cost=0.002,
tokens=200,
is_error=True,
error_message="Test error",
error_type="TestErrorType",
error_details="Detailed error information"
)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(updated_trajectory, TrajectoryModel)
# Verify the fields were updated
assert updated_trajectory.tool_result == new_tool_result
assert updated_trajectory.step_data == new_step_data
assert updated_trajectory.cost == 0.002
assert updated_trajectory.tokens == 200
assert updated_trajectory.is_error is True
assert updated_trajectory.error_message == "Test error"
assert updated_trajectory.error_type == "TestErrorType"
assert updated_trajectory.error_details == "Detailed error information"
# Original tool parameters should not change
# We need to parse the JSON string from the Peewee object for comparison
original_params = json.loads(sample_trajectory.tool_parameters)
assert updated_trajectory.tool_parameters == original_params
# Verify updating a non-existent trajectory returns None
non_existent_update = repo.update(trajectory_id=999, cost=0.005)
assert non_existent_update is None
def test_delete_trajectory(setup_db, sample_trajectory):
"""Test deleting a trajectory."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Verify the trajectory exists
assert repo.get(sample_trajectory.id) is not None
# Delete the trajectory
result = repo.delete(sample_trajectory.id)
# Verify the trajectory was deleted
assert result is True
assert repo.get(sample_trajectory.id) is None
# Verify deleting a non-existent trajectory returns False
result = repo.delete(999)
assert result is False
def test_get_all_trajectories(setup_db, sample_human_input):
"""Test retrieving all trajectories."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create multiple trajectories
for i in range(3):
repo.create(
tool_name=f"tool_{i}",
tool_parameters={"index": i},
human_input_id=sample_human_input.id
)
# Get all trajectories
trajectories = repo.get_all()
# Verify we got a dictionary of TrajectoryModel objects
assert len(trajectories) == 3
for trajectory_id, trajectory in trajectories.items():
assert isinstance(trajectory, TrajectoryModel)
assert isinstance(trajectory.tool_parameters, dict)
# Verify the trajectories have the correct tool names
tool_names = {trajectory.tool_name for trajectory in trajectories.values()}
assert "tool_0" in tool_names
assert "tool_1" in tool_names
assert "tool_2" in tool_names
def test_get_trajectories_by_human_input(setup_db, sample_human_input):
"""Test retrieving trajectories by human input ID."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Create another human input
other_human_input = HumanInput.create(
content="Another human input",
source="test"
)
# Create trajectories for both human inputs
for i in range(2):
repo.create(
tool_name=f"tool_1_{i}",
human_input_id=sample_human_input.id
)
for i in range(3):
repo.create(
tool_name=f"tool_2_{i}",
human_input_id=other_human_input.id
)
# Get trajectories for the first human input
trajectories = repo.get_trajectories_by_human_input(sample_human_input.id)
# Verify we got a list of TrajectoryModel objects for the first human input
assert len(trajectories) == 2
for trajectory in trajectories:
assert isinstance(trajectory, TrajectoryModel)
assert trajectory.human_input_id == sample_human_input.id
assert trajectory.tool_name.startswith("tool_1")
# Get trajectories for the second human input
trajectories = repo.get_trajectories_by_human_input(other_human_input.id)
# Verify we got a list of TrajectoryModel objects for the second human input
assert len(trajectories) == 3
for trajectory in trajectories:
assert isinstance(trajectory, TrajectoryModel)
assert trajectory.human_input_id == other_human_input.id
assert trajectory.tool_name.startswith("tool_2")
def test_get_parsed_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
"""Test retrieving a parsed trajectory."""
# Set up repository
repo = TrajectoryRepository(db=setup_db)
# Get the parsed trajectory
trajectory = repo.get_parsed_trajectory(sample_trajectory.id)
# Verify type is TrajectoryModel, not Trajectory
assert isinstance(trajectory, TrajectoryModel)
# Verify the retrieved trajectory matches the original
assert trajectory.id == sample_trajectory.id
assert trajectory.tool_name == sample_trajectory.tool_name
# Verify the JSON fields are dictionaries, not strings
assert isinstance(trajectory.tool_parameters, dict)
assert isinstance(trajectory.tool_result, dict)
assert isinstance(trajectory.step_data, dict)
# Verify the content of JSON fields
assert trajectory.tool_parameters == test_tool_parameters
assert trajectory.tool_result == test_tool_result
assert trajectory.step_data == test_step_data
# Verify non-existent trajectory returns None
non_existent_trajectory = repo.get_parsed_trajectory(999)
assert non_existent_trajectory is None
def test_trajectory_repository_manager(setup_db, cleanup_repo):
"""Test the TrajectoryRepositoryManager context manager."""
# Use the context manager to create a repository
with TrajectoryRepositoryManager(setup_db) as repo:
# Verify the repository was created correctly
assert isinstance(repo, TrajectoryRepository)
assert repo.db is setup_db
# Create a trajectory and verify it's a TrajectoryModel
tool_parameters = {"test": "manager"}
trajectory = repo.create(
tool_name="manager_test",
tool_parameters=tool_parameters
)
assert isinstance(trajectory, TrajectoryModel)
assert trajectory.tool_parameters["test"] == "manager"
# Verify we can get the repository using get_trajectory_repository
repo_from_var = get_trajectory_repository()
assert repo_from_var is repo
# Verify the repository was removed from the context var
with pytest.raises(RuntimeError) as excinfo:
get_trajectory_repository()
assert "No TrajectoryRepository available" in str(excinfo.value)
def test_repository_init_without_db():
"""Test that TrajectoryRepository raises an error when initialized without a db parameter."""
# Attempt to create a repository without a database connection
with pytest.raises(ValueError) as excinfo:
TrajectoryRepository(db=None)
# Verify the correct error message
assert "Database connection is required" in str(excinfo.value)