session API endpoint
This commit is contained in:
parent
77cfbdeca7
commit
c18c4dbd22
|
|
@ -102,9 +102,65 @@ if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debuggin
|
|||
def launch_server(host: str, port: int):
|
||||
"""Launch the RA.Aid web interface."""
|
||||
from ra_aid.server import run_server
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager
|
||||
from ra_aid.database.repositories.session_repository import SessionRepositoryManager
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager
|
||||
from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager
|
||||
from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager
|
||||
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||
from ra_aid.env_inv_context import EnvInvManager
|
||||
from ra_aid.env_inv import EnvDiscovery
|
||||
|
||||
# Apply any pending database migrations
|
||||
from ra_aid.database import ensure_migrations_applied
|
||||
try:
|
||||
migration_result = ensure_migrations_applied()
|
||||
if not migration_result:
|
||||
logger.warning("Database migrations failed but execution will continue")
|
||||
except Exception as e:
|
||||
logger.error(f"Database migration error: {str(e)}")
|
||||
|
||||
# Initialize empty config dictionary
|
||||
config = {}
|
||||
|
||||
# Initialize environment discovery
|
||||
env_discovery = EnvDiscovery()
|
||||
env_discovery.discover()
|
||||
env_data = env_discovery.format_markdown()
|
||||
|
||||
print(f"Starting RA.Aid web interface on http://{host}:{port}")
|
||||
run_server(host=host, port=port)
|
||||
|
||||
# Initialize database connection and repositories
|
||||
with DatabaseManager() as db, \
|
||||
SessionRepositoryManager(db) as session_repo, \
|
||||
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||
WorkLogRepositoryManager() as work_log_repo, \
|
||||
ConfigRepositoryManager(config) as config_repo, \
|
||||
EnvInvManager(env_data) as env_inv:
|
||||
|
||||
# This initializes all repositories and makes them available via their respective get methods
|
||||
logger.debug("Initialized SessionRepository")
|
||||
logger.debug("Initialized KeyFactRepository")
|
||||
logger.debug("Initialized KeySnippetRepository")
|
||||
logger.debug("Initialized HumanInputRepository")
|
||||
logger.debug("Initialized ResearchNoteRepository")
|
||||
logger.debug("Initialized RelatedFilesRepository")
|
||||
logger.debug("Initialized TrajectoryRepository")
|
||||
logger.debug("Initialized WorkLogRepository")
|
||||
logger.debug("Initialized ConfigRepository")
|
||||
logger.debug("Initialized Environment Inventory")
|
||||
|
||||
# Run the server within the context managers
|
||||
run_server(host=host, port=port)
|
||||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
|
|
|
|||
|
|
@ -226,19 +226,33 @@ class SessionRepository:
|
|||
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_all(self) -> List[SessionModel]:
|
||||
def get_all(self, offset: int = 0, limit: int = 10) -> tuple[List[SessionModel], int]:
|
||||
"""
|
||||
Get all sessions from the database.
|
||||
Get all sessions from the database with pagination support.
|
||||
|
||||
Args:
|
||||
offset: Number of sessions to skip (default: 0)
|
||||
limit: Maximum number of sessions to return (default: 10)
|
||||
|
||||
Returns:
|
||||
List[SessionModel]: List of all sessions
|
||||
tuple: (List[SessionModel], int) containing the list of sessions and the total count
|
||||
"""
|
||||
try:
|
||||
sessions = list(Session.select().order_by(Session.created_at.desc()))
|
||||
return [self._to_model(session) for session in sessions]
|
||||
# Get total count for pagination info
|
||||
total_count = Session.select().count()
|
||||
|
||||
# Get paginated sessions ordered by created_at in descending order (newest first)
|
||||
sessions = list(
|
||||
Session.select()
|
||||
.order_by(Session.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return [self._to_model(session) for session in sessions], total_count
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get all sessions: {str(e)}")
|
||||
return []
|
||||
logger.error(f"Failed to get all sessions with pagination: {str(e)}")
|
||||
return [], 0
|
||||
|
||||
def get_recent(self, limit: int = 10) -> List[SessionModel]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,200 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
API v1 Session Endpoints.
|
||||
|
||||
This module provides RESTful API endpoints for managing sessions.
|
||||
It implements routes for creating, listing, and retrieving sessions
|
||||
with proper validation and error handling.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
import peewee
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
# Create API router
|
||||
router = APIRouter(
|
||||
prefix="/v1/sessions",
|
||||
tags=["sessions"],
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "Session not found"},
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"},
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Database error"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""
|
||||
Pydantic model for paginated API responses.
|
||||
|
||||
This model provides a standardized format for API responses that include
|
||||
pagination, with a total count and the requested items.
|
||||
|
||||
Attributes:
|
||||
total: The total number of items available
|
||||
items: List of items for the current page
|
||||
limit: The limit parameter that was used
|
||||
offset: The offset parameter that was used
|
||||
"""
|
||||
total: int
|
||||
items: List[Any]
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""
|
||||
Pydantic model for session creation requests.
|
||||
|
||||
This model provides validation for creating new sessions.
|
||||
|
||||
Attributes:
|
||||
metadata: Optional dictionary of additional metadata to store with the session
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional dictionary of additional metadata to store with the session"
|
||||
)
|
||||
|
||||
|
||||
class PaginatedSessionResponse(PaginatedResponse):
|
||||
"""
|
||||
Pydantic model for paginated session responses.
|
||||
|
||||
This model specializes the generic PaginatedResponse for SessionModel items.
|
||||
|
||||
Attributes:
|
||||
items: List of SessionModel items for the current page
|
||||
"""
|
||||
items: List[SessionModel]
|
||||
|
||||
|
||||
# Dependency to get the session repository
|
||||
def get_repository() -> SessionRepository:
|
||||
"""
|
||||
Get the SessionRepository instance.
|
||||
|
||||
This function is used as a FastAPI dependency and can be overridden
|
||||
in tests using dependency_overrides.
|
||||
|
||||
Returns:
|
||||
SessionRepository: The repository instance
|
||||
"""
|
||||
return get_session_repository()
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedSessionResponse,
|
||||
summary="List sessions",
|
||||
description="Get a paginated list of sessions",
|
||||
)
|
||||
async def list_sessions(
|
||||
offset: int = Query(0, ge=0, description="Number of sessions to skip"),
|
||||
limit: int = Query(10, ge=1, le=100, description="Maximum number of sessions to return"),
|
||||
repo: SessionRepository = Depends(get_repository),
|
||||
) -> PaginatedSessionResponse:
|
||||
"""
|
||||
Get a paginated list of sessions.
|
||||
|
||||
Args:
|
||||
offset: Number of sessions to skip (default: 0)
|
||||
limit: Maximum number of sessions to return (default: 10)
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
PaginatedSessionResponse: Response with paginated sessions
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 500 status code if there's a database error
|
||||
"""
|
||||
try:
|
||||
sessions, total = repo.get_all(offset=offset, limit=limit)
|
||||
return PaginatedSessionResponse(
|
||||
total=total,
|
||||
items=sessions,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
except peewee.DatabaseError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{session_id}",
|
||||
response_model=SessionModel,
|
||||
summary="Get session",
|
||||
description="Get a specific session by ID",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: int,
|
||||
repo: SessionRepository = Depends(get_repository),
|
||||
) -> SessionModel:
|
||||
"""
|
||||
Get a specific session by ID.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to retrieve
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
SessionModel: The requested session
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 404 status code if the session is not found
|
||||
HTTPException: With a 500 status code if there's a database error
|
||||
"""
|
||||
try:
|
||||
session = repo.get(session_id)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Session with ID {session_id} not found",
|
||||
)
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=SessionModel,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create session",
|
||||
description="Create a new session",
|
||||
)
|
||||
async def create_session(
|
||||
request: Optional[CreateSessionRequest] = None,
|
||||
repo: SessionRepository = Depends(get_repository),
|
||||
) -> SessionModel:
|
||||
"""
|
||||
Create a new session.
|
||||
|
||||
Args:
|
||||
request: Optional request body with session metadata
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
SessionModel: The newly created session
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 500 status code if there's a database error
|
||||
"""
|
||||
try:
|
||||
metadata = request.metadata if request else None
|
||||
return repo.create_session(metadata=metadata)
|
||||
except peewee.DatabaseError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error: {str(e)}",
|
||||
)
|
||||
|
|
@ -33,7 +33,13 @@ from fastapi.responses import HTMLResponse
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
app = FastAPI()
|
||||
from ra_aid.server.api_v1_sessions import router as sessions_router
|
||||
|
||||
app = FastAPI(
|
||||
title="RA.Aid API",
|
||||
description="API for RA.Aid - AI Programming Assistant",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
|
|
@ -44,6 +50,9 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routers
|
||||
app.include_router(sessions_router)
|
||||
|
||||
# Setup templates and static files directories
|
||||
CURRENT_DIR = Path(__file__).parent
|
||||
templates = Jinja2Templates(directory=CURRENT_DIR)
|
||||
|
|
|
|||
|
|
@ -231,8 +231,11 @@ def test_get_all(setup_db):
|
|||
repo.create_session(metadata=metadata2)
|
||||
repo.create_session(metadata=metadata3)
|
||||
|
||||
# Get all sessions
|
||||
sessions = repo.get_all()
|
||||
# Get all sessions with default pagination
|
||||
sessions, total_count = repo.get_all()
|
||||
|
||||
# Verify total count
|
||||
assert total_count == 3
|
||||
|
||||
# Verify we got a list of SessionModel objects
|
||||
assert len(sessions) == 3
|
||||
|
|
@ -249,6 +252,19 @@ def test_get_all(setup_db):
|
|||
assert "Linux" in os_values
|
||||
assert "Windows" in os_values
|
||||
assert "macOS" in os_values
|
||||
|
||||
# Test pagination with limit
|
||||
sessions_limited, total_count = repo.get_all(limit=2)
|
||||
assert total_count == 3 # Total count should still be 3
|
||||
assert len(sessions_limited) == 2 # But only 2 returned
|
||||
|
||||
# Test pagination with offset
|
||||
sessions_offset, total_count = repo.get_all(offset=1, limit=2)
|
||||
assert total_count == 3
|
||||
assert len(sessions_offset) == 2
|
||||
|
||||
# The second item in the full list should be the first item in the offset list
|
||||
assert sessions[1].id == sessions_offset[0].id
|
||||
|
||||
|
||||
def test_get_all_empty(setup_db):
|
||||
|
|
@ -257,11 +273,12 @@ def test_get_all_empty(setup_db):
|
|||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Get all sessions
|
||||
sessions = repo.get_all()
|
||||
sessions, total_count = repo.get_all()
|
||||
|
||||
# Verify we got an empty list
|
||||
# Verify we got an empty list and zero count
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) == 0
|
||||
assert total_count == 0
|
||||
|
||||
|
||||
def test_get_recent(setup_db):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Tests for the Sessions API v1 endpoints.
|
||||
|
||||
This module contains tests for the sessions API endpoints in ra_aid/server/api_v1_sessions.py.
|
||||
It tests the creation, listing, and retrieval of sessions through the API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
import datetime
|
||||
|
||||
from ra_aid.server.api_v1_sessions import router, get_repository
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
|
||||
# Mock session data for testing
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Return a mock session for testing."""
|
||||
return SessionModel(
|
||||
id=1,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions():
|
||||
"""Return a list of mock sessions for testing."""
|
||||
return [
|
||||
SessionModel(
|
||||
id=1,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test1",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
),
|
||||
SessionModel(
|
||||
id=2,
|
||||
created_at=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||
command_line="ra-aid test2",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repo(mock_session, mock_sessions):
|
||||
"""Mock the SessionRepository for testing."""
|
||||
repo = MagicMock()
|
||||
repo.get.return_value = mock_session
|
||||
repo.get_all.return_value = (mock_sessions, len(mock_sessions))
|
||||
repo.create_session.return_value = mock_session
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repo):
|
||||
"""Return a TestClient for the API router with dependency override."""
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_get_session(client, mock_repo, mock_session):
|
||||
"""Test getting a specific session by ID."""
|
||||
response = client.get("/v1/sessions/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == mock_session.id
|
||||
assert response.json()["command_line"] == mock_session.command_line
|
||||
mock_repo.get.assert_called_once_with(1)
|
||||
|
||||
|
||||
def test_get_session_not_found(client, mock_repo):
|
||||
"""Test getting a session that doesn't exist."""
|
||||
mock_repo.get.return_value = None
|
||||
|
||||
response = client.get("/v1/sessions/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
mock_repo.get.assert_called_once_with(999)
|
||||
|
||||
|
||||
def test_list_sessions(client, mock_repo, mock_sessions):
|
||||
"""Test listing sessions with pagination."""
|
||||
response = client.get("/v1/sessions?offset=0&limit=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == len(mock_sessions)
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 0
|
||||
mock_repo.get_all.assert_called_once_with(offset=0, limit=10)
|
||||
|
||||
|
||||
def test_create_session(client, mock_repo, mock_session):
|
||||
"""Test creating a new session."""
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {"test": "data"}}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["id"] == mock_session.id
|
||||
mock_repo.create_session.assert_called_once_with(metadata={"test": "data"})
|
||||
|
||||
|
||||
def test_create_session_no_body(client, mock_repo, mock_session):
|
||||
"""Test creating a new session without a request body."""
|
||||
response = client.post("/v1/sessions")
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["id"] == mock_session.id
|
||||
mock_repo.create_session.assert_called_once_with(metadata=None)
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
Tests for server.py FastAPI application.
|
||||
|
||||
This module tests the FastAPI application setup in server.py to ensure
|
||||
that all routers are properly mounted and middleware is configured.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ra_aid.server.server import app
|
||||
from ra_aid.database.repositories.session_repository import session_repo_var
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Return a TestClient for the FastAPI app."""
|
||||
# Mock the session repository to avoid database dependency
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_all.return_value = ([], 0)
|
||||
|
||||
# Set the repository in the contextvar
|
||||
token = session_repo_var.set(mock_repo)
|
||||
|
||||
yield TestClient(app)
|
||||
|
||||
# Reset the contextvar after the test
|
||||
session_repo_var.reset(token)
|
||||
|
||||
|
||||
def test_config_endpoint(client):
|
||||
"""Test that the config endpoint returns server configuration."""
|
||||
response = client.get("/config")
|
||||
assert response.status_code == 200
|
||||
assert "host" in response.json()
|
||||
assert "port" in response.json()
|
||||
|
||||
|
||||
def test_api_documentation(client):
|
||||
"""Test that the OpenAPI documentation includes the sessions API."""
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
|
||||
openapi_spec = response.json()
|
||||
assert "paths" in openapi_spec
|
||||
|
||||
# Check that the sessions API paths are included
|
||||
assert "/v1/sessions" in openapi_spec["paths"]
|
||||
assert "/v1/sessions/{session_id}" in openapi_spec["paths"]
|
||||
|
||||
# Verify that sessions API operations are documented
|
||||
assert "get" in openapi_spec["paths"]["/v1/sessions"]
|
||||
assert "post" in openapi_spec["paths"]["/v1/sessions"]
|
||||
assert "get" in openapi_spec["paths"]["/v1/sessions/{session_id}"]
|
||||
|
||||
|
||||
@patch("ra_aid.database.repositories.session_repository.get_session_repository")
|
||||
def test_sessions_api_mounted(mock_get_repo, client):
|
||||
"""Test that the sessions API router is mounted correctly."""
|
||||
# Mock the repository for this specific test
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_all.return_value = ([], 0)
|
||||
mock_get_repo.return_value = mock_repo
|
||||
|
||||
# Test that the sessions list endpoint is accessible
|
||||
response = client.get("/v1/sessions")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the response structure follows our expected format
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "items" in data
|
||||
assert "limit" in data
|
||||
assert "offset" in data
|
||||
|
|
@ -0,0 +1,532 @@
|
|||
"""
|
||||
Integration tests for the Sessions API endpoints.
|
||||
|
||||
This module contains integration tests for the API endpoints defined in ra_aid/server/api_v1_sessions.py.
|
||||
It uses mocks to simulate the database interactions while testing the real API behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ra_aid.server.server import app
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
from ra_aid.server.api_v1_sessions import get_repository
|
||||
|
||||
|
||||
# Mock session data for testing
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Return a mock session for testing."""
|
||||
return SessionModel(
|
||||
id=1,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions():
|
||||
"""Return a list of mock sessions for testing."""
|
||||
return [
|
||||
SessionModel(
|
||||
id=i+1,
|
||||
created_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||
command_line=f"ra-aid test{i+1}",
|
||||
program_version="1.0.0",
|
||||
machine_info={"index": i}
|
||||
)
|
||||
for i in range(15)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repo(mock_session, mock_sessions):
|
||||
"""Create a mock repository with predefined responses."""
|
||||
repo = MagicMock()
|
||||
repo.get.return_value = mock_session
|
||||
repo.get_all.return_value = (mock_sessions[:10], len(mock_sessions))
|
||||
repo.create_session.return_value = mock_session
|
||||
|
||||
# Add behavior for custom parameters
|
||||
def get_with_id(session_id):
|
||||
if session_id == 999999:
|
||||
return None
|
||||
for session in mock_sessions:
|
||||
if session.id == session_id:
|
||||
return session
|
||||
return mock_session
|
||||
|
||||
def get_all_with_pagination(offset=0, limit=10):
|
||||
total = len(mock_sessions)
|
||||
sorted_sessions = sorted(mock_sessions, key=lambda s: s.id, reverse=True)
|
||||
return sorted_sessions[offset:offset+limit], total
|
||||
|
||||
def create_with_metadata(metadata=None):
|
||||
if metadata is None:
|
||||
return SessionModel(
|
||||
id=16,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-null",
|
||||
program_version="1.0.0",
|
||||
machine_info=None
|
||||
)
|
||||
return SessionModel(
|
||||
id=16,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-custom",
|
||||
program_version="1.0.0",
|
||||
machine_info=metadata
|
||||
)
|
||||
|
||||
repo.get.side_effect = get_with_id
|
||||
repo.get_all.side_effect = get_all_with_pagination
|
||||
repo.create_session.side_effect = create_with_metadata
|
||||
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repo):
|
||||
"""Create a TestClient with the API and dependency overrides."""
|
||||
# Override the dependency to use our mock repository
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||
|
||||
# Create a test client
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
# Clean up the dependency override
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> Dict[str, Any]:
|
||||
"""Return sample metadata for session creation."""
|
||||
return {
|
||||
"os": "Test OS",
|
||||
"version": "1.0.0",
|
||||
"environment": "test",
|
||||
"cpu_cores": 4,
|
||||
"memory_gb": 16,
|
||||
"additional_info": {
|
||||
"gpu": "Test GPU",
|
||||
"display_resolution": "1920x1080"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_create_session_with_metadata(client, mock_repo, sample_metadata):
|
||||
"""Test creating a session with metadata through the API endpoint."""
|
||||
# Send request to create a session with metadata
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": sample_metadata}
|
||||
)
|
||||
|
||||
# Verify response status code and structure
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
# Verify the session was created with the expected fields
|
||||
assert data["id"] is not None
|
||||
assert data["command_line"] is not None
|
||||
assert data["program_version"] is not None
|
||||
assert data["created_at"] is not None
|
||||
assert data["updated_at"] is not None
|
||||
assert data["start_time"] is not None
|
||||
|
||||
# Verify metadata was passed correctly to the repository
|
||||
mock_repo.create_session.assert_called_once_with(metadata=sample_metadata)
|
||||
|
||||
|
||||
def test_create_session_without_metadata(client, mock_repo):
|
||||
"""Test creating a session without metadata through the API endpoint."""
|
||||
# Send request without a body
|
||||
response = client.post("/v1/sessions")
|
||||
|
||||
# Verify response status code and structure
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
# Verify the session was created with the expected fields
|
||||
assert data["id"] is not None
|
||||
assert data["command_line"] is not None
|
||||
assert data["program_version"] is not None
|
||||
|
||||
# Verify correct parameters were passed to the repository
|
||||
mock_repo.create_session.assert_called_once_with(metadata=None)
|
||||
|
||||
|
||||
def test_get_session_by_id(client):
|
||||
"""Test retrieving a session by ID through the API endpoint."""
|
||||
# Use a completely isolated, standalone test
|
||||
|
||||
# For this test, let's focus on verifying the core functionality:
|
||||
# 1. The API endpoint receives a request for a specific session ID
|
||||
# 2. It calls the repository with that ID
|
||||
# 3. It returns a properly formatted response
|
||||
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Create a test session with a simple machine_info to reduce serialization issues
|
||||
test_session = SessionModel(
|
||||
id=42,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid specific-test",
|
||||
program_version="1.0.0-test",
|
||||
machine_info=None # Use None to avoid serialization issues
|
||||
)
|
||||
|
||||
# Configure the mock
|
||||
mock_repo.get.return_value = test_session
|
||||
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||
|
||||
try:
|
||||
# Retrieve the session through the API
|
||||
response = client.get(f"/v1/sessions/{test_session.id}")
|
||||
|
||||
# Verify response status code
|
||||
assert response.status_code == 200
|
||||
|
||||
# Parse the response data
|
||||
data = response.json()
|
||||
|
||||
# Print for debugging
|
||||
import json
|
||||
print("Response JSON:", json.dumps(data, indent=2))
|
||||
|
||||
# Verify the returned session matches what we expected
|
||||
assert data["id"] == test_session.id
|
||||
assert data["command_line"] == test_session.command_line
|
||||
assert data["program_version"] == test_session.program_version
|
||||
assert data["machine_info"] is None
|
||||
|
||||
# Verify the repository was called with the correct ID
|
||||
mock_repo.get.assert_called_once_with(test_session.id)
|
||||
finally:
|
||||
# Clean up the override
|
||||
if get_repository in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_repository]
|
||||
|
||||
|
||||
def test_get_session_not_found(client, mock_repo):
|
||||
"""Test the error handling when requesting a non-existent session."""
|
||||
# Try to get a session with a non-existent ID
|
||||
response = client.get("/v1/sessions/999999")
|
||||
|
||||
# Verify response status code and error message
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
# Verify the repository was called with the correct ID
|
||||
mock_repo.get.assert_called_with(999999)
|
||||
|
||||
|
||||
def test_list_sessions_empty(client, mock_repo):
|
||||
"""Test listing sessions when no sessions exist."""
|
||||
# Reset the mock first to clear any previous calls/side effects
|
||||
mock_repo.reset_mock()
|
||||
|
||||
# Configure the mock to return empty results
|
||||
mock_repo.get_all.side_effect = None # Clear any previous side effects
|
||||
mock_repo.get_all.return_value = ([], 0)
|
||||
|
||||
# Get the list of sessions
|
||||
response = client.get("/v1/sessions")
|
||||
|
||||
# Verify response status code and structure
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify the pagination response
|
||||
assert data["total"] == 0
|
||||
assert len(data["items"]) == 0
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 0
|
||||
|
||||
# Verify the repository was called with the correct parameters
|
||||
mock_repo.get_all.assert_called_with(offset=0, limit=10)
|
||||
|
||||
|
||||
def test_list_sessions_with_pagination(client, mock_repo, mock_sessions):
|
||||
"""Test listing sessions with pagination parameters."""
|
||||
# Set up the repository mock to return specific results for different pagination parameters
|
||||
default_result = (mock_sessions[:10], len(mock_sessions))
|
||||
limit_5_result = (mock_sessions[:5], len(mock_sessions))
|
||||
offset_10_result = (mock_sessions[10:], len(mock_sessions))
|
||||
offset_5_limit_3_result = (mock_sessions[5:8], len(mock_sessions))
|
||||
|
||||
pagination_responses = {
|
||||
(0, 10): default_result,
|
||||
(0, 5): limit_5_result,
|
||||
(10, 10): offset_10_result,
|
||||
(5, 3): offset_5_limit_3_result
|
||||
}
|
||||
|
||||
def mock_get_all(offset=0, limit=10):
|
||||
return pagination_responses.get((offset, limit), ([], 0))
|
||||
|
||||
mock_repo.get_all.side_effect = mock_get_all
|
||||
|
||||
# Test default pagination (limit=10, offset=0)
|
||||
response = client.get("/v1/sessions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 10
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 0
|
||||
|
||||
# Test with custom limit
|
||||
response = client.get("/v1/sessions?limit=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 5
|
||||
assert data["limit"] == 5
|
||||
assert data["offset"] == 0
|
||||
|
||||
# Test with custom offset
|
||||
response = client.get("/v1/sessions?offset=10")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 5 # Only 5 items left after offset 10
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 10
|
||||
|
||||
# Test with both custom limit and offset
|
||||
response = client.get("/v1/sessions?limit=3&offset=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 3
|
||||
assert data["limit"] == 3
|
||||
assert data["offset"] == 5
|
||||
|
||||
|
||||
def test_list_sessions_invalid_parameters(client):
|
||||
"""Test error handling for invalid pagination parameters."""
|
||||
# Test with negative offset
|
||||
response = client.get("/v1/sessions?offset=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with negative limit
|
||||
response = client.get("/v1/sessions?limit=-5")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with zero limit
|
||||
response = client.get("/v1/sessions?limit=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with limit exceeding maximum
|
||||
response = client.get("/v1/sessions?limit=101")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_metadata_validation(client, mock_repo):
|
||||
"""Test validation for different metadata formats in session creation."""
|
||||
# Create test sessions with different metadata
|
||||
null_metadata_session = SessionModel(
|
||||
id=20,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-null",
|
||||
program_version="1.0.0",
|
||||
machine_info=None
|
||||
)
|
||||
|
||||
empty_dict_metadata_session = SessionModel(
|
||||
id=21,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-empty",
|
||||
program_version="1.0.0",
|
||||
machine_info={}
|
||||
)
|
||||
|
||||
complex_metadata_session = SessionModel(
|
||||
id=22,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-complex",
|
||||
program_version="1.0.0",
|
||||
machine_info={"level1": {"level2": {"level3": [1, 2, 3, {"key": "value"}]}}}
|
||||
)
|
||||
|
||||
# Configure mock to return different sessions based on metadata
|
||||
def create_with_specific_metadata(metadata=None):
|
||||
if metadata is None:
|
||||
return null_metadata_session
|
||||
elif metadata == {}:
|
||||
return empty_dict_metadata_session
|
||||
elif isinstance(metadata, dict) and "level1" in metadata:
|
||||
return complex_metadata_session
|
||||
return SessionModel(
|
||||
id=23,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-other",
|
||||
program_version="1.0.0",
|
||||
machine_info=metadata
|
||||
)
|
||||
|
||||
mock_repo.create_session.side_effect = create_with_specific_metadata
|
||||
|
||||
# Try to create a session with null metadata
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": None}
|
||||
)
|
||||
|
||||
# This should work fine
|
||||
assert response.status_code == 201
|
||||
mock_repo.create_session.assert_called_with(metadata=None)
|
||||
|
||||
# Try to create a session with an empty metadata dict
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {}}
|
||||
)
|
||||
|
||||
# This should work fine
|
||||
assert response.status_code == 201
|
||||
mock_repo.create_session.assert_called_with(metadata={})
|
||||
|
||||
# Try to create a session with a complex nested metadata
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": [1, 2, 3, {"key": "value"}]
|
||||
}
|
||||
}
|
||||
}}
|
||||
)
|
||||
|
||||
# Verify the complex nested structure is preserved
|
||||
assert response.status_code == 201
|
||||
complex_metadata = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": [1, 2, 3, {"key": "value"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_repo.create_session.assert_called_with(metadata=complex_metadata)
|
||||
|
||||
|
||||
def test_integration_workflow(client, mock_repo):
|
||||
"""Test a complete workflow of creating and retrieving sessions."""
|
||||
# Set up mock sessions for the workflow
|
||||
first_session = SessionModel(
|
||||
id=30,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid workflow-1",
|
||||
program_version="1.0.0",
|
||||
machine_info={"workflow_test": True}
|
||||
)
|
||||
|
||||
second_session = SessionModel(
|
||||
id=31,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid workflow-2",
|
||||
program_version="1.0.0",
|
||||
machine_info={"workflow_test": False, "second": True}
|
||||
)
|
||||
|
||||
# Configure mock for create_session
|
||||
create_calls = 0
|
||||
def create_session_for_workflow(metadata=None):
|
||||
nonlocal create_calls
|
||||
create_calls += 1
|
||||
if create_calls == 1:
|
||||
return first_session
|
||||
return second_session
|
||||
|
||||
mock_repo.create_session.side_effect = create_session_for_workflow
|
||||
|
||||
# Configure mock for get
|
||||
def get_session_for_workflow(session_id):
|
||||
if session_id == first_session.id:
|
||||
return first_session
|
||||
elif session_id == second_session.id:
|
||||
return second_session
|
||||
return None
|
||||
|
||||
mock_repo.get.side_effect = get_session_for_workflow
|
||||
|
||||
# Configure mock for get_all
|
||||
def get_all_for_workflow(offset=0, limit=10):
|
||||
if create_calls == 1:
|
||||
return [first_session], 1
|
||||
return [second_session, first_session], 2
|
||||
|
||||
mock_repo.get_all.side_effect = get_all_for_workflow
|
||||
|
||||
# 1. Create a session
|
||||
create_response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {"workflow_test": True}}
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
session_id = create_response.json()["id"]
|
||||
assert session_id == first_session.id
|
||||
|
||||
# 2. Retrieve the created session
|
||||
get_response = client.get(f"/v1/sessions/{session_id}")
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["id"] == session_id
|
||||
|
||||
# 3. List all sessions and verify the created one is included
|
||||
list_response = client.get("/v1/sessions")
|
||||
assert list_response.status_code == 200
|
||||
items = list_response.json()["items"]
|
||||
assert len(items) == 1
|
||||
assert items[0]["id"] == session_id
|
||||
|
||||
# 4. Create a second session
|
||||
create_response2 = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {"workflow_test": False, "second": True}}
|
||||
)
|
||||
assert create_response2.status_code == 201
|
||||
session_id2 = create_response2.json()["id"]
|
||||
assert session_id2 == second_session.id
|
||||
|
||||
# 5. List all sessions and verify both sessions are included
|
||||
list_response = client.get("/v1/sessions")
|
||||
assert list_response.status_code == 200
|
||||
data = list_response.json()
|
||||
assert data["total"] == 2
|
||||
items = data["items"]
|
||||
assert len(items) == 2
|
||||
assert items[0]["id"] == session_id2
|
||||
assert items[1]["id"] == session_id
|
||||
Loading…
Reference in New Issue