add /v1/spawn-agent

This commit is contained in:
AI Christianson 2025-03-15 21:35:43 -04:00
parent 510e1016f8
commit fee23fcc21
3 changed files with 349 additions and 0 deletions

View File

@ -0,0 +1,213 @@
"""API router for spawning an RA.Aid agent."""
import threading
import logging
from typing import Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository
from ra_aid.database.connection import DatabaseManager
from ra_aid.database.repositories.session_repository import SessionRepositoryManager
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager
from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager
from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager
from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
from ra_aid.env_inv_context import EnvInvManager
from ra_aid.env_inv import EnvDiscovery
from ra_aid.database import ensure_migrations_applied
# Create logger
logger = logging.getLogger(__name__)
# Create API router
router = APIRouter(
prefix="/v1/spawn-agent",
tags=["agent"],
responses={
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Agent spawn error"},
},
)
class SpawnAgentRequest(BaseModel):
"""
Pydantic model for agent spawn requests.
This model provides validation for spawning a new agent.
Attributes:
message: The message or task for the agent to process
research_only: Whether to use research-only mode (default: False)
expert_enabled: Whether to enable expert assistance (default: True)
web_research_enabled: Whether to enable web research (default: False)
"""
message: str = Field(
description="The message or task for the agent to process"
)
research_only: bool = Field(
default=False,
description="Whether to use research-only mode"
)
expert_enabled: bool = Field(
default=True,
description="Whether to enable expert assistance"
)
web_research_enabled: bool = Field(
default=False,
description="Whether to enable web research"
)
class SpawnAgentResponse(BaseModel):
"""
Pydantic model for agent spawn responses.
This model defines the response format for the spawn-agent endpoint.
Attributes:
session_id: The ID of the created session
"""
session_id: str = Field(
description="The ID of the created session"
)
def get_repository() -> SessionRepository:
"""
Get the SessionRepository instance.
This function is used as a FastAPI dependency and can be overridden
in tests using dependency_overrides.
Returns:
SessionRepository: The repository instance
"""
return get_session_repository()
def run_agent_thread(
message: str,
session_id: str,
research_only: bool = False,
expert_enabled: bool = True,
web_research_enabled: bool = False,
):
"""
Run a research agent in a separate thread with proper repository initialization.
Args:
message: The message or task for the agent to process
session_id: The ID of the session to associate with this agent
research_only: Whether to use research-only mode
expert_enabled: Whether to enable expert assistance
web_research_enabled: Whether to enable web research
"""
try:
logger.info(f"Starting agent thread for session {session_id}")
# Initialize environment discovery
env_discovery = EnvDiscovery()
env_discovery.discover()
env_data = env_discovery.format_markdown()
# Apply any pending database migrations
try:
migration_result = ensure_migrations_applied()
if not migration_result:
logger.warning("Database migrations failed but execution will continue")
except Exception as e:
logger.error(f"Database migration error: {str(e)}")
# Initialize empty config dictionary
config = {}
# Initialize database connection and repositories
with DatabaseManager() as db, \
SessionRepositoryManager(db) as session_repo, \
KeyFactRepositoryManager(db) as key_fact_repo, \
KeySnippetRepositoryManager(db) as key_snippet_repo, \
HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo, \
RelatedFilesRepositoryManager() as related_files_repo, \
TrajectoryRepositoryManager(db) as trajectory_repo, \
WorkLogRepositoryManager() as work_log_repo, \
ConfigRepositoryManager(config) as config_repo, \
EnvInvManager(env_data) as env_inv:
# Import here to avoid circular imports
from ra_aid.__main__ import run_research_agent
# Run the research agent
run_research_agent(
base_task_or_query=message,
model=None, # Use default model
expert_enabled=expert_enabled,
research_only=research_only,
hil=False, # No human-in-the-loop for API
web_research_enabled=web_research_enabled,
thread_id=session_id
)
logger.info(f"Agent completed for session {session_id}")
except Exception as e:
logger.error(f"Error in agent thread for session {session_id}: {str(e)}")
@router.post(
"",
response_model=SpawnAgentResponse,
status_code=status.HTTP_201_CREATED,
summary="Spawn agent",
description="Spawn a new RA.Aid agent to process a message or task",
)
async def spawn_agent(
request: SpawnAgentRequest,
repo: SessionRepository = Depends(get_repository),
) -> SpawnAgentResponse:
"""
Spawn a new RA.Aid agent to process a message or task.
Args:
request: Request body with message and agent configuration
repo: SessionRepository dependency injection
Returns:
SpawnAgentResponse: Response with session ID
Raises:
HTTPException: With a 500 status code if there's an error spawning the agent
"""
try:
# Create a new session
metadata = {
"agent_type": "research-only" if request.research_only else "research",
"expert_enabled": request.expert_enabled,
"web_research_enabled": request.web_research_enabled,
}
session = repo.create_session(metadata=metadata)
# Start the agent thread
thread = threading.Thread(
target=run_agent_thread,
args=(
request.message,
str(session.id),
request.research_only,
request.expert_enabled,
request.web_research_enabled,
)
)
thread.daemon = True # Thread will terminate when main process exits
thread.start()
# Return the session ID
return SpawnAgentResponse(session_id=str(session.id))
except Exception as e:
logger.error(f"Error spawning agent: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error spawning agent: {str(e)}",
)

