diff --git a/ra_aid/server/api_v1_spawn_agent.py b/ra_aid/server/api_v1_spawn_agent.py new file mode 100644 index 0000000..cb247a0 --- /dev/null +++ b/ra_aid/server/api_v1_spawn_agent.py @@ -0,0 +1,213 @@ +"""API router for spawning an RA.Aid agent.""" + +import threading +import logging +from typing import Dict, Any, Optional + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field + +from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository +from ra_aid.database.connection import DatabaseManager +from ra_aid.database.repositories.session_repository import SessionRepositoryManager +from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager +from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager +from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager +from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager +from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager +from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager +from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager +from ra_aid.database.repositories.config_repository import ConfigRepositoryManager +from ra_aid.env_inv_context import EnvInvManager +from ra_aid.env_inv import EnvDiscovery +from ra_aid.database import ensure_migrations_applied + +# Create logger +logger = logging.getLogger(__name__) + +# Create API router +router = APIRouter( + prefix="/v1/spawn-agent", + tags=["agent"], + responses={ + status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Agent spawn error"}, + }, +) + +class SpawnAgentRequest(BaseModel): + """ + Pydantic model for agent spawn requests. + + This model provides validation for spawning a new agent. + + Attributes: + message: The message or task for the agent to process + research_only: Whether to use research-only mode (default: False) + expert_enabled: Whether to enable expert assistance (default: True) + web_research_enabled: Whether to enable web research (default: False) + """ + message: str = Field( + description="The message or task for the agent to process" + ) + research_only: bool = Field( + default=False, + description="Whether to use research-only mode" + ) + expert_enabled: bool = Field( + default=True, + description="Whether to enable expert assistance" + ) + web_research_enabled: bool = Field( + default=False, + description="Whether to enable web research" + ) + +class SpawnAgentResponse(BaseModel): + """ + Pydantic model for agent spawn responses. + + This model defines the response format for the spawn-agent endpoint. + + Attributes: + session_id: The ID of the created session + """ + session_id: str = Field( + description="The ID of the created session" + ) + + +def get_repository() -> SessionRepository: + """ + Get the SessionRepository instance. + + This function is used as a FastAPI dependency and can be overridden + in tests using dependency_overrides. + + Returns: + SessionRepository: The repository instance + """ + return get_session_repository() + +def run_agent_thread( + message: str, + session_id: str, + research_only: bool = False, + expert_enabled: bool = True, + web_research_enabled: bool = False, +): + """ + Run a research agent in a separate thread with proper repository initialization. + + Args: + message: The message or task for the agent to process + session_id: The ID of the session to associate with this agent + research_only: Whether to use research-only mode + expert_enabled: Whether to enable expert assistance + web_research_enabled: Whether to enable web research + """ + try: + logger.info(f"Starting agent thread for session {session_id}") + + # Initialize environment discovery + env_discovery = EnvDiscovery() + env_discovery.discover() + env_data = env_discovery.format_markdown() + + # Apply any pending database migrations + try: + migration_result = ensure_migrations_applied() + if not migration_result: + logger.warning("Database migrations failed but execution will continue") + except Exception as e: + logger.error(f"Database migration error: {str(e)}") + + # Initialize empty config dictionary + config = {} + + # Initialize database connection and repositories + with DatabaseManager() as db, \ + SessionRepositoryManager(db) as session_repo, \ + KeyFactRepositoryManager(db) as key_fact_repo, \ + KeySnippetRepositoryManager(db) as key_snippet_repo, \ + HumanInputRepositoryManager(db) as human_input_repo, \ + ResearchNoteRepositoryManager(db) as research_note_repo, \ + RelatedFilesRepositoryManager() as related_files_repo, \ + TrajectoryRepositoryManager(db) as trajectory_repo, \ + WorkLogRepositoryManager() as work_log_repo, \ + ConfigRepositoryManager(config) as config_repo, \ + EnvInvManager(env_data) as env_inv: + + # Import here to avoid circular imports + from ra_aid.__main__ import run_research_agent + + # Run the research agent + run_research_agent( + base_task_or_query=message, + model=None, # Use default model + expert_enabled=expert_enabled, + research_only=research_only, + hil=False, # No human-in-the-loop for API + web_research_enabled=web_research_enabled, + thread_id=session_id + ) + + logger.info(f"Agent completed for session {session_id}") + except Exception as e: + logger.error(f"Error in agent thread for session {session_id}: {str(e)}") + +@router.post( + "", + response_model=SpawnAgentResponse, + status_code=status.HTTP_201_CREATED, + summary="Spawn agent", + description="Spawn a new RA.Aid agent to process a message or task", +) +async def spawn_agent( + request: SpawnAgentRequest, + repo: SessionRepository = Depends(get_repository), +) -> SpawnAgentResponse: + """ + Spawn a new RA.Aid agent to process a message or task. + + Args: + request: Request body with message and agent configuration + repo: SessionRepository dependency injection + + Returns: + SpawnAgentResponse: Response with session ID + + Raises: + HTTPException: With a 500 status code if there's an error spawning the agent + """ + try: + # Create a new session + metadata = { + "agent_type": "research-only" if request.research_only else "research", + "expert_enabled": request.expert_enabled, + "web_research_enabled": request.web_research_enabled, + } + session = repo.create_session(metadata=metadata) + + # Start the agent thread + thread = threading.Thread( + target=run_agent_thread, + args=( + request.message, + str(session.id), + request.research_only, + request.expert_enabled, + request.web_research_enabled, + ) + ) + thread.daemon = True # Thread will terminate when main process exits + thread.start() + + # Return the session ID + return SpawnAgentResponse(session_id=str(session.id)) + except Exception as e: + logger.error(f"Error spawning agent: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error spawning agent: {str(e)}", + ) \ No newline at end of file diff --git a/ra_aid/server/server.py b/ra_aid/server/server.py index d44f5c4..45b221f 100644 --- a/ra_aid/server/server.py +++ b/ra_aid/server/server.py @@ -34,6 +34,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from ra_aid.server.api_v1_sessions import router as sessions_router +from ra_aid.server.api_v1_spawn_agent import router as spawn_agent_router app = FastAPI( title="RA.Aid API", @@ -52,6 +53,7 @@ app.add_middleware( # Include API routers app.include_router(sessions_router) +app.include_router(spawn_agent_router) # Setup templates and static files directories CURRENT_DIR = Path(__file__).parent diff --git a/tests/ra_aid/server/test_api_v1_spawn_agent.py b/tests/ra_aid/server/test_api_v1_spawn_agent.py new file mode 100644 index 0000000..a440693 --- /dev/null +++ b/tests/ra_aid/server/test_api_v1_spawn_agent.py @@ -0,0 +1,134 @@ +""" +Tests for the Spawn Agent API v1 endpoint. + +This module contains tests for the spawn-agent API endpoint in ra_aid/server/api_v1_spawn_agent.py. +It tests the creation of agent threads and session handling for the spawn-agent endpoint. +""" + +import pytest +import threading +from unittest.mock import MagicMock +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from ra_aid.server.api_v1_spawn_agent import router, get_repository +from ra_aid.database.pydantic_models import SessionModel +import datetime +import ra_aid.server.api_v1_spawn_agent + + +@pytest.fixture +def mock_session(): + """Return a mock session for testing.""" + return SessionModel( + id=123, + created_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + start_time=datetime.datetime(2025, 1, 1, 0, 0, 0), + command_line="ra-aid test", + program_version="1.0.0", + machine_info={"agent_type": "research", "expert_enabled": True, "web_research_enabled": False} + ) + + +@pytest.fixture +def mock_thread(): + """Create a mock thread that does nothing when started.""" + mock = MagicMock() + mock.daemon = True + return mock + + +@pytest.fixture +def mock_repository(mock_session): + """Create a mock repository for testing.""" + mock_repo = MagicMock() + mock_repo.create_session.return_value = mock_session + return mock_repo + + +@pytest.fixture +def client(mock_repository, mock_thread, monkeypatch): + """Set up a test client with mocked dependencies.""" + # Create FastAPI app with router + app = FastAPI() + app.include_router(router) + + # Override the dependency to use our mock repository + app.dependency_overrides[get_repository] = lambda: mock_repository + + # Mock run_agent_thread to be a no-op + monkeypatch.setattr( + "ra_aid.server.api_v1_spawn_agent.run_agent_thread", + lambda *args, **kwargs: None + ) + + # Mock threading.Thread to return our mock thread + def mock_thread_constructor(*args, **kwargs): + mock_thread.target = kwargs.get('target') + mock_thread.args = kwargs.get('args') + mock_thread.daemon = kwargs.get('daemon', False) + return mock_thread + + monkeypatch.setattr( + ra_aid.server.api_v1_spawn_agent, + "threading", + MagicMock(Thread=mock_thread_constructor) + ) + + client = TestClient(app) + + # Add mocks to client for test access + client.mock_repo = mock_repository + client.mock_thread = mock_thread + + yield client + + # Clean up the dependency override + app.dependency_overrides.clear() + + +def test_spawn_agent(client, mock_repository, mock_thread): + """Test spawning an agent with valid parameters.""" + # Create the request payload + payload = { + "message": "Test task for the agent", + "research_only": False, + "expert_enabled": True, + "web_research_enabled": False + } + + # Send the request + response = client.post("/v1/spawn-agent", json=payload) + + # Verify response + assert response.status_code == 201 + assert response.json() == {"session_id": "123"} + + # Verify session creation + mock_repository.create_session.assert_called_once() + + # Verify thread was created with correct args + assert mock_thread.args == ("Test task for the agent", "123", False, True, False) + assert mock_thread.daemon is True + + # Verify thread.start was called + mock_thread.start.assert_called_once() + + +def test_spawn_agent_missing_message(client): + """Test spawning an agent with missing required message parameter.""" + # Create a request payload missing the required message + payload = { + "research_only": False, + "expert_enabled": True, + "web_research_enabled": False + } + + # Send the request + response = client.post("/v1/spawn-agent", json=payload) + + # Verify response indicates validation error + assert response.status_code == 422 + error_detail = response.json().get("detail", []) + assert any("message" in error.get("loc", []) for error in error_detail) \ No newline at end of file