RA.Aid/tests/ra_aid/server/test_api_v1_spawn_agent.py

153 lines
4.7 KiB
Python

"""
Tests for the Spawn Agent API v1 endpoint.
This module contains tests for the spawn-agent API endpoint in ra_aid/server/api_v1_spawn_agent.py.
It tests the creation of agent threads and session handling for the spawn-agent endpoint.
"""
import pytest
import threading
from unittest.mock import MagicMock
from fastapi import FastAPI
from fastapi.testclient import TestClient
from ra_aid.server.api_v1_spawn_agent import router, get_repository
from ra_aid.database.pydantic_models import SessionModel
import datetime
import ra_aid.server.api_v1_spawn_agent
@pytest.fixture
def mock_session():
"""Return a mock session for testing."""
return SessionModel(
id=123,
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
command_line="ra-aid test",
program_version="1.0.0",
machine_info={"agent_type": "research", "expert_enabled": True, "web_research_enabled": False}
)
@pytest.fixture
def mock_thread():
"""Create a mock thread that does nothing when started."""
mock = MagicMock()
mock.daemon = True
return mock
@pytest.fixture
def mock_repository(mock_session):
"""Create a mock repository for testing."""
mock_repo = MagicMock()
mock_repo.create_session.return_value = mock_session
return mock_repo
@pytest.fixture
def mock_config_repository():
"""Create a mock config repository for testing."""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: {
"expert_enabled": True,
"web_research_enabled": False,
"provider": "anthropic",
"model": "claude-3-7-sonnet-20250219",
}.get(key, default)
return mock_config
@pytest.fixture
def client(mock_repository, mock_thread, mock_config_repository, monkeypatch):
"""Set up a test client with mocked dependencies."""
# Create FastAPI app with router
app = FastAPI()
app.include_router(router)
# Override the dependency to use our mock repository
app.dependency_overrides[get_repository] = lambda: mock_repository
# Mock run_agent_thread to be a no-op
monkeypatch.setattr(
"ra_aid.server.api_v1_spawn_agent.run_agent_thread",
lambda *args, **kwargs: None
)
# Mock get_config_repository to use our mock
monkeypatch.setattr(
"ra_aid.server.api_v1_spawn_agent.get_config_repository",
lambda: mock_config_repository
)
# Mock threading.Thread to return our mock thread
def mock_thread_constructor(*args, **kwargs):
mock_thread.target = kwargs.get('target')
mock_thread.args = kwargs.get('args')
mock_thread.daemon = kwargs.get('daemon', False)
return mock_thread
monkeypatch.setattr(
ra_aid.server.api_v1_spawn_agent,
"threading",
MagicMock(Thread=mock_thread_constructor)
)
client = TestClient(app)
# Add mocks to client for test access
client.mock_repo = mock_repository
client.mock_thread = mock_thread
client.mock_config = mock_config_repository
yield client
# Clean up the dependency override
app.dependency_overrides.clear()
def test_spawn_agent(client, mock_repository, mock_thread):
"""Test spawning an agent with valid parameters."""
# Create the request payload
payload = {
"message": "Test task for the agent",
"research_only": False,
"expert_enabled": True,
"web_research_enabled": False
}
# Send the request
response = client.post("/v1/spawn-agent", json=payload)
# Verify response
assert response.status_code == 201
assert response.json() == {"session_id": "123"}
# Verify session creation
mock_repository.create_session.assert_called_once()
# Verify thread was created with correct args
assert mock_thread.args == ("Test task for the agent", "123", False)
assert mock_thread.daemon is True
# Verify thread.start was called
mock_thread.start.assert_called_once()
def test_spawn_agent_missing_message(client):
"""Test spawning an agent with missing required message parameter."""
# Create a request payload missing the required message
payload = {
"research_only": False,
"expert_enabled": True,
"web_research_enabled": False
}
# Send the request
response = client.post("/v1/spawn-agent", json=payload)
# Verify response indicates validation error
assert response.status_code == 422
error_detail = response.json().get("detail", [])
assert any("message" in error.get("loc", []) for error in error_detail)