458 lines
15 KiB
Python
458 lines
15 KiB
Python
"""
|
|
Tests for the TrajectoryRepository class.
|
|
"""
|
|
|
|
import pytest
|
|
import datetime
|
|
import json
|
|
from unittest.mock import patch
|
|
|
|
import peewee
|
|
|
|
from ra_aid.database.connection import DatabaseManager, db_var
|
|
from ra_aid.database.models import Trajectory, HumanInput, Session, BaseModel
|
|
from ra_aid.database.repositories.trajectory_repository import (
|
|
TrajectoryRepository,
|
|
TrajectoryRepositoryManager,
|
|
get_trajectory_repository,
|
|
trajectory_repo_var
|
|
)
|
|
from ra_aid.database.pydantic_models import TrajectoryModel
|
|
|
|
|
|
@pytest.fixture
|
|
def cleanup_db():
|
|
"""Reset the database contextvar and connection state after each test."""
|
|
# Reset before the test
|
|
db = db_var.get()
|
|
if db is not None:
|
|
try:
|
|
if not db.is_closed():
|
|
db.close()
|
|
except Exception:
|
|
# Ignore errors when closing the database
|
|
pass
|
|
db_var.set(None)
|
|
|
|
# Run the test
|
|
yield
|
|
|
|
# Reset after the test
|
|
db = db_var.get()
|
|
if db is not None:
|
|
try:
|
|
if not db.is_closed():
|
|
db.close()
|
|
except Exception:
|
|
# Ignore errors when closing the database
|
|
pass
|
|
db_var.set(None)
|
|
|
|
|
|
@pytest.fixture
|
|
def cleanup_repo():
|
|
"""Reset the repository contextvar after each test."""
|
|
# Reset before the test
|
|
trajectory_repo_var.set(None)
|
|
|
|
# Run the test
|
|
yield
|
|
|
|
# Reset after the test
|
|
trajectory_repo_var.set(None)
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_db(cleanup_db):
|
|
"""Set up an in-memory database with the necessary tables and patch the BaseModel.Meta.database."""
|
|
# Initialize an in-memory database connection
|
|
with DatabaseManager(in_memory=True) as db:
|
|
# Patch the BaseModel.Meta.database to use our in-memory database
|
|
with patch.object(BaseModel._meta, 'database', db):
|
|
# Create the required tables
|
|
with db.atomic():
|
|
db.create_tables([Trajectory, HumanInput, Session], safe=True)
|
|
|
|
yield db
|
|
|
|
# Clean up
|
|
with db.atomic():
|
|
Trajectory.drop_table(safe=True)
|
|
HumanInput.drop_table(safe=True)
|
|
Session.drop_table(safe=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_human_input(setup_db):
|
|
"""Create a sample human input in the database."""
|
|
return HumanInput.create(
|
|
content="Test human input",
|
|
source="test"
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_tool_parameters():
|
|
"""Return test tool parameters."""
|
|
return {
|
|
"pattern": "test pattern",
|
|
"file_path": "/path/to/file",
|
|
"options": {
|
|
"case_sensitive": True,
|
|
"whole_words": False
|
|
}
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def test_tool_result():
|
|
"""Return test tool result."""
|
|
return {
|
|
"matches": [
|
|
{"line": 10, "content": "This is a test pattern"},
|
|
{"line": 20, "content": "Another test pattern here"}
|
|
],
|
|
"total_matches": 2,
|
|
"execution_time": 0.5
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def test_step_data():
|
|
"""Return test step data for UI rendering."""
|
|
return {
|
|
"display_type": "text",
|
|
"content": "Tool execution results",
|
|
"highlights": [
|
|
{"start": 10, "end": 15, "color": "red"}
|
|
]
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
|
|
"""Create a sample trajectory in the database."""
|
|
return Trajectory.create(
|
|
human_input=sample_human_input,
|
|
tool_name="ripgrep_search",
|
|
tool_parameters=json.dumps(test_tool_parameters),
|
|
tool_result=json.dumps(test_tool_result),
|
|
step_data=json.dumps(test_step_data),
|
|
record_type="tool_execution",
|
|
cost=0.001,
|
|
tokens=100,
|
|
is_error=False
|
|
)
|
|
|
|
|
|
def test_create_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
|
|
"""Test creating a trajectory with all fields."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# Create a trajectory
|
|
trajectory = repo.create(
|
|
tool_name="ripgrep_search",
|
|
tool_parameters=test_tool_parameters,
|
|
tool_result=test_tool_result,
|
|
step_data=test_step_data,
|
|
record_type="tool_execution",
|
|
human_input_id=sample_human_input.id,
|
|
cost=0.001,
|
|
tokens=100
|
|
)
|
|
|
|
# Verify type is TrajectoryModel, not Trajectory
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
|
|
# Verify the trajectory was created correctly
|
|
assert trajectory.id is not None
|
|
assert trajectory.tool_name == "ripgrep_search"
|
|
|
|
# Verify the JSON fields are dictionaries, not strings
|
|
assert isinstance(trajectory.tool_parameters, dict)
|
|
assert isinstance(trajectory.tool_result, dict)
|
|
assert isinstance(trajectory.step_data, dict)
|
|
|
|
# Verify the nested structure of tool parameters
|
|
assert trajectory.tool_parameters["options"]["case_sensitive"] == True
|
|
assert trajectory.tool_result["total_matches"] == 2
|
|
assert trajectory.step_data["highlights"][0]["color"] == "red"
|
|
|
|
# Verify foreign key reference
|
|
assert trajectory.human_input_id == sample_human_input.id
|
|
|
|
|
|
def test_create_trajectory_minimal(setup_db):
|
|
"""Test creating a trajectory with minimal fields."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# Create a trajectory with minimal fields
|
|
trajectory = repo.create(
|
|
tool_name="simple_tool"
|
|
)
|
|
|
|
# Verify type is TrajectoryModel, not Trajectory
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
|
|
# Verify the trajectory was created correctly
|
|
assert trajectory.id is not None
|
|
assert trajectory.tool_name == "simple_tool"
|
|
|
|
# Verify optional fields are None
|
|
assert trajectory.tool_parameters is None
|
|
assert trajectory.tool_result is None
|
|
assert trajectory.step_data is None
|
|
assert trajectory.human_input_id is None
|
|
assert trajectory.cost is None
|
|
assert trajectory.tokens is None
|
|
assert trajectory.is_error is False
|
|
|
|
|
|
def test_get_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
|
|
"""Test retrieving a trajectory by ID."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# Get the trajectory by ID
|
|
trajectory = repo.get(sample_trajectory.id)
|
|
|
|
# Verify type is TrajectoryModel, not Trajectory
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
|
|
# Verify the retrieved trajectory matches the original
|
|
assert trajectory.id == sample_trajectory.id
|
|
assert trajectory.tool_name == sample_trajectory.tool_name
|
|
|
|
# Verify the JSON fields are dictionaries, not strings
|
|
assert isinstance(trajectory.tool_parameters, dict)
|
|
assert isinstance(trajectory.tool_result, dict)
|
|
assert isinstance(trajectory.step_data, dict)
|
|
|
|
# Verify the content of JSON fields
|
|
assert trajectory.tool_parameters == test_tool_parameters
|
|
assert trajectory.tool_result == test_tool_result
|
|
assert trajectory.step_data == test_step_data
|
|
|
|
# Verify non-existent trajectory returns None
|
|
non_existent_trajectory = repo.get(999)
|
|
assert non_existent_trajectory is None
|
|
|
|
|
|
def test_update_trajectory(setup_db, sample_trajectory):
|
|
"""Test updating a trajectory."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# New data for update
|
|
new_tool_result = {
|
|
"matches": [
|
|
{"line": 15, "content": "Updated test pattern"}
|
|
],
|
|
"total_matches": 1,
|
|
"execution_time": 0.3
|
|
}
|
|
|
|
new_step_data = {
|
|
"display_type": "html",
|
|
"content": "Updated UI rendering",
|
|
"highlights": []
|
|
}
|
|
|
|
# Update the trajectory
|
|
updated_trajectory = repo.update(
|
|
trajectory_id=sample_trajectory.id,
|
|
tool_result=new_tool_result,
|
|
step_data=new_step_data,
|
|
cost=0.002,
|
|
tokens=200,
|
|
is_error=True,
|
|
error_message="Test error",
|
|
error_type="TestErrorType",
|
|
error_details="Detailed error information"
|
|
)
|
|
|
|
# Verify type is TrajectoryModel, not Trajectory
|
|
assert isinstance(updated_trajectory, TrajectoryModel)
|
|
|
|
# Verify the fields were updated
|
|
assert updated_trajectory.tool_result == new_tool_result
|
|
assert updated_trajectory.step_data == new_step_data
|
|
assert updated_trajectory.cost == 0.002
|
|
assert updated_trajectory.tokens == 200
|
|
assert updated_trajectory.is_error is True
|
|
assert updated_trajectory.error_message == "Test error"
|
|
assert updated_trajectory.error_type == "TestErrorType"
|
|
assert updated_trajectory.error_details == "Detailed error information"
|
|
|
|
# Original tool parameters should not change
|
|
# We need to parse the JSON string from the Peewee object for comparison
|
|
original_params = json.loads(sample_trajectory.tool_parameters)
|
|
assert updated_trajectory.tool_parameters == original_params
|
|
|
|
# Verify updating a non-existent trajectory returns None
|
|
non_existent_update = repo.update(trajectory_id=999, cost=0.005)
|
|
assert non_existent_update is None
|
|
|
|
|
|
def test_delete_trajectory(setup_db, sample_trajectory):
|
|
"""Test deleting a trajectory."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# Verify the trajectory exists
|
|
assert repo.get(sample_trajectory.id) is not None
|
|
|
|
# Delete the trajectory
|
|
result = repo.delete(sample_trajectory.id)
|
|
|
|
# Verify the trajectory was deleted
|
|
assert result is True
|
|
assert repo.get(sample_trajectory.id) is None
|
|
|
|
# Verify deleting a non-existent trajectory returns False
|
|
result = repo.delete(999)
|
|
assert result is False
|
|
|
|
|
|
def test_get_all_trajectories(setup_db, sample_human_input):
|
|
"""Test retrieving all trajectories."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# Create multiple trajectories
|
|
for i in range(3):
|
|
repo.create(
|
|
tool_name=f"tool_{i}",
|
|
tool_parameters={"index": i},
|
|
human_input_id=sample_human_input.id
|
|
)
|
|
|
|
# Get all trajectories
|
|
trajectories = repo.get_all()
|
|
|
|
# Verify we got a dictionary of TrajectoryModel objects
|
|
assert len(trajectories) == 3
|
|
for trajectory_id, trajectory in trajectories.items():
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
assert isinstance(trajectory.tool_parameters, dict)
|
|
|
|
# Verify the trajectories have the correct tool names
|
|
tool_names = {trajectory.tool_name for trajectory in trajectories.values()}
|
|
assert "tool_0" in tool_names
|
|
assert "tool_1" in tool_names
|
|
assert "tool_2" in tool_names
|
|
|
|
|
|
def test_get_trajectories_by_human_input(setup_db, sample_human_input):
|
|
"""Test retrieving trajectories by human input ID."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# Create another human input
|
|
other_human_input = HumanInput.create(
|
|
content="Another human input",
|
|
source="test"
|
|
)
|
|
|
|
# Create trajectories for both human inputs
|
|
for i in range(2):
|
|
repo.create(
|
|
tool_name=f"tool_1_{i}",
|
|
human_input_id=sample_human_input.id
|
|
)
|
|
|
|
for i in range(3):
|
|
repo.create(
|
|
tool_name=f"tool_2_{i}",
|
|
human_input_id=other_human_input.id
|
|
)
|
|
|
|
# Get trajectories for the first human input
|
|
trajectories = repo.get_trajectories_by_human_input(sample_human_input.id)
|
|
|
|
# Verify we got a list of TrajectoryModel objects for the first human input
|
|
assert len(trajectories) == 2
|
|
for trajectory in trajectories:
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
assert trajectory.human_input_id == sample_human_input.id
|
|
assert trajectory.tool_name.startswith("tool_1")
|
|
|
|
# Get trajectories for the second human input
|
|
trajectories = repo.get_trajectories_by_human_input(other_human_input.id)
|
|
|
|
# Verify we got a list of TrajectoryModel objects for the second human input
|
|
assert len(trajectories) == 3
|
|
for trajectory in trajectories:
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
assert trajectory.human_input_id == other_human_input.id
|
|
assert trajectory.tool_name.startswith("tool_2")
|
|
|
|
|
|
def test_get_parsed_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
|
|
"""Test retrieving a parsed trajectory."""
|
|
# Set up repository
|
|
repo = TrajectoryRepository(db=setup_db)
|
|
|
|
# Get the parsed trajectory
|
|
trajectory = repo.get_parsed_trajectory(sample_trajectory.id)
|
|
|
|
# Verify type is TrajectoryModel, not Trajectory
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
|
|
# Verify the retrieved trajectory matches the original
|
|
assert trajectory.id == sample_trajectory.id
|
|
assert trajectory.tool_name == sample_trajectory.tool_name
|
|
|
|
# Verify the JSON fields are dictionaries, not strings
|
|
assert isinstance(trajectory.tool_parameters, dict)
|
|
assert isinstance(trajectory.tool_result, dict)
|
|
assert isinstance(trajectory.step_data, dict)
|
|
|
|
# Verify the content of JSON fields
|
|
assert trajectory.tool_parameters == test_tool_parameters
|
|
assert trajectory.tool_result == test_tool_result
|
|
assert trajectory.step_data == test_step_data
|
|
|
|
# Verify non-existent trajectory returns None
|
|
non_existent_trajectory = repo.get_parsed_trajectory(999)
|
|
assert non_existent_trajectory is None
|
|
|
|
|
|
def test_trajectory_repository_manager(setup_db, cleanup_repo):
|
|
"""Test the TrajectoryRepositoryManager context manager."""
|
|
# Use the context manager to create a repository
|
|
with TrajectoryRepositoryManager(setup_db) as repo:
|
|
# Verify the repository was created correctly
|
|
assert isinstance(repo, TrajectoryRepository)
|
|
assert repo.db is setup_db
|
|
|
|
# Create a trajectory and verify it's a TrajectoryModel
|
|
tool_parameters = {"test": "manager"}
|
|
trajectory = repo.create(
|
|
tool_name="manager_test",
|
|
tool_parameters=tool_parameters
|
|
)
|
|
assert isinstance(trajectory, TrajectoryModel)
|
|
assert trajectory.tool_parameters["test"] == "manager"
|
|
|
|
# Verify we can get the repository using get_trajectory_repository
|
|
repo_from_var = get_trajectory_repository()
|
|
assert repo_from_var is repo
|
|
|
|
# Verify the repository was removed from the context var
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
get_trajectory_repository()
|
|
|
|
assert "No TrajectoryRepository available" in str(excinfo.value)
|
|
|
|
|
|
def test_repository_init_without_db():
|
|
"""Test that TrajectoryRepository raises an error when initialized without a db parameter."""
|
|
# Attempt to create a repository without a database connection
|
|
with pytest.raises(ValueError) as excinfo:
|
|
TrajectoryRepository(db=None)
|
|
|
|
# Verify the correct error message
|
|
assert "Database connection is required" in str(excinfo.value) |