View File

@ -34,6 +34,7 @@ from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from ra_aid.server.api_v1_sessions import router as sessions_router from ra_aid.server.api_v1_sessions import router as sessions_router
from ra_aid.server.api_v1_spawn_agent import router as spawn_agent_router
app = FastAPI( app = FastAPI(
title="RA.Aid API", title="RA.Aid API",
@ -52,6 +53,7 @@ app.add_middleware(
# Include API routers # Include API routers
app.include_router(sessions_router) app.include_router(sessions_router)
app.include_router(spawn_agent_router)
# Setup templates and static files directories # Setup templates and static files directories
CURRENT_DIR = Path(__file__).parent CURRENT_DIR = Path(__file__).parent

View File

@ -0,0 +1,134 @@
"""
Tests for the Spawn Agent API v1 endpoint.
This module contains tests for the spawn-agent API endpoint in ra_aid/server/api_v1_spawn_agent.py.
It tests the creation of agent threads and session handling for the spawn-agent endpoint.
"""
import pytest
import threading
from unittest.mock import MagicMock
from fastapi import FastAPI
from fastapi.testclient import TestClient
from ra_aid.server.api_v1_spawn_agent import router, get_repository
from ra_aid.database.pydantic_models import SessionModel
import datetime
import ra_aid.server.api_v1_spawn_agent
@pytest.fixture
def mock_session():
"""Return a mock session for testing."""
return SessionModel(
id=123,
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
command_line="ra-aid test",
program_version="1.0.0",
machine_info={"agent_type": "research", "expert_enabled": True, "web_research_enabled": False}
)
@pytest.fixture
def mock_thread():
"""Create a mock thread that does nothing when started."""
mock = MagicMock()
mock.daemon = True
return mock
@pytest.fixture
def mock_repository(mock_session):
"""Create a mock repository for testing."""
mock_repo = MagicMock()
mock_repo.create_session.return_value = mock_session
return mock_repo
@pytest.fixture
def client(mock_repository, mock_thread, monkeypatch):
"""Set up a test client with mocked dependencies."""
# Create FastAPI app with router
app = FastAPI()
app.include_router(router)
# Override the dependency to use our mock repository
app.dependency_overrides[get_repository] = lambda: mock_repository
# Mock run_agent_thread to be a no-op
monkeypatch.setattr(
"ra_aid.server.api_v1_spawn_agent.run_agent_thread",
lambda *args, **kwargs: None
)
# Mock threading.Thread to return our mock thread
def mock_thread_constructor(*args, **kwargs):
mock_thread.target = kwargs.get('target')
mock_thread.args = kwargs.get('args')
mock_thread.daemon = kwargs.get('daemon', False)
return mock_thread
monkeypatch.setattr(
ra_aid.server.api_v1_spawn_agent,
"threading",
MagicMock(Thread=mock_thread_constructor)
)
client = TestClient(app)
# Add mocks to client for test access
client.mock_repo = mock_repository
client.mock_thread = mock_thread
yield client
# Clean up the dependency override
app.dependency_overrides.clear()
def test_spawn_agent(client, mock_repository, mock_thread):
"""Test spawning an agent with valid parameters."""
# Create the request payload
payload = {
"message": "Test task for the agent",
"research_only": False,
"expert_enabled": True,
"web_research_enabled": False
}
# Send the request
response = client.post("/v1/spawn-agent", json=payload)
# Verify response
assert response.status_code == 201
assert response.json() == {"session_id": "123"}
# Verify session creation
mock_repository.create_session.assert_called_once()
# Verify thread was created with correct args
assert mock_thread.args == ("Test task for the agent", "123", False, True, False)
assert mock_thread.daemon is True
# Verify thread.start was called
mock_thread.start.assert_called_once()
def test_spawn_agent_missing_message(client):
"""Test spawning an agent with missing required message parameter."""
# Create a request payload missing the required message
payload = {
"research_only": False,
"expert_enabled": True,
"web_research_enabled": False
}
# Send the request
response = client.post("/v1/spawn-agent", json=payload)
# Verify response indicates validation error
assert response.status_code == 422
error_detail = response.json().get("detail", [])
assert any("message" in error.get("loc", []) for error in error_detail)