RA.Aid/tests/ra_aid/test_agent_context.py

249 lines
9.5 KiB
Python

"""Unit tests for the agent_context module."""
import threading
import time
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_task_completed,
mark_plan_completed,
reset_completion_flags,
is_completed,
get_completion_message,
mark_should_exit,
should_exit,
)
class TestAgentContext:
"""Test cases for the AgentContext class and related functions."""
def test_context_creation(self):
"""Test creating a new context."""
context = AgentContext()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_context_inheritance(self):
"""Test that child contexts do not inherit completion flags from parent contexts."""
parent = AgentContext()
parent.mark_task_completed("Parent task completed")
child = AgentContext(parent_context=parent)
assert child.task_completed is False
assert child.completion_message == ""
def test_mark_task_completed(self):
"""Test marking a task as completed."""
context = AgentContext()
context.mark_task_completed("Task done")
assert context.task_completed is True
assert context.plan_completed is False
assert context.completion_message == "Task done"
def test_mark_plan_completed(self):
"""Test marking a plan as completed."""
context = AgentContext()
context.mark_plan_completed("Plan done")
assert context.task_completed is True
assert context.plan_completed is True
assert context.completion_message == "Plan done"
def test_reset_completion_flags(self):
"""Test resetting completion flags."""
context = AgentContext()
context.mark_task_completed("Task done")
context.reset_completion_flags()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_is_completed_property(self):
"""Test the is_completed property."""
context = AgentContext()
assert context.is_completed is False
context.mark_task_completed("Task done")
assert context.is_completed is True
context.reset_completion_flags()
assert context.is_completed is False
context.mark_plan_completed("Plan done")
assert context.is_completed is True
class TestContextManager:
"""Test cases for the agent_context context manager."""
def test_context_manager_basic(self):
"""Test basic context manager functionality."""
assert get_current_context() is None
with agent_context() as ctx:
assert get_current_context() is ctx
assert ctx.task_completed is False
assert get_current_context() is None
def test_nested_context_managers(self):
"""Test nested context managers."""
with agent_context() as outer_ctx:
assert get_current_context() is outer_ctx
with agent_context() as inner_ctx:
assert get_current_context() is inner_ctx
assert inner_ctx is not outer_ctx
assert get_current_context() is outer_ctx
def test_context_manager_with_parent(self):
"""Test context manager with explicit parent context."""
parent = AgentContext()
parent.mark_task_completed("Parent task")
with agent_context(parent_context=parent) as ctx:
assert ctx.task_completed is False
assert ctx.completion_message == ""
def test_context_manager_inheritance(self):
"""Test that nested contexts do not inherit completion flags from outer contexts."""
with agent_context() as outer:
outer.mark_task_completed("Outer task")
with agent_context() as inner:
assert inner.task_completed is False
assert inner.completion_message == ""
inner.mark_plan_completed("Inner plan")
# Outer context should not be affected by inner context changes
assert outer.task_completed is True
assert outer.plan_completed is False
assert outer.completion_message == "Outer task"
class TestExitPropagation:
"""Test cases for the agent_should_exit flag propagation."""
def test_mark_should_exit_propagation(self):
"""Test that mark_should_exit propagates to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Initially both contexts should have agent_should_exit as False
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# Both child and parent should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
def test_nested_should_exit_propagation(self):
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
# Initially all contexts should have agent_should_exit as False
assert grandparent.agent_should_exit is False
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# All contexts should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
assert grandparent.agent_should_exit is True
def test_context_manager_should_exit_propagation(self):
"""Test that mark_should_exit propagates when using context managers."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark the inner context as should exit
inner.mark_should_exit()
# Both inner and outer should now have agent_should_exit as True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
class TestThreadIsolation:
"""Test thread isolation of context variables."""
def test_thread_isolation(self):
"""Test that contexts are isolated between threads."""
results = {}
def thread_func(thread_id):
with agent_context() as ctx:
ctx.mark_task_completed(f"Thread {thread_id}")
time.sleep(0.1) # Give other threads time to run
# Store the context's message for verification
results[thread_id] = get_completion_message()
threads = []
for i in range(3):
t = threading.Thread(target=thread_func, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Each thread should have its own message
assert results[0] == "Thread 0"
assert results[1] == "Thread 1"
assert results[2] == "Thread 2"
class TestUtilityFunctions:
"""Test utility functions that operate on the current context."""
def test_mark_task_completed_utility(self):
"""Test the mark_task_completed utility function."""
with agent_context():
mark_task_completed("Task done via utility")
assert is_completed() is True
assert get_completion_message() == "Task done via utility"
def test_mark_plan_completed_utility(self):
"""Test the mark_plan_completed utility function."""
with agent_context():
mark_plan_completed("Plan done via utility")
assert is_completed() is True
assert get_completion_message() == "Plan done via utility"
def test_reset_completion_flags_utility(self):
"""Test the reset_completion_flags utility function."""
with agent_context():
mark_task_completed("Task done")
reset_completion_flags()
assert is_completed() is False
assert get_completion_message() == ""
def test_utility_functions_without_context(self):
"""Test utility functions when no context is active."""
# These should not raise exceptions even without an active context
mark_task_completed("No context")
mark_plan_completed("No context")
reset_completion_flags()
# These should have safe default returns
assert is_completed() is False
assert get_completion_message() == ""
def test_mark_should_exit_utility(self):
"""Test the mark_should_exit utility function."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert should_exit() is False
# Mark the current context (inner) as should exit
mark_should_exit()
# Both inner and outer should now have agent_should_exit as True
assert should_exit() is True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True