diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index aa8ff75..d1813c5 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -102,9 +102,65 @@ if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debuggin def launch_server(host: str, port: int): """Launch the RA.Aid web interface.""" from ra_aid.server import run_server - + from ra_aid.database.connection import DatabaseManager + from ra_aid.database.repositories.session_repository import SessionRepositoryManager + from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager + from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager + from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager + from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager + from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager + from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager + from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager + from ra_aid.database.repositories.config_repository import ConfigRepositoryManager + from ra_aid.env_inv_context import EnvInvManager + from ra_aid.env_inv import EnvDiscovery + + # Apply any pending database migrations + from ra_aid.database import ensure_migrations_applied + try: + migration_result = ensure_migrations_applied() + if not migration_result: + logger.warning("Database migrations failed but execution will continue") + except Exception as e: + logger.error(f"Database migration error: {str(e)}") + + # Initialize empty config dictionary + config = {} + + # Initialize environment discovery + env_discovery = EnvDiscovery() + env_discovery.discover() + env_data = env_discovery.format_markdown() + print(f"Starting RA.Aid web interface on http://{host}:{port}") - run_server(host=host, port=port) + + # Initialize database connection and repositories + with DatabaseManager() as db, \ + SessionRepositoryManager(db) as session_repo, \ + KeyFactRepositoryManager(db) as key_fact_repo, \ + KeySnippetRepositoryManager(db) as key_snippet_repo, \ + HumanInputRepositoryManager(db) as human_input_repo, \ + ResearchNoteRepositoryManager(db) as research_note_repo, \ + RelatedFilesRepositoryManager() as related_files_repo, \ + TrajectoryRepositoryManager(db) as trajectory_repo, \ + WorkLogRepositoryManager() as work_log_repo, \ + ConfigRepositoryManager(config) as config_repo, \ + EnvInvManager(env_data) as env_inv: + + # This initializes all repositories and makes them available via their respective get methods + logger.debug("Initialized SessionRepository") + logger.debug("Initialized KeyFactRepository") + logger.debug("Initialized KeySnippetRepository") + logger.debug("Initialized HumanInputRepository") + logger.debug("Initialized ResearchNoteRepository") + logger.debug("Initialized RelatedFilesRepository") + logger.debug("Initialized TrajectoryRepository") + logger.debug("Initialized WorkLogRepository") + logger.debug("Initialized ConfigRepository") + logger.debug("Initialized Environment Inventory") + + # Run the server within the context managers + run_server(host=host, port=port) def parse_arguments(args=None): diff --git a/ra_aid/database/repositories/session_repository.py b/ra_aid/database/repositories/session_repository.py index 9996f0e..9ea9650 100644 --- a/ra_aid/database/repositories/session_repository.py +++ b/ra_aid/database/repositories/session_repository.py @@ -226,19 +226,33 @@ class SessionRepository: logger.error(f"Database error getting session {session_id}: {str(e)}") return None - def get_all(self) -> List[SessionModel]: + def get_all(self, offset: int = 0, limit: int = 10) -> tuple[List[SessionModel], int]: """ - Get all sessions from the database. + Get all sessions from the database with pagination support. + Args: + offset: Number of sessions to skip (default: 0) + limit: Maximum number of sessions to return (default: 10) + Returns: - List[SessionModel]: List of all sessions + tuple: (List[SessionModel], int) containing the list of sessions and the total count """ try: - sessions = list(Session.select().order_by(Session.created_at.desc())) - return [self._to_model(session) for session in sessions] + # Get total count for pagination info + total_count = Session.select().count() + + # Get paginated sessions ordered by created_at in descending order (newest first) + sessions = list( + Session.select() + .order_by(Session.created_at.desc()) + .offset(offset) + .limit(limit) + ) + + return [self._to_model(session) for session in sessions], total_count except peewee.DatabaseError as e: - logger.error(f"Failed to get all sessions: {str(e)}") - return [] + logger.error(f"Failed to get all sessions with pagination: {str(e)}") + return [], 0 def get_recent(self, limit: int = 10) -> List[SessionModel]: """ diff --git a/ra_aid/server/api_v1_sessions.py b/ra_aid/server/api_v1_sessions.py new file mode 100644 index 0000000..0993c12 --- /dev/null +++ b/ra_aid/server/api_v1_sessions.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +API v1 Session Endpoints. + +This module provides RESTful API endpoints for managing sessions. +It implements routes for creating, listing, and retrieving sessions +with proper validation and error handling. +""" + +from typing import List, Optional, Dict, Any +from fastapi import APIRouter, Depends, HTTPException, Query, status +import peewee +from pydantic import BaseModel, Field + +from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository +from ra_aid.database.pydantic_models import SessionModel + +# Create API router +router = APIRouter( + prefix="/v1/sessions", + tags=["sessions"], + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Session not found"}, + status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Database error"}, + }, +) + + +class PaginatedResponse(BaseModel): + """ + Pydantic model for paginated API responses. + + This model provides a standardized format for API responses that include + pagination, with a total count and the requested items. + + Attributes: + total: The total number of items available + items: List of items for the current page + limit: The limit parameter that was used + offset: The offset parameter that was used + """ + total: int + items: List[Any] + limit: int + offset: int + + +class CreateSessionRequest(BaseModel): + """ + Pydantic model for session creation requests. + + This model provides validation for creating new sessions. + + Attributes: + metadata: Optional dictionary of additional metadata to store with the session + """ + metadata: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional dictionary of additional metadata to store with the session" + ) + + +class PaginatedSessionResponse(PaginatedResponse): + """ + Pydantic model for paginated session responses. + + This model specializes the generic PaginatedResponse for SessionModel items. + + Attributes: + items: List of SessionModel items for the current page + """ + items: List[SessionModel] + + +# Dependency to get the session repository +def get_repository() -> SessionRepository: + """ + Get the SessionRepository instance. + + This function is used as a FastAPI dependency and can be overridden + in tests using dependency_overrides. + + Returns: + SessionRepository: The repository instance + """ + return get_session_repository() + + +@router.get( + "", + response_model=PaginatedSessionResponse, + summary="List sessions", + description="Get a paginated list of sessions", +) +async def list_sessions( + offset: int = Query(0, ge=0, description="Number of sessions to skip"), + limit: int = Query(10, ge=1, le=100, description="Maximum number of sessions to return"), + repo: SessionRepository = Depends(get_repository), +) -> PaginatedSessionResponse: + """ + Get a paginated list of sessions. + + Args: + offset: Number of sessions to skip (default: 0) + limit: Maximum number of sessions to return (default: 10) + repo: SessionRepository dependency injection + + Returns: + PaginatedSessionResponse: Response with paginated sessions + + Raises: + HTTPException: With a 500 status code if there's a database error + """ + try: + sessions, total = repo.get_all(offset=offset, limit=limit) + return PaginatedSessionResponse( + total=total, + items=sessions, + limit=limit, + offset=offset, + ) + except peewee.DatabaseError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Database error: {str(e)}", + ) + + +@router.get( + "/{session_id}", + response_model=SessionModel, + summary="Get session", + description="Get a specific session by ID", +) +async def get_session( + session_id: int, + repo: SessionRepository = Depends(get_repository), +) -> SessionModel: + """ + Get a specific session by ID. + + Args: + session_id: The ID of the session to retrieve + repo: SessionRepository dependency injection + + Returns: + SessionModel: The requested session + + Raises: + HTTPException: With a 404 status code if the session is not found + HTTPException: With a 500 status code if there's a database error + """ + try: + session = repo.get(session_id) + if not session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session with ID {session_id} not found", + ) + return session + except peewee.DatabaseError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Database error: {str(e)}", + ) + + +@router.post( + "", + response_model=SessionModel, + status_code=status.HTTP_201_CREATED, + summary="Create session", + description="Create a new session", +) +async def create_session( + request: Optional[CreateSessionRequest] = None, + repo: SessionRepository = Depends(get_repository), +) -> SessionModel: + """ + Create a new session. + + Args: + request: Optional request body with session metadata + repo: SessionRepository dependency injection + + Returns: + SessionModel: The newly created session + + Raises: + HTTPException: With a 500 status code if there's a database error + """ + try: + metadata = request.metadata if request else None + return repo.create_session(metadata=metadata) + except peewee.DatabaseError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Database error: {str(e)}", + ) \ No newline at end of file diff --git a/ra_aid/server/server.py b/ra_aid/server/server.py index 6291c49..a85fd04 100644 --- a/ra_aid/server/server.py +++ b/ra_aid/server/server.py @@ -33,7 +33,13 @@ from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -app = FastAPI() +from ra_aid.server.api_v1_sessions import router as sessions_router + +app = FastAPI( + title="RA.Aid API", + description="API for RA.Aid - AI Programming Assistant", + version="1.0.0", +) # Add CORS middleware app.add_middleware( @@ -44,6 +50,9 @@ app.add_middleware( allow_headers=["*"], ) +# Include API routers +app.include_router(sessions_router) + # Setup templates and static files directories CURRENT_DIR = Path(__file__).parent templates = Jinja2Templates(directory=CURRENT_DIR) diff --git a/tests/ra_aid/database/test_session_repository.py b/tests/ra_aid/database/test_session_repository.py index 2845395..6771e75 100644 --- a/tests/ra_aid/database/test_session_repository.py +++ b/tests/ra_aid/database/test_session_repository.py @@ -231,8 +231,11 @@ def test_get_all(setup_db): repo.create_session(metadata=metadata2) repo.create_session(metadata=metadata3) - # Get all sessions - sessions = repo.get_all() + # Get all sessions with default pagination + sessions, total_count = repo.get_all() + + # Verify total count + assert total_count == 3 # Verify we got a list of SessionModel objects assert len(sessions) == 3 @@ -249,6 +252,19 @@ def test_get_all(setup_db): assert "Linux" in os_values assert "Windows" in os_values assert "macOS" in os_values + + # Test pagination with limit + sessions_limited, total_count = repo.get_all(limit=2) + assert total_count == 3 # Total count should still be 3 + assert len(sessions_limited) == 2 # But only 2 returned + + # Test pagination with offset + sessions_offset, total_count = repo.get_all(offset=1, limit=2) + assert total_count == 3 + assert len(sessions_offset) == 2 + + # The second item in the full list should be the first item in the offset list + assert sessions[1].id == sessions_offset[0].id def test_get_all_empty(setup_db): @@ -257,11 +273,12 @@ def test_get_all_empty(setup_db): repo = SessionRepository(db=setup_db) # Get all sessions - sessions = repo.get_all() + sessions, total_count = repo.get_all() - # Verify we got an empty list + # Verify we got an empty list and zero count assert isinstance(sessions, list) assert len(sessions) == 0 + assert total_count == 0 def test_get_recent(setup_db): diff --git a/tests/ra_aid/server/test_api_v1_sessions.py b/tests/ra_aid/server/test_api_v1_sessions.py new file mode 100644 index 0000000..2bb67b0 --- /dev/null +++ b/tests/ra_aid/server/test_api_v1_sessions.py @@ -0,0 +1,132 @@ +""" +Tests for the Sessions API v1 endpoints. + +This module contains tests for the sessions API endpoints in ra_aid/server/api_v1_sessions.py. +It tests the creation, listing, and retrieval of sessions through the API. +""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock +import datetime + +from ra_aid.server.api_v1_sessions import router, get_repository +from ra_aid.database.pydantic_models import SessionModel + + +# Mock session data for testing +@pytest.fixture +def mock_session(): + """Return a mock session for testing.""" + return SessionModel( + id=1, + created_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + start_time=datetime.datetime(2025, 1, 1, 0, 0, 0), + command_line="ra-aid test", + program_version="1.0.0", + machine_info={"os": "test"} + ) + + +@pytest.fixture +def mock_sessions(): + """Return a list of mock sessions for testing.""" + return [ + SessionModel( + id=1, + created_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + start_time=datetime.datetime(2025, 1, 1, 0, 0, 0), + command_line="ra-aid test1", + program_version="1.0.0", + machine_info={"os": "test"} + ), + SessionModel( + id=2, + created_at=datetime.datetime(2025, 1, 2, 0, 0, 0), + updated_at=datetime.datetime(2025, 1, 2, 0, 0, 0), + start_time=datetime.datetime(2025, 1, 2, 0, 0, 0), + command_line="ra-aid test2", + program_version="1.0.0", + machine_info={"os": "test"} + ) + ] + + +@pytest.fixture +def mock_repo(mock_session, mock_sessions): + """Mock the SessionRepository for testing.""" + repo = MagicMock() + repo.get.return_value = mock_session + repo.get_all.return_value = (mock_sessions, len(mock_sessions)) + repo.create_session.return_value = mock_session + return repo + + +@pytest.fixture +def client(mock_repo): + """Return a TestClient for the API router with dependency override.""" + from fastapi import FastAPI + app = FastAPI() + app.include_router(router) + + # Override the dependency + app.dependency_overrides[get_repository] = lambda: mock_repo + + return TestClient(app) + + +def test_get_session(client, mock_repo, mock_session): + """Test getting a specific session by ID.""" + response = client.get("/v1/sessions/1") + + assert response.status_code == 200 + assert response.json()["id"] == mock_session.id + assert response.json()["command_line"] == mock_session.command_line + mock_repo.get.assert_called_once_with(1) + + +def test_get_session_not_found(client, mock_repo): + """Test getting a session that doesn't exist.""" + mock_repo.get.return_value = None + + response = client.get("/v1/sessions/999") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + mock_repo.get.assert_called_once_with(999) + + +def test_list_sessions(client, mock_repo, mock_sessions): + """Test listing sessions with pagination.""" + response = client.get("/v1/sessions?offset=0&limit=10") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == len(mock_sessions) + assert len(data["items"]) == len(mock_sessions) + assert data["limit"] == 10 + assert data["offset"] == 0 + mock_repo.get_all.assert_called_once_with(offset=0, limit=10) + + +def test_create_session(client, mock_repo, mock_session): + """Test creating a new session.""" + response = client.post( + "/v1/sessions", + json={"metadata": {"test": "data"}} + ) + + assert response.status_code == 201 + assert response.json()["id"] == mock_session.id + mock_repo.create_session.assert_called_once_with(metadata={"test": "data"}) + + +def test_create_session_no_body(client, mock_repo, mock_session): + """Test creating a new session without a request body.""" + response = client.post("/v1/sessions") + + assert response.status_code == 201 + assert response.json()["id"] == mock_session.id + mock_repo.create_session.assert_called_once_with(metadata=None) \ No newline at end of file diff --git a/tests/ra_aid/server/test_server.py b/tests/ra_aid/server/test_server.py new file mode 100644 index 0000000..d1fc322 --- /dev/null +++ b/tests/ra_aid/server/test_server.py @@ -0,0 +1,75 @@ +""" +Tests for server.py FastAPI application. + +This module tests the FastAPI application setup in server.py to ensure +that all routers are properly mounted and middleware is configured. +""" + +from unittest.mock import MagicMock, patch +import pytest +from fastapi.testclient import TestClient + +from ra_aid.server.server import app +from ra_aid.database.repositories.session_repository import session_repo_var + + +@pytest.fixture +def client(): + """Return a TestClient for the FastAPI app.""" + # Mock the session repository to avoid database dependency + mock_repo = MagicMock() + mock_repo.get_all.return_value = ([], 0) + + # Set the repository in the contextvar + token = session_repo_var.set(mock_repo) + + yield TestClient(app) + + # Reset the contextvar after the test + session_repo_var.reset(token) + + +def test_config_endpoint(client): + """Test that the config endpoint returns server configuration.""" + response = client.get("/config") + assert response.status_code == 200 + assert "host" in response.json() + assert "port" in response.json() + + +def test_api_documentation(client): + """Test that the OpenAPI documentation includes the sessions API.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + + openapi_spec = response.json() + assert "paths" in openapi_spec + + # Check that the sessions API paths are included + assert "/v1/sessions" in openapi_spec["paths"] + assert "/v1/sessions/{session_id}" in openapi_spec["paths"] + + # Verify that sessions API operations are documented + assert "get" in openapi_spec["paths"]["/v1/sessions"] + assert "post" in openapi_spec["paths"]["/v1/sessions"] + assert "get" in openapi_spec["paths"]["/v1/sessions/{session_id}"] + + +@patch("ra_aid.database.repositories.session_repository.get_session_repository") +def test_sessions_api_mounted(mock_get_repo, client): + """Test that the sessions API router is mounted correctly.""" + # Mock the repository for this specific test + mock_repo = MagicMock() + mock_repo.get_all.return_value = ([], 0) + mock_get_repo.return_value = mock_repo + + # Test that the sessions list endpoint is accessible + response = client.get("/v1/sessions") + assert response.status_code == 200 + + # Verify the response structure follows our expected format + data = response.json() + assert "total" in data + assert "items" in data + assert "limit" in data + assert "offset" in data \ No newline at end of file diff --git a/tests/ra_aid/server/test_sessions_api_integration.py b/tests/ra_aid/server/test_sessions_api_integration.py new file mode 100644 index 0000000..01c6b51 --- /dev/null +++ b/tests/ra_aid/server/test_sessions_api_integration.py @@ -0,0 +1,532 @@ +""" +Integration tests for the Sessions API endpoints. + +This module contains integration tests for the API endpoints defined in ra_aid/server/api_v1_sessions.py. +It uses mocks to simulate the database interactions while testing the real API behavior. +""" + +import pytest +import datetime +from typing import Dict, Any, List, Tuple +from unittest.mock import MagicMock, patch + +from fastapi.testclient import TestClient + +from ra_aid.server.server import app +from ra_aid.database.pydantic_models import SessionModel +from ra_aid.server.api_v1_sessions import get_repository + + +# Mock session data for testing +@pytest.fixture +def mock_session(): + """Return a mock session for testing.""" + return SessionModel( + id=1, + created_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + start_time=datetime.datetime(2025, 1, 1, 0, 0, 0), + command_line="ra-aid test", + program_version="1.0.0", + machine_info={"os": "test"} + ) + + +@pytest.fixture +def mock_sessions(): + """Return a list of mock sessions for testing.""" + return [ + SessionModel( + id=i+1, + created_at=datetime.datetime(2025, 1, i+1, 0, 0, 0), + updated_at=datetime.datetime(2025, 1, i+1, 0, 0, 0), + start_time=datetime.datetime(2025, 1, i+1, 0, 0, 0), + command_line=f"ra-aid test{i+1}", + program_version="1.0.0", + machine_info={"index": i} + ) + for i in range(15) + ] + + +@pytest.fixture +def mock_repo(mock_session, mock_sessions): + """Create a mock repository with predefined responses.""" + repo = MagicMock() + repo.get.return_value = mock_session + repo.get_all.return_value = (mock_sessions[:10], len(mock_sessions)) + repo.create_session.return_value = mock_session + + # Add behavior for custom parameters + def get_with_id(session_id): + if session_id == 999999: + return None + for session in mock_sessions: + if session.id == session_id: + return session + return mock_session + + def get_all_with_pagination(offset=0, limit=10): + total = len(mock_sessions) + sorted_sessions = sorted(mock_sessions, key=lambda s: s.id, reverse=True) + return sorted_sessions[offset:offset+limit], total + + def create_with_metadata(metadata=None): + if metadata is None: + return SessionModel( + id=16, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid test-null", + program_version="1.0.0", + machine_info=None + ) + return SessionModel( + id=16, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid test-custom", + program_version="1.0.0", + machine_info=metadata + ) + + repo.get.side_effect = get_with_id + repo.get_all.side_effect = get_all_with_pagination + repo.create_session.side_effect = create_with_metadata + + return repo + + +@pytest.fixture +def client(mock_repo): + """Create a TestClient with the API and dependency overrides.""" + # Override the dependency to use our mock repository + app.dependency_overrides[get_repository] = lambda: mock_repo + + # Create a test client + with TestClient(app) as test_client: + yield test_client + + # Clean up the dependency override + app.dependency_overrides.clear() + + +@pytest.fixture +def sample_metadata() -> Dict[str, Any]: + """Return sample metadata for session creation.""" + return { + "os": "Test OS", + "version": "1.0.0", + "environment": "test", + "cpu_cores": 4, + "memory_gb": 16, + "additional_info": { + "gpu": "Test GPU", + "display_resolution": "1920x1080" + } + } + + +def test_create_session_with_metadata(client, mock_repo, sample_metadata): + """Test creating a session with metadata through the API endpoint.""" + # Send request to create a session with metadata + response = client.post( + "/v1/sessions", + json={"metadata": sample_metadata} + ) + + # Verify response status code and structure + assert response.status_code == 201 + data = response.json() + + # Verify the session was created with the expected fields + assert data["id"] is not None + assert data["command_line"] is not None + assert data["program_version"] is not None + assert data["created_at"] is not None + assert data["updated_at"] is not None + assert data["start_time"] is not None + + # Verify metadata was passed correctly to the repository + mock_repo.create_session.assert_called_once_with(metadata=sample_metadata) + + +def test_create_session_without_metadata(client, mock_repo): + """Test creating a session without metadata through the API endpoint.""" + # Send request without a body + response = client.post("/v1/sessions") + + # Verify response status code and structure + assert response.status_code == 201 + data = response.json() + + # Verify the session was created with the expected fields + assert data["id"] is not None + assert data["command_line"] is not None + assert data["program_version"] is not None + + # Verify correct parameters were passed to the repository + mock_repo.create_session.assert_called_once_with(metadata=None) + + +def test_get_session_by_id(client): + """Test retrieving a session by ID through the API endpoint.""" + # Use a completely isolated, standalone test + + # For this test, let's focus on verifying the core functionality: + # 1. The API endpoint receives a request for a specific session ID + # 2. It calls the repository with that ID + # 3. It returns a properly formatted response + + mock_repo = MagicMock() + + # Create a test session with a simple machine_info to reduce serialization issues + test_session = SessionModel( + id=42, + created_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0), + start_time=datetime.datetime(2025, 1, 1, 0, 0, 0), + command_line="ra-aid specific-test", + program_version="1.0.0-test", + machine_info=None # Use None to avoid serialization issues + ) + + # Configure the mock + mock_repo.get.return_value = test_session + + # Override the dependency + app.dependency_overrides[get_repository] = lambda: mock_repo + + try: + # Retrieve the session through the API + response = client.get(f"/v1/sessions/{test_session.id}") + + # Verify response status code + assert response.status_code == 200 + + # Parse the response data + data = response.json() + + # Print for debugging + import json + print("Response JSON:", json.dumps(data, indent=2)) + + # Verify the returned session matches what we expected + assert data["id"] == test_session.id + assert data["command_line"] == test_session.command_line + assert data["program_version"] == test_session.program_version + assert data["machine_info"] is None + + # Verify the repository was called with the correct ID + mock_repo.get.assert_called_once_with(test_session.id) + finally: + # Clean up the override + if get_repository in app.dependency_overrides: + del app.dependency_overrides[get_repository] + + +def test_get_session_not_found(client, mock_repo): + """Test the error handling when requesting a non-existent session.""" + # Try to get a session with a non-existent ID + response = client.get("/v1/sessions/999999") + + # Verify response status code and error message + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + # Verify the repository was called with the correct ID + mock_repo.get.assert_called_with(999999) + + +def test_list_sessions_empty(client, mock_repo): + """Test listing sessions when no sessions exist.""" + # Reset the mock first to clear any previous calls/side effects + mock_repo.reset_mock() + + # Configure the mock to return empty results + mock_repo.get_all.side_effect = None # Clear any previous side effects + mock_repo.get_all.return_value = ([], 0) + + # Get the list of sessions + response = client.get("/v1/sessions") + + # Verify response status code and structure + assert response.status_code == 200 + data = response.json() + + # Verify the pagination response + assert data["total"] == 0 + assert len(data["items"]) == 0 + assert data["limit"] == 10 + assert data["offset"] == 0 + + # Verify the repository was called with the correct parameters + mock_repo.get_all.assert_called_with(offset=0, limit=10) + + +def test_list_sessions_with_pagination(client, mock_repo, mock_sessions): + """Test listing sessions with pagination parameters.""" + # Set up the repository mock to return specific results for different pagination parameters + default_result = (mock_sessions[:10], len(mock_sessions)) + limit_5_result = (mock_sessions[:5], len(mock_sessions)) + offset_10_result = (mock_sessions[10:], len(mock_sessions)) + offset_5_limit_3_result = (mock_sessions[5:8], len(mock_sessions)) + + pagination_responses = { + (0, 10): default_result, + (0, 5): limit_5_result, + (10, 10): offset_10_result, + (5, 3): offset_5_limit_3_result + } + + def mock_get_all(offset=0, limit=10): + return pagination_responses.get((offset, limit), ([], 0)) + + mock_repo.get_all.side_effect = mock_get_all + + # Test default pagination (limit=10, offset=0) + response = client.get("/v1/sessions") + assert response.status_code == 200 + data = response.json() + assert data["total"] == len(mock_sessions) + assert len(data["items"]) == 10 + assert data["limit"] == 10 + assert data["offset"] == 0 + + # Test with custom limit + response = client.get("/v1/sessions?limit=5") + assert response.status_code == 200 + data = response.json() + assert data["total"] == len(mock_sessions) + assert len(data["items"]) == 5 + assert data["limit"] == 5 + assert data["offset"] == 0 + + # Test with custom offset + response = client.get("/v1/sessions?offset=10") + assert response.status_code == 200 + data = response.json() + assert data["total"] == len(mock_sessions) + assert len(data["items"]) == 5 # Only 5 items left after offset 10 + assert data["limit"] == 10 + assert data["offset"] == 10 + + # Test with both custom limit and offset + response = client.get("/v1/sessions?limit=3&offset=5") + assert response.status_code == 200 + data = response.json() + assert data["total"] == len(mock_sessions) + assert len(data["items"]) == 3 + assert data["limit"] == 3 + assert data["offset"] == 5 + + +def test_list_sessions_invalid_parameters(client): + """Test error handling for invalid pagination parameters.""" + # Test with negative offset + response = client.get("/v1/sessions?offset=-1") + assert response.status_code == 422 + + # Test with negative limit + response = client.get("/v1/sessions?limit=-5") + assert response.status_code == 422 + + # Test with zero limit + response = client.get("/v1/sessions?limit=0") + assert response.status_code == 422 + + # Test with limit exceeding maximum + response = client.get("/v1/sessions?limit=101") + assert response.status_code == 422 + + +def test_metadata_validation(client, mock_repo): + """Test validation for different metadata formats in session creation.""" + # Create test sessions with different metadata + null_metadata_session = SessionModel( + id=20, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid test-null", + program_version="1.0.0", + machine_info=None + ) + + empty_dict_metadata_session = SessionModel( + id=21, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid test-empty", + program_version="1.0.0", + machine_info={} + ) + + complex_metadata_session = SessionModel( + id=22, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid test-complex", + program_version="1.0.0", + machine_info={"level1": {"level2": {"level3": [1, 2, 3, {"key": "value"}]}}} + ) + + # Configure mock to return different sessions based on metadata + def create_with_specific_metadata(metadata=None): + if metadata is None: + return null_metadata_session + elif metadata == {}: + return empty_dict_metadata_session + elif isinstance(metadata, dict) and "level1" in metadata: + return complex_metadata_session + return SessionModel( + id=23, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid test-other", + program_version="1.0.0", + machine_info=metadata + ) + + mock_repo.create_session.side_effect = create_with_specific_metadata + + # Try to create a session with null metadata + response = client.post( + "/v1/sessions", + json={"metadata": None} + ) + + # This should work fine + assert response.status_code == 201 + mock_repo.create_session.assert_called_with(metadata=None) + + # Try to create a session with an empty metadata dict + response = client.post( + "/v1/sessions", + json={"metadata": {}} + ) + + # This should work fine + assert response.status_code == 201 + mock_repo.create_session.assert_called_with(metadata={}) + + # Try to create a session with a complex nested metadata + response = client.post( + "/v1/sessions", + json={"metadata": { + "level1": { + "level2": { + "level3": [1, 2, 3, {"key": "value"}] + } + } + }} + ) + + # Verify the complex nested structure is preserved + assert response.status_code == 201 + complex_metadata = { + "level1": { + "level2": { + "level3": [1, 2, 3, {"key": "value"}] + } + } + } + mock_repo.create_session.assert_called_with(metadata=complex_metadata) + + +def test_integration_workflow(client, mock_repo): + """Test a complete workflow of creating and retrieving sessions.""" + # Set up mock sessions for the workflow + first_session = SessionModel( + id=30, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid workflow-1", + program_version="1.0.0", + machine_info={"workflow_test": True} + ) + + second_session = SessionModel( + id=31, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + start_time=datetime.datetime.now(), + command_line="ra-aid workflow-2", + program_version="1.0.0", + machine_info={"workflow_test": False, "second": True} + ) + + # Configure mock for create_session + create_calls = 0 + def create_session_for_workflow(metadata=None): + nonlocal create_calls + create_calls += 1 + if create_calls == 1: + return first_session + return second_session + + mock_repo.create_session.side_effect = create_session_for_workflow + + # Configure mock for get + def get_session_for_workflow(session_id): + if session_id == first_session.id: + return first_session + elif session_id == second_session.id: + return second_session + return None + + mock_repo.get.side_effect = get_session_for_workflow + + # Configure mock for get_all + def get_all_for_workflow(offset=0, limit=10): + if create_calls == 1: + return [first_session], 1 + return [second_session, first_session], 2 + + mock_repo.get_all.side_effect = get_all_for_workflow + + # 1. Create a session + create_response = client.post( + "/v1/sessions", + json={"metadata": {"workflow_test": True}} + ) + assert create_response.status_code == 201 + session_id = create_response.json()["id"] + assert session_id == first_session.id + + # 2. Retrieve the created session + get_response = client.get(f"/v1/sessions/{session_id}") + assert get_response.status_code == 200 + assert get_response.json()["id"] == session_id + + # 3. List all sessions and verify the created one is included + list_response = client.get("/v1/sessions") + assert list_response.status_code == 200 + items = list_response.json()["items"] + assert len(items) == 1 + assert items[0]["id"] == session_id + + # 4. Create a second session + create_response2 = client.post( + "/v1/sessions", + json={"metadata": {"workflow_test": False, "second": True}} + ) + assert create_response2.status_code == 201 + session_id2 = create_response2.json()["id"] + assert session_id2 == second_session.id + + # 5. List all sessions and verify both sessions are included + list_response = client.get("/v1/sessions") + assert list_response.status_code == 200 + data = list_response.json() + assert data["total"] == 2 + items = data["items"] + assert len(items) == 2 + assert items[0]["id"] == session_id2 + assert items[1]["id"] == session_id \ No newline at end of file