add /v1/spawn-agent
This commit is contained in:
parent
510e1016f8
commit
fee23fcc21
|
|
@ -0,0 +1,213 @@
|
|||
"""API router for spawning an RA.Aid agent."""
|
||||
|
||||
import threading
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository
|
||||
from ra_aid.database.connection import DatabaseManager
|
||||
from ra_aid.database.repositories.session_repository import SessionRepositoryManager
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager
|
||||
from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager
|
||||
from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager
|
||||
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||
from ra_aid.env_inv_context import EnvInvManager
|
||||
from ra_aid.env_inv import EnvDiscovery
|
||||
from ra_aid.database import ensure_migrations_applied
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create API router
|
||||
router = APIRouter(
|
||||
prefix="/v1/spawn-agent",
|
||||
tags=["agent"],
|
||||
responses={
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"},
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Agent spawn error"},
|
||||
},
|
||||
)
|
||||
|
||||
class SpawnAgentRequest(BaseModel):
|
||||
"""
|
||||
Pydantic model for agent spawn requests.
|
||||
|
||||
This model provides validation for spawning a new agent.
|
||||
|
||||
Attributes:
|
||||
message: The message or task for the agent to process
|
||||
research_only: Whether to use research-only mode (default: False)
|
||||
expert_enabled: Whether to enable expert assistance (default: True)
|
||||
web_research_enabled: Whether to enable web research (default: False)
|
||||
"""
|
||||
message: str = Field(
|
||||
description="The message or task for the agent to process"
|
||||
)
|
||||
research_only: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use research-only mode"
|
||||
)
|
||||
expert_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable expert assistance"
|
||||
)
|
||||
web_research_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable web research"
|
||||
)
|
||||
|
||||
class SpawnAgentResponse(BaseModel):
|
||||
"""
|
||||
Pydantic model for agent spawn responses.
|
||||
|
||||
This model defines the response format for the spawn-agent endpoint.
|
||||
|
||||
Attributes:
|
||||
session_id: The ID of the created session
|
||||
"""
|
||||
session_id: str = Field(
|
||||
description="The ID of the created session"
|
||||
)
|
||||
|
||||
|
||||
def get_repository() -> SessionRepository:
|
||||
"""
|
||||
Get the SessionRepository instance.
|
||||
|
||||
This function is used as a FastAPI dependency and can be overridden
|
||||
in tests using dependency_overrides.
|
||||
|
||||
Returns:
|
||||
SessionRepository: The repository instance
|
||||
"""
|
||||
return get_session_repository()
|
||||
|
||||
def run_agent_thread(
|
||||
message: str,
|
||||
session_id: str,
|
||||
research_only: bool = False,
|
||||
expert_enabled: bool = True,
|
||||
web_research_enabled: bool = False,
|
||||
):
|
||||
"""
|
||||
Run a research agent in a separate thread with proper repository initialization.
|
||||
|
||||
Args:
|
||||
message: The message or task for the agent to process
|
||||
session_id: The ID of the session to associate with this agent
|
||||
research_only: Whether to use research-only mode
|
||||
expert_enabled: Whether to enable expert assistance
|
||||
web_research_enabled: Whether to enable web research
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting agent thread for session {session_id}")
|
||||
|
||||
# Initialize environment discovery
|
||||
env_discovery = EnvDiscovery()
|
||||
env_discovery.discover()
|
||||
env_data = env_discovery.format_markdown()
|
||||
|
||||
# Apply any pending database migrations
|
||||
try:
|
||||
migration_result = ensure_migrations_applied()
|
||||
if not migration_result:
|
||||
logger.warning("Database migrations failed but execution will continue")
|
||||
except Exception as e:
|
||||
logger.error(f"Database migration error: {str(e)}")
|
||||
|
||||
# Initialize empty config dictionary
|
||||
config = {}
|
||||
|
||||
# Initialize database connection and repositories
|
||||
with DatabaseManager() as db, \
|
||||
SessionRepositoryManager(db) as session_repo, \
|
||||
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||
WorkLogRepositoryManager() as work_log_repo, \
|
||||
ConfigRepositoryManager(config) as config_repo, \
|
||||
EnvInvManager(env_data) as env_inv:
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.__main__ import run_research_agent
|
||||
|
||||
# Run the research agent
|
||||
run_research_agent(
|
||||
base_task_or_query=message,
|
||||
model=None, # Use default model
|
||||
expert_enabled=expert_enabled,
|
||||
research_only=research_only,
|
||||
hil=False, # No human-in-the-loop for API
|
||||
web_research_enabled=web_research_enabled,
|
||||
thread_id=session_id
|
||||
)
|
||||
|
||||
logger.info(f"Agent completed for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent thread for session {session_id}: {str(e)}")
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=SpawnAgentResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Spawn agent",
|
||||
description="Spawn a new RA.Aid agent to process a message or task",
|
||||
)
|
||||
async def spawn_agent(
|
||||
request: SpawnAgentRequest,
|
||||
repo: SessionRepository = Depends(get_repository),
|
||||
) -> SpawnAgentResponse:
|
||||
"""
|
||||
Spawn a new RA.Aid agent to process a message or task.
|
||||
|
||||
Args:
|
||||
request: Request body with message and agent configuration
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
SpawnAgentResponse: Response with session ID
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 500 status code if there's an error spawning the agent
|
||||
"""
|
||||
try:
|
||||
# Create a new session
|
||||
metadata = {
|
||||
"agent_type": "research-only" if request.research_only else "research",
|
||||
"expert_enabled": request.expert_enabled,
|
||||
"web_research_enabled": request.web_research_enabled,
|
||||
}
|
||||
session = repo.create_session(metadata=metadata)
|
||||
|
||||
# Start the agent thread
|
||||
thread = threading.Thread(
|
||||
target=run_agent_thread,
|
||||
args=(
|
||||
request.message,
|
||||
str(session.id),
|
||||
request.research_only,
|
||||
request.expert_enabled,
|
||||
request.web_research_enabled,
|
||||
)
|
||||
)
|
||||
thread.daemon = True # Thread will terminate when main process exits
|
||||
thread.start()
|
||||
|
||||
# Return the session ID
|
||||
return SpawnAgentResponse(session_id=str(session.id))
|
||||
except Exception as e:
|
||||
logger.error(f"Error spawning agent: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error spawning agent: {str(e)}",
|
||||
)
|
||||
|
|
@ -34,6 +34,7 @@ from fastapi.staticfiles import StaticFiles
|
|||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from ra_aid.server.api_v1_sessions import router as sessions_router
|
||||
from ra_aid.server.api_v1_spawn_agent import router as spawn_agent_router
|
||||
|
||||
app = FastAPI(
|
||||
title="RA.Aid API",
|
||||
|
|
@ -52,6 +53,7 @@ app.add_middleware(
|
|||
|
||||
# Include API routers
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(spawn_agent_router)
|
||||
|
||||
# Setup templates and static files directories
|
||||
CURRENT_DIR = Path(__file__).parent
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Tests for the Spawn Agent API v1 endpoint.
|
||||
|
||||
This module contains tests for the spawn-agent API endpoint in ra_aid/server/api_v1_spawn_agent.py.
|
||||
It tests the creation of agent threads and session handling for the spawn-agent endpoint.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ra_aid.server.api_v1_spawn_agent import router, get_repository
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
import datetime
|
||||
import ra_aid.server.api_v1_spawn_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Return a mock session for testing."""
|
||||
return SessionModel(
|
||||
id=123,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info={"agent_type": "research", "expert_enabled": True, "web_research_enabled": False}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_thread():
|
||||
"""Create a mock thread that does nothing when started."""
|
||||
mock = MagicMock()
|
||||
mock.daemon = True
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repository(mock_session):
|
||||
"""Create a mock repository for testing."""
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.create_session.return_value = mock_session
|
||||
return mock_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repository, mock_thread, monkeypatch):
|
||||
"""Set up a test client with mocked dependencies."""
|
||||
# Create FastAPI app with router
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
# Override the dependency to use our mock repository
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repository
|
||||
|
||||
# Mock run_agent_thread to be a no-op
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.server.api_v1_spawn_agent.run_agent_thread",
|
||||
lambda *args, **kwargs: None
|
||||
)
|
||||
|
||||
# Mock threading.Thread to return our mock thread
|
||||
def mock_thread_constructor(*args, **kwargs):
|
||||
mock_thread.target = kwargs.get('target')
|
||||
mock_thread.args = kwargs.get('args')
|
||||
mock_thread.daemon = kwargs.get('daemon', False)
|
||||
return mock_thread
|
||||
|
||||
monkeypatch.setattr(
|
||||
ra_aid.server.api_v1_spawn_agent,
|
||||
"threading",
|
||||
MagicMock(Thread=mock_thread_constructor)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Add mocks to client for test access
|
||||
client.mock_repo = mock_repository
|
||||
client.mock_thread = mock_thread
|
||||
|
||||
yield client
|
||||
|
||||
# Clean up the dependency override
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_spawn_agent(client, mock_repository, mock_thread):
|
||||
"""Test spawning an agent with valid parameters."""
|
||||
# Create the request payload
|
||||
payload = {
|
||||
"message": "Test task for the agent",
|
||||
"research_only": False,
|
||||
"expert_enabled": True,
|
||||
"web_research_enabled": False
|
||||
}
|
||||
|
||||
# Send the request
|
||||
response = client.post("/v1/spawn-agent", json=payload)
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"session_id": "123"}
|
||||
|
||||
# Verify session creation
|
||||
mock_repository.create_session.assert_called_once()
|
||||
|
||||
# Verify thread was created with correct args
|
||||
assert mock_thread.args == ("Test task for the agent", "123", False, True, False)
|
||||
assert mock_thread.daemon is True
|
||||
|
||||
# Verify thread.start was called
|
||||
mock_thread.start.assert_called_once()
|
||||
|
||||
|
||||
def test_spawn_agent_missing_message(client):
|
||||
"""Test spawning an agent with missing required message parameter."""
|
||||
# Create a request payload missing the required message
|
||||
payload = {
|
||||
"research_only": False,
|
||||
"expert_enabled": True,
|
||||
"web_research_enabled": False
|
||||
}
|
||||
|
||||
# Send the request
|
||||
response = client.post("/v1/spawn-agent", json=payload)
|
||||
|
||||
# Verify response indicates validation error
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json().get("detail", [])
|
||||
assert any("message" in error.get("loc", []) for error in error_detail)
|
||||
Loading…
Reference in New Issue