session API endpoint
This commit is contained in:
parent
77cfbdeca7
commit
c18c4dbd22
|
|
@ -102,9 +102,65 @@ if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debuggin
|
||||||
def launch_server(host: str, port: int):
|
def launch_server(host: str, port: int):
|
||||||
"""Launch the RA.Aid web interface."""
|
"""Launch the RA.Aid web interface."""
|
||||||
from ra_aid.server import run_server
|
from ra_aid.server import run_server
|
||||||
|
from ra_aid.database.connection import DatabaseManager
|
||||||
|
from ra_aid.database.repositories.session_repository import SessionRepositoryManager
|
||||||
|
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager
|
||||||
|
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager
|
||||||
|
from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager
|
||||||
|
from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager
|
||||||
|
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||||
|
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||||
|
from ra_aid.env_inv_context import EnvInvManager
|
||||||
|
from ra_aid.env_inv import EnvDiscovery
|
||||||
|
|
||||||
|
# Apply any pending database migrations
|
||||||
|
from ra_aid.database import ensure_migrations_applied
|
||||||
|
try:
|
||||||
|
migration_result = ensure_migrations_applied()
|
||||||
|
if not migration_result:
|
||||||
|
logger.warning("Database migrations failed but execution will continue")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database migration error: {str(e)}")
|
||||||
|
|
||||||
|
# Initialize empty config dictionary
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# Initialize environment discovery
|
||||||
|
env_discovery = EnvDiscovery()
|
||||||
|
env_discovery.discover()
|
||||||
|
env_data = env_discovery.format_markdown()
|
||||||
|
|
||||||
print(f"Starting RA.Aid web interface on http://{host}:{port}")
|
print(f"Starting RA.Aid web interface on http://{host}:{port}")
|
||||||
run_server(host=host, port=port)
|
|
||||||
|
# Initialize database connection and repositories
|
||||||
|
with DatabaseManager() as db, \
|
||||||
|
SessionRepositoryManager(db) as session_repo, \
|
||||||
|
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||||
|
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||||
|
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||||
|
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||||
|
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||||
|
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||||
|
WorkLogRepositoryManager() as work_log_repo, \
|
||||||
|
ConfigRepositoryManager(config) as config_repo, \
|
||||||
|
EnvInvManager(env_data) as env_inv:
|
||||||
|
|
||||||
|
# This initializes all repositories and makes them available via their respective get methods
|
||||||
|
logger.debug("Initialized SessionRepository")
|
||||||
|
logger.debug("Initialized KeyFactRepository")
|
||||||
|
logger.debug("Initialized KeySnippetRepository")
|
||||||
|
logger.debug("Initialized HumanInputRepository")
|
||||||
|
logger.debug("Initialized ResearchNoteRepository")
|
||||||
|
logger.debug("Initialized RelatedFilesRepository")
|
||||||
|
logger.debug("Initialized TrajectoryRepository")
|
||||||
|
logger.debug("Initialized WorkLogRepository")
|
||||||
|
logger.debug("Initialized ConfigRepository")
|
||||||
|
logger.debug("Initialized Environment Inventory")
|
||||||
|
|
||||||
|
# Run the server within the context managers
|
||||||
|
run_server(host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments(args=None):
|
def parse_arguments(args=None):
|
||||||
|
|
|
||||||
|
|
@ -226,19 +226,33 @@ class SessionRepository:
|
||||||
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_all(self) -> List[SessionModel]:
|
def get_all(self, offset: int = 0, limit: int = 10) -> tuple[List[SessionModel], int]:
|
||||||
"""
|
"""
|
||||||
Get all sessions from the database.
|
Get all sessions from the database with pagination support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
offset: Number of sessions to skip (default: 0)
|
||||||
|
limit: Maximum number of sessions to return (default: 10)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[SessionModel]: List of all sessions
|
tuple: (List[SessionModel], int) containing the list of sessions and the total count
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
sessions = list(Session.select().order_by(Session.created_at.desc()))
|
# Get total count for pagination info
|
||||||
return [self._to_model(session) for session in sessions]
|
total_count = Session.select().count()
|
||||||
|
|
||||||
|
# Get paginated sessions ordered by created_at in descending order (newest first)
|
||||||
|
sessions = list(
|
||||||
|
Session.select()
|
||||||
|
.order_by(Session.created_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [self._to_model(session) for session in sessions], total_count
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to get all sessions: {str(e)}")
|
logger.error(f"Failed to get all sessions with pagination: {str(e)}")
|
||||||
return []
|
return [], 0
|
||||||
|
|
||||||
def get_recent(self, limit: int = 10) -> List[SessionModel]:
|
def get_recent(self, limit: int = 10) -> List[SessionModel]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,200 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
API v1 Session Endpoints.
|
||||||
|
|
||||||
|
This module provides RESTful API endpoints for managing sessions.
|
||||||
|
It implements routes for creating, listing, and retrieving sessions
|
||||||
|
with proper validation and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
import peewee
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository
|
||||||
|
from ra_aid.database.pydantic_models import SessionModel
|
||||||
|
|
||||||
|
# Create API router
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/v1/sessions",
|
||||||
|
tags=["sessions"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_404_NOT_FOUND: {"description": "Session not found"},
|
||||||
|
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"},
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Database error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for paginated API responses.
|
||||||
|
|
||||||
|
This model provides a standardized format for API responses that include
|
||||||
|
pagination, with a total count and the requested items.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
total: The total number of items available
|
||||||
|
items: List of items for the current page
|
||||||
|
limit: The limit parameter that was used
|
||||||
|
offset: The offset parameter that was used
|
||||||
|
"""
|
||||||
|
total: int
|
||||||
|
items: List[Any]
|
||||||
|
limit: int
|
||||||
|
offset: int
|
||||||
|
|
||||||
|
|
||||||
|
class CreateSessionRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for session creation requests.
|
||||||
|
|
||||||
|
This model provides validation for creating new sessions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
metadata: Optional dictionary of additional metadata to store with the session
|
||||||
|
"""
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional dictionary of additional metadata to store with the session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedSessionResponse(PaginatedResponse):
|
||||||
|
"""
|
||||||
|
Pydantic model for paginated session responses.
|
||||||
|
|
||||||
|
This model specializes the generic PaginatedResponse for SessionModel items.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
items: List of SessionModel items for the current page
|
||||||
|
"""
|
||||||
|
items: List[SessionModel]
|
||||||
|
|
||||||
|
|
||||||
|
# Dependency to get the session repository
|
||||||
|
def get_repository() -> SessionRepository:
|
||||||
|
"""
|
||||||
|
Get the SessionRepository instance.
|
||||||
|
|
||||||
|
This function is used as a FastAPI dependency and can be overridden
|
||||||
|
in tests using dependency_overrides.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionRepository: The repository instance
|
||||||
|
"""
|
||||||
|
return get_session_repository()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedSessionResponse,
|
||||||
|
summary="List sessions",
|
||||||
|
description="Get a paginated list of sessions",
|
||||||
|
)
|
||||||
|
async def list_sessions(
|
||||||
|
offset: int = Query(0, ge=0, description="Number of sessions to skip"),
|
||||||
|
limit: int = Query(10, ge=1, le=100, description="Maximum number of sessions to return"),
|
||||||
|
repo: SessionRepository = Depends(get_repository),
|
||||||
|
) -> PaginatedSessionResponse:
|
||||||
|
"""
|
||||||
|
Get a paginated list of sessions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
offset: Number of sessions to skip (default: 0)
|
||||||
|
limit: Maximum number of sessions to return (default: 10)
|
||||||
|
repo: SessionRepository dependency injection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PaginatedSessionResponse: Response with paginated sessions
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: With a 500 status code if there's a database error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sessions, total = repo.get_all(offset=offset, limit=limit)
|
||||||
|
return PaginatedSessionResponse(
|
||||||
|
total=total,
|
||||||
|
items=sessions,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Database error: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{session_id}",
|
||||||
|
response_model=SessionModel,
|
||||||
|
summary="Get session",
|
||||||
|
description="Get a specific session by ID",
|
||||||
|
)
|
||||||
|
async def get_session(
|
||||||
|
session_id: int,
|
||||||
|
repo: SessionRepository = Depends(get_repository),
|
||||||
|
) -> SessionModel:
|
||||||
|
"""
|
||||||
|
Get a specific session by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session to retrieve
|
||||||
|
repo: SessionRepository dependency injection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionModel: The requested session
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: With a 404 status code if the session is not found
|
||||||
|
HTTPException: With a 500 status code if there's a database error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session = repo.get(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Session with ID {session_id} not found",
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Database error: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"",
|
||||||
|
response_model=SessionModel,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create session",
|
||||||
|
description="Create a new session",
|
||||||
|
)
|
||||||
|
async def create_session(
|
||||||
|
request: Optional[CreateSessionRequest] = None,
|
||||||
|
repo: SessionRepository = Depends(get_repository),
|
||||||
|
) -> SessionModel:
|
||||||
|
"""
|
||||||
|
Create a new session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Optional request body with session metadata
|
||||||
|
repo: SessionRepository dependency injection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionModel: The newly created session
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: With a 500 status code if there's a database error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metadata = request.metadata if request else None
|
||||||
|
return repo.create_session(metadata=metadata)
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Database error: {str(e)}",
|
||||||
|
)
|
||||||
|
|
@ -33,7 +33,13 @@ from fastapi.responses import HTMLResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
app = FastAPI()
|
from ra_aid.server.api_v1_sessions import router as sessions_router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="RA.Aid API",
|
||||||
|
description="API for RA.Aid - AI Programming Assistant",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
# Add CORS middleware
|
# Add CORS middleware
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|
@ -44,6 +50,9 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Include API routers
|
||||||
|
app.include_router(sessions_router)
|
||||||
|
|
||||||
# Setup templates and static files directories
|
# Setup templates and static files directories
|
||||||
CURRENT_DIR = Path(__file__).parent
|
CURRENT_DIR = Path(__file__).parent
|
||||||
templates = Jinja2Templates(directory=CURRENT_DIR)
|
templates = Jinja2Templates(directory=CURRENT_DIR)
|
||||||
|
|
|
||||||
|
|
@ -231,8 +231,11 @@ def test_get_all(setup_db):
|
||||||
repo.create_session(metadata=metadata2)
|
repo.create_session(metadata=metadata2)
|
||||||
repo.create_session(metadata=metadata3)
|
repo.create_session(metadata=metadata3)
|
||||||
|
|
||||||
# Get all sessions
|
# Get all sessions with default pagination
|
||||||
sessions = repo.get_all()
|
sessions, total_count = repo.get_all()
|
||||||
|
|
||||||
|
# Verify total count
|
||||||
|
assert total_count == 3
|
||||||
|
|
||||||
# Verify we got a list of SessionModel objects
|
# Verify we got a list of SessionModel objects
|
||||||
assert len(sessions) == 3
|
assert len(sessions) == 3
|
||||||
|
|
@ -249,6 +252,19 @@ def test_get_all(setup_db):
|
||||||
assert "Linux" in os_values
|
assert "Linux" in os_values
|
||||||
assert "Windows" in os_values
|
assert "Windows" in os_values
|
||||||
assert "macOS" in os_values
|
assert "macOS" in os_values
|
||||||
|
|
||||||
|
# Test pagination with limit
|
||||||
|
sessions_limited, total_count = repo.get_all(limit=2)
|
||||||
|
assert total_count == 3 # Total count should still be 3
|
||||||
|
assert len(sessions_limited) == 2 # But only 2 returned
|
||||||
|
|
||||||
|
# Test pagination with offset
|
||||||
|
sessions_offset, total_count = repo.get_all(offset=1, limit=2)
|
||||||
|
assert total_count == 3
|
||||||
|
assert len(sessions_offset) == 2
|
||||||
|
|
||||||
|
# The second item in the full list should be the first item in the offset list
|
||||||
|
assert sessions[1].id == sessions_offset[0].id
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_empty(setup_db):
|
def test_get_all_empty(setup_db):
|
||||||
|
|
@ -257,11 +273,12 @@ def test_get_all_empty(setup_db):
|
||||||
repo = SessionRepository(db=setup_db)
|
repo = SessionRepository(db=setup_db)
|
||||||
|
|
||||||
# Get all sessions
|
# Get all sessions
|
||||||
sessions = repo.get_all()
|
sessions, total_count = repo.get_all()
|
||||||
|
|
||||||
# Verify we got an empty list
|
# Verify we got an empty list and zero count
|
||||||
assert isinstance(sessions, list)
|
assert isinstance(sessions, list)
|
||||||
assert len(sessions) == 0
|
assert len(sessions) == 0
|
||||||
|
assert total_count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_recent(setup_db):
|
def test_get_recent(setup_db):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,132 @@
|
||||||
|
"""
|
||||||
|
Tests for the Sessions API v1 endpoints.
|
||||||
|
|
||||||
|
This module contains tests for the sessions API endpoints in ra_aid/server/api_v1_sessions.py.
|
||||||
|
It tests the creation, listing, and retrieval of sessions through the API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from ra_aid.server.api_v1_sessions import router, get_repository
|
||||||
|
from ra_aid.database.pydantic_models import SessionModel
|
||||||
|
|
||||||
|
|
||||||
|
# Mock session data for testing
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session():
|
||||||
|
"""Return a mock session for testing."""
|
||||||
|
return SessionModel(
|
||||||
|
id=1,
|
||||||
|
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
command_line="ra-aid test",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"os": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_sessions():
|
||||||
|
"""Return a list of mock sessions for testing."""
|
||||||
|
return [
|
||||||
|
SessionModel(
|
||||||
|
id=1,
|
||||||
|
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
command_line="ra-aid test1",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"os": "test"}
|
||||||
|
),
|
||||||
|
SessionModel(
|
||||||
|
id=2,
|
||||||
|
created_at=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||||
|
updated_at=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||||
|
start_time=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||||
|
command_line="ra-aid test2",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"os": "test"}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_repo(mock_session, mock_sessions):
|
||||||
|
"""Mock the SessionRepository for testing."""
|
||||||
|
repo = MagicMock()
|
||||||
|
repo.get.return_value = mock_session
|
||||||
|
repo.get_all.return_value = (mock_sessions, len(mock_sessions))
|
||||||
|
repo.create_session.return_value = mock_session
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_repo):
|
||||||
|
"""Return a TestClient for the API router with dependency override."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
# Override the dependency
|
||||||
|
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session(client, mock_repo, mock_session):
|
||||||
|
"""Test getting a specific session by ID."""
|
||||||
|
response = client.get("/v1/sessions/1")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["id"] == mock_session.id
|
||||||
|
assert response.json()["command_line"] == mock_session.command_line
|
||||||
|
mock_repo.get.assert_called_once_with(1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_not_found(client, mock_repo):
|
||||||
|
"""Test getting a session that doesn't exist."""
|
||||||
|
mock_repo.get.return_value = None
|
||||||
|
|
||||||
|
response = client.get("/v1/sessions/999")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "not found" in response.json()["detail"]
|
||||||
|
mock_repo.get.assert_called_once_with(999)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_sessions(client, mock_repo, mock_sessions):
|
||||||
|
"""Test listing sessions with pagination."""
|
||||||
|
response = client.get("/v1/sessions?offset=0&limit=10")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == len(mock_sessions)
|
||||||
|
assert len(data["items"]) == len(mock_sessions)
|
||||||
|
assert data["limit"] == 10
|
||||||
|
assert data["offset"] == 0
|
||||||
|
mock_repo.get_all.assert_called_once_with(offset=0, limit=10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_session(client, mock_repo, mock_session):
|
||||||
|
"""Test creating a new session."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/sessions",
|
||||||
|
json={"metadata": {"test": "data"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
assert response.json()["id"] == mock_session.id
|
||||||
|
mock_repo.create_session.assert_called_once_with(metadata={"test": "data"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_session_no_body(client, mock_repo, mock_session):
|
||||||
|
"""Test creating a new session without a request body."""
|
||||||
|
response = client.post("/v1/sessions")
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
assert response.json()["id"] == mock_session.id
|
||||||
|
mock_repo.create_session.assert_called_once_with(metadata=None)
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
"""
|
||||||
|
Tests for server.py FastAPI application.
|
||||||
|
|
||||||
|
This module tests the FastAPI application setup in server.py to ensure
|
||||||
|
that all routers are properly mounted and middleware is configured.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from ra_aid.server.server import app
|
||||||
|
from ra_aid.database.repositories.session_repository import session_repo_var
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
"""Return a TestClient for the FastAPI app."""
|
||||||
|
# Mock the session repository to avoid database dependency
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
mock_repo.get_all.return_value = ([], 0)
|
||||||
|
|
||||||
|
# Set the repository in the contextvar
|
||||||
|
token = session_repo_var.set(mock_repo)
|
||||||
|
|
||||||
|
yield TestClient(app)
|
||||||
|
|
||||||
|
# Reset the contextvar after the test
|
||||||
|
session_repo_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_endpoint(client):
|
||||||
|
"""Test that the config endpoint returns server configuration."""
|
||||||
|
response = client.get("/config")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "host" in response.json()
|
||||||
|
assert "port" in response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_documentation(client):
|
||||||
|
"""Test that the OpenAPI documentation includes the sessions API."""
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
openapi_spec = response.json()
|
||||||
|
assert "paths" in openapi_spec
|
||||||
|
|
||||||
|
# Check that the sessions API paths are included
|
||||||
|
assert "/v1/sessions" in openapi_spec["paths"]
|
||||||
|
assert "/v1/sessions/{session_id}" in openapi_spec["paths"]
|
||||||
|
|
||||||
|
# Verify that sessions API operations are documented
|
||||||
|
assert "get" in openapi_spec["paths"]["/v1/sessions"]
|
||||||
|
assert "post" in openapi_spec["paths"]["/v1/sessions"]
|
||||||
|
assert "get" in openapi_spec["paths"]["/v1/sessions/{session_id}"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("ra_aid.database.repositories.session_repository.get_session_repository")
|
||||||
|
def test_sessions_api_mounted(mock_get_repo, client):
|
||||||
|
"""Test that the sessions API router is mounted correctly."""
|
||||||
|
# Mock the repository for this specific test
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
mock_repo.get_all.return_value = ([], 0)
|
||||||
|
mock_get_repo.return_value = mock_repo
|
||||||
|
|
||||||
|
# Test that the sessions list endpoint is accessible
|
||||||
|
response = client.get("/v1/sessions")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify the response structure follows our expected format
|
||||||
|
data = response.json()
|
||||||
|
assert "total" in data
|
||||||
|
assert "items" in data
|
||||||
|
assert "limit" in data
|
||||||
|
assert "offset" in data
|
||||||
|
|
@ -0,0 +1,532 @@
|
||||||
|
"""
|
||||||
|
Integration tests for the Sessions API endpoints.
|
||||||
|
|
||||||
|
This module contains integration tests for the API endpoints defined in ra_aid/server/api_v1_sessions.py.
|
||||||
|
It uses mocks to simulate the database interactions while testing the real API behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import datetime
|
||||||
|
from typing import Dict, Any, List, Tuple
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from ra_aid.server.server import app
|
||||||
|
from ra_aid.database.pydantic_models import SessionModel
|
||||||
|
from ra_aid.server.api_v1_sessions import get_repository
|
||||||
|
|
||||||
|
|
||||||
|
# Mock session data for testing
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session():
|
||||||
|
"""Return a mock session for testing."""
|
||||||
|
return SessionModel(
|
||||||
|
id=1,
|
||||||
|
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
command_line="ra-aid test",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"os": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_sessions():
|
||||||
|
"""Return a list of mock sessions for testing."""
|
||||||
|
return [
|
||||||
|
SessionModel(
|
||||||
|
id=i+1,
|
||||||
|
created_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||||
|
updated_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||||
|
start_time=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||||
|
command_line=f"ra-aid test{i+1}",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"index": i}
|
||||||
|
)
|
||||||
|
for i in range(15)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_repo(mock_session, mock_sessions):
|
||||||
|
"""Create a mock repository with predefined responses."""
|
||||||
|
repo = MagicMock()
|
||||||
|
repo.get.return_value = mock_session
|
||||||
|
repo.get_all.return_value = (mock_sessions[:10], len(mock_sessions))
|
||||||
|
repo.create_session.return_value = mock_session
|
||||||
|
|
||||||
|
# Add behavior for custom parameters
|
||||||
|
def get_with_id(session_id):
|
||||||
|
if session_id == 999999:
|
||||||
|
return None
|
||||||
|
for session in mock_sessions:
|
||||||
|
if session.id == session_id:
|
||||||
|
return session
|
||||||
|
return mock_session
|
||||||
|
|
||||||
|
def get_all_with_pagination(offset=0, limit=10):
|
||||||
|
total = len(mock_sessions)
|
||||||
|
sorted_sessions = sorted(mock_sessions, key=lambda s: s.id, reverse=True)
|
||||||
|
return sorted_sessions[offset:offset+limit], total
|
||||||
|
|
||||||
|
def create_with_metadata(metadata=None):
|
||||||
|
if metadata is None:
|
||||||
|
return SessionModel(
|
||||||
|
id=16,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid test-null",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info=None
|
||||||
|
)
|
||||||
|
return SessionModel(
|
||||||
|
id=16,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid test-custom",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
repo.get.side_effect = get_with_id
|
||||||
|
repo.get_all.side_effect = get_all_with_pagination
|
||||||
|
repo.create_session.side_effect = create_with_metadata
|
||||||
|
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_repo):
|
||||||
|
"""Create a TestClient with the API and dependency overrides."""
|
||||||
|
# Override the dependency to use our mock repository
|
||||||
|
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||||
|
|
||||||
|
# Create a test client
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
# Clean up the dependency override
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_metadata() -> Dict[str, Any]:
|
||||||
|
"""Return sample metadata for session creation."""
|
||||||
|
return {
|
||||||
|
"os": "Test OS",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"environment": "test",
|
||||||
|
"cpu_cores": 4,
|
||||||
|
"memory_gb": 16,
|
||||||
|
"additional_info": {
|
||||||
|
"gpu": "Test GPU",
|
||||||
|
"display_resolution": "1920x1080"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_session_with_metadata(client, mock_repo, sample_metadata):
|
||||||
|
"""Test creating a session with metadata through the API endpoint."""
|
||||||
|
# Send request to create a session with metadata
|
||||||
|
response = client.post(
|
||||||
|
"/v1/sessions",
|
||||||
|
json={"metadata": sample_metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response status code and structure
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify the session was created with the expected fields
|
||||||
|
assert data["id"] is not None
|
||||||
|
assert data["command_line"] is not None
|
||||||
|
assert data["program_version"] is not None
|
||||||
|
assert data["created_at"] is not None
|
||||||
|
assert data["updated_at"] is not None
|
||||||
|
assert data["start_time"] is not None
|
||||||
|
|
||||||
|
# Verify metadata was passed correctly to the repository
|
||||||
|
mock_repo.create_session.assert_called_once_with(metadata=sample_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_session_without_metadata(client, mock_repo):
|
||||||
|
"""Test creating a session without metadata through the API endpoint."""
|
||||||
|
# Send request without a body
|
||||||
|
response = client.post("/v1/sessions")
|
||||||
|
|
||||||
|
# Verify response status code and structure
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify the session was created with the expected fields
|
||||||
|
assert data["id"] is not None
|
||||||
|
assert data["command_line"] is not None
|
||||||
|
assert data["program_version"] is not None
|
||||||
|
|
||||||
|
# Verify correct parameters were passed to the repository
|
||||||
|
mock_repo.create_session.assert_called_once_with(metadata=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_by_id(client):
|
||||||
|
"""Test retrieving a session by ID through the API endpoint."""
|
||||||
|
# Use a completely isolated, standalone test
|
||||||
|
|
||||||
|
# For this test, let's focus on verifying the core functionality:
|
||||||
|
# 1. The API endpoint receives a request for a specific session ID
|
||||||
|
# 2. It calls the repository with that ID
|
||||||
|
# 3. It returns a properly formatted response
|
||||||
|
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a test session with a simple machine_info to reduce serialization issues
|
||||||
|
test_session = SessionModel(
|
||||||
|
id=42,
|
||||||
|
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||||
|
command_line="ra-aid specific-test",
|
||||||
|
program_version="1.0.0-test",
|
||||||
|
machine_info=None # Use None to avoid serialization issues
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure the mock
|
||||||
|
mock_repo.get.return_value = test_session
|
||||||
|
|
||||||
|
# Override the dependency
|
||||||
|
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Retrieve the session through the API
|
||||||
|
response = client.get(f"/v1/sessions/{test_session.id}")
|
||||||
|
|
||||||
|
# Verify response status code
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Parse the response data
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Print for debugging
|
||||||
|
import json
|
||||||
|
print("Response JSON:", json.dumps(data, indent=2))
|
||||||
|
|
||||||
|
# Verify the returned session matches what we expected
|
||||||
|
assert data["id"] == test_session.id
|
||||||
|
assert data["command_line"] == test_session.command_line
|
||||||
|
assert data["program_version"] == test_session.program_version
|
||||||
|
assert data["machine_info"] is None
|
||||||
|
|
||||||
|
# Verify the repository was called with the correct ID
|
||||||
|
mock_repo.get.assert_called_once_with(test_session.id)
|
||||||
|
finally:
|
||||||
|
# Clean up the override
|
||||||
|
if get_repository in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[get_repository]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_not_found(client, mock_repo):
|
||||||
|
"""Test the error handling when requesting a non-existent session."""
|
||||||
|
# Try to get a session with a non-existent ID
|
||||||
|
response = client.get("/v1/sessions/999999")
|
||||||
|
|
||||||
|
# Verify response status code and error message
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify the repository was called with the correct ID
|
||||||
|
mock_repo.get.assert_called_with(999999)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_sessions_empty(client, mock_repo):
|
||||||
|
"""Test listing sessions when no sessions exist."""
|
||||||
|
# Reset the mock first to clear any previous calls/side effects
|
||||||
|
mock_repo.reset_mock()
|
||||||
|
|
||||||
|
# Configure the mock to return empty results
|
||||||
|
mock_repo.get_all.side_effect = None # Clear any previous side effects
|
||||||
|
mock_repo.get_all.return_value = ([], 0)
|
||||||
|
|
||||||
|
# Get the list of sessions
|
||||||
|
response = client.get("/v1/sessions")
|
||||||
|
|
||||||
|
# Verify response status code and structure
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify the pagination response
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert len(data["items"]) == 0
|
||||||
|
assert data["limit"] == 10
|
||||||
|
assert data["offset"] == 0
|
||||||
|
|
||||||
|
# Verify the repository was called with the correct parameters
|
||||||
|
mock_repo.get_all.assert_called_with(offset=0, limit=10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_sessions_with_pagination(client, mock_repo, mock_sessions):
|
||||||
|
"""Test listing sessions with pagination parameters."""
|
||||||
|
# Set up the repository mock to return specific results for different pagination parameters
|
||||||
|
default_result = (mock_sessions[:10], len(mock_sessions))
|
||||||
|
limit_5_result = (mock_sessions[:5], len(mock_sessions))
|
||||||
|
offset_10_result = (mock_sessions[10:], len(mock_sessions))
|
||||||
|
offset_5_limit_3_result = (mock_sessions[5:8], len(mock_sessions))
|
||||||
|
|
||||||
|
pagination_responses = {
|
||||||
|
(0, 10): default_result,
|
||||||
|
(0, 5): limit_5_result,
|
||||||
|
(10, 10): offset_10_result,
|
||||||
|
(5, 3): offset_5_limit_3_result
|
||||||
|
}
|
||||||
|
|
||||||
|
def mock_get_all(offset=0, limit=10):
|
||||||
|
return pagination_responses.get((offset, limit), ([], 0))
|
||||||
|
|
||||||
|
mock_repo.get_all.side_effect = mock_get_all
|
||||||
|
|
||||||
|
# Test default pagination (limit=10, offset=0)
|
||||||
|
response = client.get("/v1/sessions")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == len(mock_sessions)
|
||||||
|
assert len(data["items"]) == 10
|
||||||
|
assert data["limit"] == 10
|
||||||
|
assert data["offset"] == 0
|
||||||
|
|
||||||
|
# Test with custom limit
|
||||||
|
response = client.get("/v1/sessions?limit=5")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == len(mock_sessions)
|
||||||
|
assert len(data["items"]) == 5
|
||||||
|
assert data["limit"] == 5
|
||||||
|
assert data["offset"] == 0
|
||||||
|
|
||||||
|
# Test with custom offset
|
||||||
|
response = client.get("/v1/sessions?offset=10")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == len(mock_sessions)
|
||||||
|
assert len(data["items"]) == 5 # Only 5 items left after offset 10
|
||||||
|
assert data["limit"] == 10
|
||||||
|
assert data["offset"] == 10
|
||||||
|
|
||||||
|
# Test with both custom limit and offset
|
||||||
|
response = client.get("/v1/sessions?limit=3&offset=5")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == len(mock_sessions)
|
||||||
|
assert len(data["items"]) == 3
|
||||||
|
assert data["limit"] == 3
|
||||||
|
assert data["offset"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_sessions_invalid_parameters(client):
|
||||||
|
"""Test error handling for invalid pagination parameters."""
|
||||||
|
# Test with negative offset
|
||||||
|
response = client.get("/v1/sessions?offset=-1")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
# Test with negative limit
|
||||||
|
response = client.get("/v1/sessions?limit=-5")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
# Test with zero limit
|
||||||
|
response = client.get("/v1/sessions?limit=0")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
# Test with limit exceeding maximum
|
||||||
|
response = client.get("/v1/sessions?limit=101")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_validation(client, mock_repo):
|
||||||
|
"""Test validation for different metadata formats in session creation."""
|
||||||
|
# Create test sessions with different metadata
|
||||||
|
null_metadata_session = SessionModel(
|
||||||
|
id=20,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid test-null",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info=None
|
||||||
|
)
|
||||||
|
|
||||||
|
empty_dict_metadata_session = SessionModel(
|
||||||
|
id=21,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid test-empty",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={}
|
||||||
|
)
|
||||||
|
|
||||||
|
complex_metadata_session = SessionModel(
|
||||||
|
id=22,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid test-complex",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"level1": {"level2": {"level3": [1, 2, 3, {"key": "value"}]}}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure mock to return different sessions based on metadata
|
||||||
|
def create_with_specific_metadata(metadata=None):
|
||||||
|
if metadata is None:
|
||||||
|
return null_metadata_session
|
||||||
|
elif metadata == {}:
|
||||||
|
return empty_dict_metadata_session
|
||||||
|
elif isinstance(metadata, dict) and "level1" in metadata:
|
||||||
|
return complex_metadata_session
|
||||||
|
return SessionModel(
|
||||||
|
id=23,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid test-other",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_repo.create_session.side_effect = create_with_specific_metadata
|
||||||
|
|
||||||
|
# Try to create a session with null metadata
|
||||||
|
response = client.post(
|
||||||
|
"/v1/sessions",
|
||||||
|
json={"metadata": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should work fine
|
||||||
|
assert response.status_code == 201
|
||||||
|
mock_repo.create_session.assert_called_with(metadata=None)
|
||||||
|
|
||||||
|
# Try to create a session with an empty metadata dict
|
||||||
|
response = client.post(
|
||||||
|
"/v1/sessions",
|
||||||
|
json={"metadata": {}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should work fine
|
||||||
|
assert response.status_code == 201
|
||||||
|
mock_repo.create_session.assert_called_with(metadata={})
|
||||||
|
|
||||||
|
# Try to create a session with a complex nested metadata
|
||||||
|
response = client.post(
|
||||||
|
"/v1/sessions",
|
||||||
|
json={"metadata": {
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"level3": [1, 2, 3, {"key": "value"}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the complex nested structure is preserved
|
||||||
|
assert response.status_code == 201
|
||||||
|
complex_metadata = {
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"level3": [1, 2, 3, {"key": "value"}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_repo.create_session.assert_called_with(metadata=complex_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def test_integration_workflow(client, mock_repo):
|
||||||
|
"""Test a complete workflow of creating and retrieving sessions."""
|
||||||
|
# Set up mock sessions for the workflow
|
||||||
|
first_session = SessionModel(
|
||||||
|
id=30,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid workflow-1",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"workflow_test": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
second_session = SessionModel(
|
||||||
|
id=31,
|
||||||
|
created_at=datetime.datetime.now(),
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line="ra-aid workflow-2",
|
||||||
|
program_version="1.0.0",
|
||||||
|
machine_info={"workflow_test": False, "second": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure mock for create_session
|
||||||
|
create_calls = 0
|
||||||
|
def create_session_for_workflow(metadata=None):
|
||||||
|
nonlocal create_calls
|
||||||
|
create_calls += 1
|
||||||
|
if create_calls == 1:
|
||||||
|
return first_session
|
||||||
|
return second_session
|
||||||
|
|
||||||
|
mock_repo.create_session.side_effect = create_session_for_workflow
|
||||||
|
|
||||||
|
# Configure mock for get
|
||||||
|
def get_session_for_workflow(session_id):
|
||||||
|
if session_id == first_session.id:
|
||||||
|
return first_session
|
||||||
|
elif session_id == second_session.id:
|
||||||
|
return second_session
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_repo.get.side_effect = get_session_for_workflow
|
||||||
|
|
||||||
|
# Configure mock for get_all
|
||||||
|
def get_all_for_workflow(offset=0, limit=10):
|
||||||
|
if create_calls == 1:
|
||||||
|
return [first_session], 1
|
||||||
|
return [second_session, first_session], 2
|
||||||
|
|
||||||
|
mock_repo.get_all.side_effect = get_all_for_workflow
|
||||||
|
|
||||||
|
# 1. Create a session
|
||||||
|
create_response = client.post(
|
||||||
|
"/v1/sessions",
|
||||||
|
json={"metadata": {"workflow_test": True}}
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 201
|
||||||
|
session_id = create_response.json()["id"]
|
||||||
|
assert session_id == first_session.id
|
||||||
|
|
||||||
|
# 2. Retrieve the created session
|
||||||
|
get_response = client.get(f"/v1/sessions/{session_id}")
|
||||||
|
assert get_response.status_code == 200
|
||||||
|
assert get_response.json()["id"] == session_id
|
||||||
|
|
||||||
|
# 3. List all sessions and verify the created one is included
|
||||||
|
list_response = client.get("/v1/sessions")
|
||||||
|
assert list_response.status_code == 200
|
||||||
|
items = list_response.json()["items"]
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["id"] == session_id
|
||||||
|
|
||||||
|
# 4. Create a second session
|
||||||
|
create_response2 = client.post(
|
||||||
|
"/v1/sessions",
|
||||||
|
json={"metadata": {"workflow_test": False, "second": True}}
|
||||||
|
)
|
||||||
|
assert create_response2.status_code == 201
|
||||||
|
session_id2 = create_response2.json()["id"]
|
||||||
|
assert session_id2 == second_session.id
|
||||||
|
|
||||||
|
# 5. List all sessions and verify both sessions are included
|
||||||
|
list_response = client.get("/v1/sessions")
|
||||||
|
assert list_response.status_code == 200
|
||||||
|
data = list_response.json()
|
||||||
|
assert data["total"] == 2
|
||||||
|
items = data["items"]
|
||||||
|
assert len(items) == 2
|
||||||
|
assert items[0]["id"] == session_id2
|
||||||
|
assert items[1]["id"] == session_id
|
||||||
Loading…
Reference in New Issue