364 lines
11 KiB
Python
364 lines
11 KiB
Python
"""
|
|
Tests for the SessionRepository class.
|
|
"""
|
|
|
|
import pytest
|
|
import datetime
|
|
import json
|
|
from unittest.mock import patch
|
|
|
|
import peewee
|
|
|
|
from ra_aid.database.connection import DatabaseManager, db_var
|
|
from ra_aid.database.models import Session, BaseModel
|
|
from ra_aid.database.repositories.session_repository import (
|
|
SessionRepository,
|
|
SessionRepositoryManager,
|
|
get_session_repository,
|
|
session_repo_var
|
|
)
|
|
from ra_aid.database.pydantic_models import SessionModel
|
|
|
|
|
|
@pytest.fixture
|
|
def cleanup_db():
|
|
"""Reset the database contextvar and connection state after each test."""
|
|
# Reset before the test
|
|
db = db_var.get()
|
|
if db is not None:
|
|
try:
|
|
if not db.is_closed():
|
|
db.close()
|
|
except Exception:
|
|
# Ignore errors when closing the database
|
|
pass
|
|
db_var.set(None)
|
|
|
|
# Run the test
|
|
yield
|
|
|
|
# Reset after the test
|
|
db = db_var.get()
|
|
if db is not None:
|
|
try:
|
|
if not db.is_closed():
|
|
db.close()
|
|
except Exception:
|
|
# Ignore errors when closing the database
|
|
pass
|
|
db_var.set(None)
|
|
|
|
|
|
@pytest.fixture
|
|
def cleanup_repo():
|
|
"""Reset the repository contextvar after each test."""
|
|
# Reset before the test
|
|
session_repo_var.set(None)
|
|
|
|
# Run the test
|
|
yield
|
|
|
|
# Reset after the test
|
|
session_repo_var.set(None)
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_db(cleanup_db):
|
|
"""Set up an in-memory database with the Session table and patch the BaseModel.Meta.database."""
|
|
# Initialize an in-memory database connection
|
|
with DatabaseManager(in_memory=True) as db:
|
|
# Patch the BaseModel.Meta.database to use our in-memory database
|
|
with patch.object(BaseModel._meta, 'database', db):
|
|
# Create the Session table
|
|
with db.atomic():
|
|
db.create_tables([Session], safe=True)
|
|
|
|
yield db
|
|
|
|
# Clean up
|
|
with db.atomic():
|
|
Session.drop_table(safe=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_metadata():
|
|
"""Return test metadata for sessions."""
|
|
return {
|
|
"os": "Test OS",
|
|
"version": "1.0",
|
|
"cpu_cores": 4,
|
|
"memory_gb": 16,
|
|
"additional_info": {
|
|
"gpu": "Test GPU",
|
|
"display_resolution": "1920x1080"
|
|
}
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_session(setup_db, test_metadata):
|
|
"""Create a sample session in the database."""
|
|
now = datetime.datetime.now()
|
|
return Session.create(
|
|
start_time=now,
|
|
command_line="ra-aid test",
|
|
program_version="1.0.0",
|
|
machine_info=json.dumps(test_metadata)
|
|
)
|
|
|
|
|
|
def test_create_session_with_metadata(setup_db, test_metadata):
|
|
"""Test creating a session with metadata."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Create a session with metadata
|
|
session = repo.create_session(metadata=test_metadata)
|
|
|
|
# Verify type is SessionModel, not Session
|
|
assert isinstance(session, SessionModel)
|
|
|
|
# Verify the session was created correctly
|
|
assert session.id is not None
|
|
assert session.command_line is not None
|
|
assert session.program_version is not None
|
|
|
|
# Verify machine_info is a dict, not a JSON string
|
|
assert isinstance(session.machine_info, dict)
|
|
assert session.machine_info == test_metadata
|
|
|
|
# Verify the dictionary structure is preserved
|
|
assert "additional_info" in session.machine_info
|
|
assert session.machine_info["additional_info"]["gpu"] == "Test GPU"
|
|
|
|
|
|
def test_create_session_without_metadata(setup_db):
|
|
"""Test creating a session without metadata."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Create a session without metadata
|
|
session = repo.create_session()
|
|
|
|
# Verify type is SessionModel, not Session
|
|
assert isinstance(session, SessionModel)
|
|
|
|
# Verify the session was created correctly
|
|
assert session.id is not None
|
|
assert session.command_line is not None
|
|
assert session.program_version is not None
|
|
|
|
# Verify machine_info is None
|
|
assert session.machine_info is None
|
|
|
|
|
|
def test_get_current_session(setup_db, sample_session):
|
|
"""Test retrieving the current session."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Set the current session
|
|
repo.current_session = sample_session
|
|
|
|
# Get the current session
|
|
current_session = repo.get_current_session()
|
|
|
|
# Verify type is SessionModel, not Session
|
|
assert isinstance(current_session, SessionModel)
|
|
|
|
# Verify the retrieved session matches the original
|
|
assert current_session.id == sample_session.id
|
|
assert current_session.command_line == sample_session.command_line
|
|
assert current_session.program_version == sample_session.program_version
|
|
|
|
# Verify machine_info is a dict, not a JSON string
|
|
assert isinstance(current_session.machine_info, dict)
|
|
|
|
|
|
def test_get_current_session_from_db(setup_db, sample_session):
|
|
"""Test retrieving the current session from the database when no current session is set."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Get the current session (should retrieve the most recent from DB)
|
|
current_session = repo.get_current_session()
|
|
|
|
# Verify type is SessionModel, not Session
|
|
assert isinstance(current_session, SessionModel)
|
|
|
|
# Verify the retrieved session matches the sample session
|
|
assert current_session.id == sample_session.id
|
|
|
|
# Verify machine_info is a dict, not a JSON string
|
|
assert isinstance(current_session.machine_info, dict)
|
|
|
|
|
|
def test_get_by_id(setup_db, sample_session):
|
|
"""Test retrieving a session by ID."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Get the session by ID
|
|
session = repo.get(sample_session.id)
|
|
|
|
# Verify type is SessionModel, not Session
|
|
assert isinstance(session, SessionModel)
|
|
|
|
# Verify the retrieved session matches the original
|
|
assert session.id == sample_session.id
|
|
assert session.command_line == sample_session.command_line
|
|
assert session.program_version == sample_session.program_version
|
|
|
|
# Verify machine_info is a dict, not a JSON string
|
|
assert isinstance(session.machine_info, dict)
|
|
|
|
# Verify getting a non-existent session returns None
|
|
non_existent_session = repo.get(999)
|
|
assert non_existent_session is None
|
|
|
|
|
|
def test_get_all(setup_db):
|
|
"""Test retrieving all sessions."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Create multiple sessions
|
|
metadata1 = {"os": "Linux", "cpu_cores": 8}
|
|
metadata2 = {"os": "Windows", "cpu_cores": 4}
|
|
metadata3 = {"os": "macOS", "cpu_cores": 10}
|
|
|
|
repo.create_session(metadata=metadata1)
|
|
repo.create_session(metadata=metadata2)
|
|
repo.create_session(metadata=metadata3)
|
|
|
|
# Get all sessions with default pagination
|
|
sessions, total_count = repo.get_all()
|
|
|
|
# Verify total count
|
|
assert total_count == 3
|
|
|
|
# Verify we got a list of SessionModel objects
|
|
assert len(sessions) == 3
|
|
for session in sessions:
|
|
assert isinstance(session, SessionModel)
|
|
assert isinstance(session.machine_info, dict)
|
|
|
|
# Verify the sessions are in descending order of creation time
|
|
assert sessions[0].created_at >= sessions[1].created_at
|
|
assert sessions[1].created_at >= sessions[2].created_at
|
|
|
|
# Verify the machine_info fields
|
|
os_values = [session.machine_info["os"] for session in sessions]
|
|
assert "Linux" in os_values
|
|
assert "Windows" in os_values
|
|
assert "macOS" in os_values
|
|
|
|
# Test pagination with limit
|
|
sessions_limited, total_count = repo.get_all(limit=2)
|
|
assert total_count == 3 # Total count should still be 3
|
|
assert len(sessions_limited) == 2 # But only 2 returned
|
|
|
|
# Test pagination with offset
|
|
sessions_offset, total_count = repo.get_all(offset=1, limit=2)
|
|
assert total_count == 3
|
|
assert len(sessions_offset) == 2
|
|
|
|
# The second item in the full list should be the first item in the offset list
|
|
assert sessions[1].id == sessions_offset[0].id
|
|
|
|
|
|
def test_get_all_empty(setup_db):
|
|
"""Test retrieving all sessions when none exist."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Get all sessions
|
|
sessions, total_count = repo.get_all()
|
|
|
|
# Verify we got an empty list and zero count
|
|
assert isinstance(sessions, list)
|
|
assert len(sessions) == 0
|
|
assert total_count == 0
|
|
|
|
|
|
def test_get_recent(setup_db):
|
|
"""Test retrieving recent sessions with a limit."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Create multiple sessions
|
|
for i in range(5):
|
|
metadata = {"index": i, "os": f"OS {i}"}
|
|
repo.create_session(metadata=metadata)
|
|
|
|
# Get recent sessions with limit=3
|
|
sessions = repo.get_recent(limit=3)
|
|
|
|
# Verify we got the correct number of SessionModel objects
|
|
assert len(sessions) == 3
|
|
for session in sessions:
|
|
assert isinstance(session, SessionModel)
|
|
assert isinstance(session.machine_info, dict)
|
|
|
|
# Verify the sessions are in descending order and are the most recent ones
|
|
indexes = [session.machine_info["index"] for session in sessions]
|
|
assert indexes == [4, 3, 2] # Most recent first
|
|
|
|
|
|
def test_session_repository_manager(setup_db, cleanup_repo):
|
|
"""Test the SessionRepositoryManager context manager."""
|
|
# Use the context manager to create a repository
|
|
with SessionRepositoryManager(setup_db) as repo:
|
|
# Verify the repository was created correctly
|
|
assert isinstance(repo, SessionRepository)
|
|
assert repo.db is setup_db
|
|
|
|
# Create a session and verify it's a SessionModel
|
|
metadata = {"test": "manager"}
|
|
session = repo.create_session(metadata=metadata)
|
|
assert isinstance(session, SessionModel)
|
|
assert session.machine_info["test"] == "manager"
|
|
|
|
# Verify we can get the repository using get_session_repository
|
|
repo_from_var = get_session_repository()
|
|
assert repo_from_var is repo
|
|
|
|
# Verify the repository was removed from the context var
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
get_session_repository()
|
|
|
|
assert "No SessionRepository available" in str(excinfo.value)
|
|
|
|
|
|
def test_repository_init_without_db():
|
|
"""Test that SessionRepository raises an error when initialized without a db parameter."""
|
|
# Attempt to create a repository without a database connection
|
|
with pytest.raises(ValueError) as excinfo:
|
|
SessionRepository(db=None)
|
|
|
|
# Verify the correct error message
|
|
assert "Database connection is required" in str(excinfo.value)
|
|
|
|
|
|
def test_get_current_session_id(setup_db, sample_session):
|
|
"""Test retrieving the ID of the current session."""
|
|
# Set up repository
|
|
repo = SessionRepository(db=setup_db)
|
|
|
|
# Set the current session
|
|
repo.current_session = sample_session
|
|
|
|
# Get the current session ID
|
|
session_id = repo.get_current_session_id()
|
|
|
|
# Verify the ID matches
|
|
assert session_id == sample_session.id
|
|
|
|
# Test when no current session exists
|
|
repo.current_session = None
|
|
# Delete all sessions
|
|
Session.delete().execute()
|
|
|
|
# Verify None is returned when no session exists
|
|
session_id = repo.get_current_session_id()
|
|
assert session_id is None |