RA.Aid/tests/ra_aid/server/test_sessions_api_integrati...

532 lines
18 KiB
Python

"""
Integration tests for the Sessions API endpoints.
This module contains integration tests for the API endpoints defined in ra_aid/server/api_v1_sessions.py.
It uses mocks to simulate the database interactions while testing the real API behavior.
"""
import pytest
import datetime
from typing import Dict, Any, List, Tuple
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from ra_aid.server.server import app
from ra_aid.database.pydantic_models import SessionModel
from ra_aid.server.api_v1_sessions import get_repository
# Mock session data for testing
@pytest.fixture
def mock_session():
"""Return a mock session for testing."""
return SessionModel(
id=1,
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
command_line="ra-aid test",
program_version="1.0.0",
machine_info={"os": "test"}
)
@pytest.fixture
def mock_sessions():
"""Return a list of mock sessions for testing."""
return [
SessionModel(
id=i+1,
created_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
updated_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
start_time=datetime.datetime(2025, 1, i+1, 0, 0, 0),
command_line=f"ra-aid test{i+1}",
program_version="1.0.0",
machine_info={"index": i}
)
for i in range(15)
]
@pytest.fixture
def mock_repo(mock_session, mock_sessions):
"""Create a mock repository with predefined responses."""
repo = MagicMock()
repo.get.return_value = mock_session
repo.get_all.return_value = (mock_sessions[:10], len(mock_sessions))
repo.create_session.return_value = mock_session
# Add behavior for custom parameters
def get_with_id(session_id):
if session_id == 999999:
return None
for session in mock_sessions:
if session.id == session_id:
return session
return mock_session
def get_all_with_pagination(offset=0, limit=10):
total = len(mock_sessions)
sorted_sessions = sorted(mock_sessions, key=lambda s: s.id, reverse=True)
return sorted_sessions[offset:offset+limit], total
def create_with_metadata(metadata=None):
if metadata is None:
return SessionModel(
id=16,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid test-null",
program_version="1.0.0",
machine_info=None
)
return SessionModel(
id=16,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid test-custom",
program_version="1.0.0",
machine_info=metadata
)
repo.get.side_effect = get_with_id
repo.get_all.side_effect = get_all_with_pagination
repo.create_session.side_effect = create_with_metadata
return repo
@pytest.fixture
def client(mock_repo):
"""Create a TestClient with the API and dependency overrides."""
# Override the dependency to use our mock repository
app.dependency_overrides[get_repository] = lambda: mock_repo
# Create a test client
with TestClient(app) as test_client:
yield test_client
# Clean up the dependency override
app.dependency_overrides.clear()
@pytest.fixture
def sample_metadata() -> Dict[str, Any]:
"""Return sample metadata for session creation."""
return {
"os": "Test OS",
"version": "1.0.0",
"environment": "test",
"cpu_cores": 4,
"memory_gb": 16,
"additional_info": {
"gpu": "Test GPU",
"display_resolution": "1920x1080"
}
}
def test_create_session_with_metadata(client, mock_repo, sample_metadata):
"""Test creating a session with metadata through the API endpoint."""
# Send request to create a session with metadata
response = client.post(
"/v1/sessions",
json={"metadata": sample_metadata}
)
# Verify response status code and structure
assert response.status_code == 201
data = response.json()
# Verify the session was created with the expected fields
assert data["id"] is not None
assert data["command_line"] is not None
assert data["program_version"] is not None
assert data["created_at"] is not None
assert data["updated_at"] is not None
assert data["start_time"] is not None
# Verify metadata was passed correctly to the repository
mock_repo.create_session.assert_called_once_with(metadata=sample_metadata)
def test_create_session_without_metadata(client, mock_repo):
"""Test creating a session without metadata through the API endpoint."""
# Send request without a body
response = client.post("/v1/sessions")
# Verify response status code and structure
assert response.status_code == 201
data = response.json()
# Verify the session was created with the expected fields
assert data["id"] is not None
assert data["command_line"] is not None
assert data["program_version"] is not None
# Verify correct parameters were passed to the repository
mock_repo.create_session.assert_called_once_with(metadata=None)
def test_get_session_by_id(client):
"""Test retrieving a session by ID through the API endpoint."""
# Use a completely isolated, standalone test
# For this test, let's focus on verifying the core functionality:
# 1. The API endpoint receives a request for a specific session ID
# 2. It calls the repository with that ID
# 3. It returns a properly formatted response
mock_repo = MagicMock()
# Create a test session with a simple machine_info to reduce serialization issues
test_session = SessionModel(
id=42,
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
command_line="ra-aid specific-test",
program_version="1.0.0-test",
machine_info=None # Use None to avoid serialization issues
)
# Configure the mock
mock_repo.get.return_value = test_session
# Override the dependency
app.dependency_overrides[get_repository] = lambda: mock_repo
try:
# Retrieve the session through the API
response = client.get(f"/v1/sessions/{test_session.id}")
# Verify response status code
assert response.status_code == 200
# Parse the response data
data = response.json()
# Print for debugging
import json
print("Response JSON:", json.dumps(data, indent=2))
# Verify the returned session matches what we expected
assert data["id"] == test_session.id
assert data["command_line"] == test_session.command_line
assert data["program_version"] == test_session.program_version
assert data["machine_info"] is None
# Verify the repository was called with the correct ID
mock_repo.get.assert_called_once_with(test_session.id)
finally:
# Clean up the override
if get_repository in app.dependency_overrides:
del app.dependency_overrides[get_repository]
def test_get_session_not_found(client, mock_repo):
"""Test the error handling when requesting a non-existent session."""
# Try to get a session with a non-existent ID
response = client.get("/v1/sessions/999999")
# Verify response status code and error message
assert response.status_code == 404
assert "not found" in response.json()["detail"]
# Verify the repository was called with the correct ID
mock_repo.get.assert_called_with(999999)
def test_list_sessions_empty(client, mock_repo):
"""Test listing sessions when no sessions exist."""
# Reset the mock first to clear any previous calls/side effects
mock_repo.reset_mock()
# Configure the mock to return empty results
mock_repo.get_all.side_effect = None # Clear any previous side effects
mock_repo.get_all.return_value = ([], 0)
# Get the list of sessions
response = client.get("/v1/sessions")
# Verify response status code and structure
assert response.status_code == 200
data = response.json()
# Verify the pagination response
assert data["total"] == 0
assert len(data["items"]) == 0
assert data["limit"] == 10
assert data["offset"] == 0
# Verify the repository was called with the correct parameters
mock_repo.get_all.assert_called_with(offset=0, limit=10)
def test_list_sessions_with_pagination(client, mock_repo, mock_sessions):
"""Test listing sessions with pagination parameters."""
# Set up the repository mock to return specific results for different pagination parameters
default_result = (mock_sessions[:10], len(mock_sessions))
limit_5_result = (mock_sessions[:5], len(mock_sessions))
offset_10_result = (mock_sessions[10:], len(mock_sessions))
offset_5_limit_3_result = (mock_sessions[5:8], len(mock_sessions))
pagination_responses = {
(0, 10): default_result,
(0, 5): limit_5_result,
(10, 10): offset_10_result,
(5, 3): offset_5_limit_3_result
}
def mock_get_all(offset=0, limit=10):
return pagination_responses.get((offset, limit), ([], 0))
mock_repo.get_all.side_effect = mock_get_all
# Test default pagination (limit=10, offset=0)
response = client.get("/v1/sessions")
assert response.status_code == 200
data = response.json()
assert data["total"] == len(mock_sessions)
assert len(data["items"]) == 10
assert data["limit"] == 10
assert data["offset"] == 0
# Test with custom limit
response = client.get("/v1/sessions?limit=5")
assert response.status_code == 200
data = response.json()
assert data["total"] == len(mock_sessions)
assert len(data["items"]) == 5
assert data["limit"] == 5
assert data["offset"] == 0
# Test with custom offset
response = client.get("/v1/sessions?offset=10")
assert response.status_code == 200
data = response.json()
assert data["total"] == len(mock_sessions)
assert len(data["items"]) == 5 # Only 5 items left after offset 10
assert data["limit"] == 10
assert data["offset"] == 10
# Test with both custom limit and offset
response = client.get("/v1/sessions?limit=3&offset=5")
assert response.status_code == 200
data = response.json()
assert data["total"] == len(mock_sessions)
assert len(data["items"]) == 3
assert data["limit"] == 3
assert data["offset"] == 5
def test_list_sessions_invalid_parameters(client):
"""Test error handling for invalid pagination parameters."""
# Test with negative offset
response = client.get("/v1/sessions?offset=-1")
assert response.status_code == 422
# Test with negative limit
response = client.get("/v1/sessions?limit=-5")
assert response.status_code == 422
# Test with zero limit
response = client.get("/v1/sessions?limit=0")
assert response.status_code == 422
# Test with limit exceeding maximum
response = client.get("/v1/sessions?limit=101")
assert response.status_code == 422
def test_metadata_validation(client, mock_repo):
"""Test validation for different metadata formats in session creation."""
# Create test sessions with different metadata
null_metadata_session = SessionModel(
id=20,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid test-null",
program_version="1.0.0",
machine_info=None
)
empty_dict_metadata_session = SessionModel(
id=21,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid test-empty",
program_version="1.0.0",
machine_info={}
)
complex_metadata_session = SessionModel(
id=22,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid test-complex",
program_version="1.0.0",
machine_info={"level1": {"level2": {"level3": [1, 2, 3, {"key": "value"}]}}}
)
# Configure mock to return different sessions based on metadata
def create_with_specific_metadata(metadata=None):
if metadata is None:
return null_metadata_session
elif metadata == {}:
return empty_dict_metadata_session
elif isinstance(metadata, dict) and "level1" in metadata:
return complex_metadata_session
return SessionModel(
id=23,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid test-other",
program_version="1.0.0",
machine_info=metadata
)
mock_repo.create_session.side_effect = create_with_specific_metadata
# Try to create a session with null metadata
response = client.post(
"/v1/sessions",
json={"metadata": None}
)
# This should work fine
assert response.status_code == 201
mock_repo.create_session.assert_called_with(metadata=None)
# Try to create a session with an empty metadata dict
response = client.post(
"/v1/sessions",
json={"metadata": {}}
)
# This should work fine
assert response.status_code == 201
mock_repo.create_session.assert_called_with(metadata={})
# Try to create a session with a complex nested metadata
response = client.post(
"/v1/sessions",
json={"metadata": {
"level1": {
"level2": {
"level3": [1, 2, 3, {"key": "value"}]
}
}
}}
)
# Verify the complex nested structure is preserved
assert response.status_code == 201
complex_metadata = {
"level1": {
"level2": {
"level3": [1, 2, 3, {"key": "value"}]
}
}
}
mock_repo.create_session.assert_called_with(metadata=complex_metadata)
def test_integration_workflow(client, mock_repo):
"""Test a complete workflow of creating and retrieving sessions."""
# Set up mock sessions for the workflow
first_session = SessionModel(
id=30,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid workflow-1",
program_version="1.0.0",
machine_info={"workflow_test": True}
)
second_session = SessionModel(
id=31,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
start_time=datetime.datetime.now(),
command_line="ra-aid workflow-2",
program_version="1.0.0",
machine_info={"workflow_test": False, "second": True}
)
# Configure mock for create_session
create_calls = 0
def create_session_for_workflow(metadata=None):
nonlocal create_calls
create_calls += 1
if create_calls == 1:
return first_session
return second_session
mock_repo.create_session.side_effect = create_session_for_workflow
# Configure mock for get
def get_session_for_workflow(session_id):
if session_id == first_session.id:
return first_session
elif session_id == second_session.id:
return second_session
return None
mock_repo.get.side_effect = get_session_for_workflow
# Configure mock for get_all
def get_all_for_workflow(offset=0, limit=10):
if create_calls == 1:
return [first_session], 1
return [second_session, first_session], 2
mock_repo.get_all.side_effect = get_all_for_workflow
# 1. Create a session
create_response = client.post(
"/v1/sessions",
json={"metadata": {"workflow_test": True}}
)
assert create_response.status_code == 201
session_id = create_response.json()["id"]
assert session_id == first_session.id
# 2. Retrieve the created session
get_response = client.get(f"/v1/sessions/{session_id}")
assert get_response.status_code == 200
assert get_response.json()["id"] == session_id
# 3. List all sessions and verify the created one is included
list_response = client.get("/v1/sessions")
assert list_response.status_code == 200
items = list_response.json()["items"]
assert len(items) == 1
assert items[0]["id"] == session_id
# 4. Create a second session
create_response2 = client.post(
"/v1/sessions",
json={"metadata": {"workflow_test": False, "second": True}}
)
assert create_response2.status_code == 201
session_id2 = create_response2.json()["id"]
assert session_id2 == second_session.id
# 5. List all sessions and verify both sessions are included
list_response = client.get("/v1/sessions")
assert list_response.status_code == 200
data = list_response.json()
assert data["total"] == 2
items = data["items"]
assert len(items) == 2
assert items[0]["id"] == session_id2
assert items[1]["id"] == session_id