agent context

This commit is contained in:
AI Christianson 2025-02-26 19:30:08 -05:00
parent 724dbd4fda
commit 28d9032ca5
9 changed files with 760 additions and 79 deletions

150
ra_aid/agent_context.py Normal file
View File

@ -0,0 +1,150 @@
"""Context manager for tracking agent state and completion status."""
import threading
from contextlib import contextmanager
from typing import Dict, Optional, Set
# Thread-local storage for context variables
_thread_local = threading.local()
class AgentContext:
"""Context manager for agent state tracking."""
def __init__(self, parent_context=None):
"""Initialize a new agent context.
Args:
parent_context: Optional parent context to inherit state from
"""
# Initialize completion flags
self.task_completed = False
self.plan_completed = False
self.completion_message = ""
# Inherit state from parent if provided
if parent_context:
self.task_completed = parent_context.task_completed
self.plan_completed = parent_context.plan_completed
self.completion_message = parent_context.completion_message
def mark_task_completed(self, message: str) -> None:
"""Mark the current task as completed.
Args:
message: Completion message explaining how/why the task is complete
"""
self.task_completed = True
self.completion_message = message
def mark_plan_completed(self, message: str) -> None:
"""Mark the current plan as completed.
Args:
message: Completion message explaining how the plan was completed
"""
self.task_completed = True
self.plan_completed = True
self.completion_message = message
def reset_completion_flags(self) -> None:
"""Reset all completion flags."""
self.task_completed = False
self.plan_completed = False
self.completion_message = ""
@property
def is_completed(self) -> bool:
"""Check if the current context is marked as completed."""
return self.task_completed or self.plan_completed
def get_current_context() -> Optional[AgentContext]:
"""Get the current agent context for this thread.
Returns:
The current AgentContext or None if no context is active
"""
return getattr(_thread_local, "current_context", None)
@contextmanager
def agent_context(parent_context=None):
"""Context manager for agent execution.
Creates a new agent context and makes it the current context for the duration
of the with block. Restores the previous context when exiting the block.
Args:
parent_context: Optional parent context to inherit state from
Yields:
The newly created AgentContext
"""
# Save the previous context
previous_context = getattr(_thread_local, "current_context", None)
# Create a new context, inheriting from parent if provided
# If parent_context is None but previous_context exists, use previous_context as parent
if parent_context is None and previous_context is not None:
context = AgentContext(previous_context)
else:
context = AgentContext(parent_context)
# Set as current context
_thread_local.current_context = context
try:
yield context
finally:
# Restore previous context
_thread_local.current_context = previous_context
def mark_task_completed(message: str) -> None:
"""Mark the current task as completed.
Args:
message: Completion message explaining how/why the task is complete
"""
context = get_current_context()
if context:
context.mark_task_completed(message)
def mark_plan_completed(message: str) -> None:
"""Mark the current plan as completed.
Args:
message: Completion message explaining how the plan was completed
"""
context = get_current_context()
if context:
context.mark_plan_completed(message)
def reset_completion_flags() -> None:
"""Reset completion flags in the current context."""
context = get_current_context()
if context:
context.reset_completion_flags()
def is_completed() -> bool:
"""Check if the current context is marked as completed.
Returns:
True if the current context is marked as completed, False otherwise
"""
context = get_current_context()
return context.is_completed if context else False
def get_completion_message() -> str:
"""Get the completion message from the current context.
Returns:
The completion message or empty string if no context or no message
"""
context = get_current_context()
return context.completion_message if context else ""

View File

@ -7,7 +7,7 @@ import threading
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence
from typing import Any, Dict, List, Literal, Optional, Sequence, ContextManager
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
@ -69,6 +69,13 @@ from ra_aid.tool_configs import (
get_web_research_tools,
)
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.agent_context import (
agent_context,
get_current_context,
is_completed,
reset_completion_flags,
get_completion_message,
)
from ra_aid.tools.memory import (
_global_memory,
get_memory_value,
@ -821,9 +828,8 @@ def _decrement_agent_depth():
def reset_agent_completion_flags():
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
"""Reset completion flags in the current context."""
reset_completion_flags()
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
@ -897,8 +903,8 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
check_interrupt()
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
if is_completed():
reset_completion_flags()
break
@ -919,7 +925,8 @@ def run_agent_with_retry(
original_prompt = prompt
msg_list = [HumanMessage(content=prompt)]
with InterruptibleSection():
# Create a new agent context for this run
with InterruptibleSection(), agent_context() as ctx:
try:
_increment_agent_depth()
for attempt in range(max_retries):

View File

@ -247,7 +247,7 @@ def create_llm_client(
if supports_thinking:
temp_kwargs = {"thinking": {
"type": "enabled",
"budget_tokens": 8000
"budget_tokens": 12000
}}
if provider == "deepseek":

View File

@ -527,6 +527,8 @@ FOLLOW TEST DRIVEN DEVELOPMENT (TDD) PRACTICES WHERE POSSIBE. E.G. COMPILE CODE
IF YOU CAN SEE THE CODE WRITTEN/CHANGED BY THE PROGRAMMER, TRUST IT. YOU DO NOT NEED TO RE-READ EVERY FILE WITH EVERY SMALL EDIT.
YOU MUST CALL emit_related_files BEFORE CALLING run_programming_task WITH ALL RELEVANT FILES, UNLESS THEY ARE ALREADY RECORDED AS RELATED FILES.
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
"""

View File

@ -0,0 +1,372 @@
"""Unit tests for the agent_context module."""
import threading
import time
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_task_completed,
mark_plan_completed,
reset_completion_flags,
is_completed,
get_completion_message,
)
class TestAgentContext:
"""Test cases for the AgentContext class and related functions."""
def test_context_creation(self):
"""Test creating a new context."""
context = AgentContext()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_context_inheritance(self):
"""Test that child contexts inherit state from parent contexts."""
parent = AgentContext()
parent.mark_task_completed("Parent task completed")
child = AgentContext(parent_context=parent)
assert child.task_completed is True
assert child.completion_message == "Parent task completed"
def test_mark_task_completed(self):
"""Test marking a task as completed."""
context = AgentContext()
context.mark_task_completed("Task done")
assert context.task_completed is True
assert context.plan_completed is False
assert context.completion_message == "Task done"
def test_mark_plan_completed(self):
"""Test marking a plan as completed."""
context = AgentContext()
context.mark_plan_completed("Plan done")
assert context.task_completed is True
assert context.plan_completed is True
assert context.completion_message == "Plan done"
def test_reset_completion_flags(self):
"""Test resetting completion flags."""
context = AgentContext()
context.mark_task_completed("Task done")
context.reset_completion_flags()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_is_completed_property(self):
"""Test the is_completed property."""
context = AgentContext()
assert context.is_completed is False
context.mark_task_completed("Task done")
assert context.is_completed is True
context.reset_completion_flags()
assert context.is_completed is False
context.mark_plan_completed("Plan done")
assert context.is_completed is True
class TestContextManager:
"""Test cases for the agent_context context manager."""
def test_context_manager_basic(self):
"""Test basic context manager functionality."""
assert get_current_context() is None
with agent_context() as ctx:
assert get_current_context() is ctx
assert ctx.task_completed is False
assert get_current_context() is None
def test_nested_context_managers(self):
"""Test nested context managers."""
with agent_context() as outer_ctx:
assert get_current_context() is outer_ctx
with agent_context() as inner_ctx:
assert get_current_context() is inner_ctx
assert inner_ctx is not outer_ctx
assert get_current_context() is outer_ctx
def test_context_manager_with_parent(self):
"""Test context manager with explicit parent context."""
parent = AgentContext()
parent.mark_task_completed("Parent task")
with agent_context(parent_context=parent) as ctx:
assert ctx.task_completed is True
assert ctx.completion_message == "Parent task"
def test_context_manager_inheritance(self):
"""Test that nested contexts inherit from outer contexts by default."""
with agent_context() as outer:
outer.mark_task_completed("Outer task")
with agent_context() as inner:
assert inner.task_completed is True
assert inner.completion_message == "Outer task"
inner.mark_plan_completed("Inner plan")
# Outer context should not be affected by inner context changes
assert outer.task_completed is True
assert outer.plan_completed is False
assert outer.completion_message == "Outer task"
class TestThreadIsolation:
"""Test thread isolation of context variables."""
def test_thread_isolation(self):
"""Test that contexts are isolated between threads."""
results = {}
def thread_func(thread_id):
with agent_context() as ctx:
ctx.mark_task_completed(f"Thread {thread_id}")
time.sleep(0.1) # Give other threads time to run
# Store the context's message for verification
results[thread_id] = get_completion_message()
threads = []
for i in range(3):
t = threading.Thread(target=thread_func, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Each thread should have its own message
assert results[0] == "Thread 0"
assert results[1] == "Thread 1"
assert results[2] == "Thread 2"
class TestUtilityFunctions:
"""Test utility functions that operate on the current context."""
def test_mark_task_completed_utility(self):
"""Test the mark_task_completed utility function."""
with agent_context():
mark_task_completed("Task done via utility")
assert is_completed() is True
assert get_completion_message() == "Task done via utility"
def test_mark_plan_completed_utility(self):
"""Test the mark_plan_completed utility function."""
with agent_context():
mark_plan_completed("Plan done via utility")
assert is_completed() is True
assert get_completion_message() == "Plan done via utility"
def test_reset_completion_flags_utility(self):
"""Test the reset_completion_flags utility function."""
with agent_context():
mark_task_completed("Task done")
reset_completion_flags()
assert is_completed() is False
assert get_completion_message() == ""
def test_utility_functions_without_context(self):
"""Test utility functions when no context is active."""
# These should not raise exceptions even without an active context
mark_task_completed("No context")
mark_plan_completed("No context")
reset_completion_flags()
# These should have safe default returns
assert is_completed() is False
assert get_completion_message() == ""
"""Unit tests for the agent_context module."""
import threading
import time
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_task_completed,
mark_plan_completed,
reset_completion_flags,
is_completed,
get_completion_message,
)
class TestAgentContext:
"""Test cases for the AgentContext class and related functions."""
def test_context_creation(self):
"""Test creating a new context."""
context = AgentContext()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_context_inheritance(self):
"""Test that child contexts inherit state from parent contexts."""
parent = AgentContext()
parent.mark_task_completed("Parent task completed")
child = AgentContext(parent_context=parent)
assert child.task_completed is True
assert child.completion_message == "Parent task completed"
def test_mark_task_completed(self):
"""Test marking a task as completed."""
context = AgentContext()
context.mark_task_completed("Task done")
assert context.task_completed is True
assert context.plan_completed is False
assert context.completion_message == "Task done"
def test_mark_plan_completed(self):
"""Test marking a plan as completed."""
context = AgentContext()
context.mark_plan_completed("Plan done")
assert context.task_completed is True
assert context.plan_completed is True
assert context.completion_message == "Plan done"
def test_reset_completion_flags(self):
"""Test resetting completion flags."""
context = AgentContext()
context.mark_task_completed("Task done")
context.reset_completion_flags()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_is_completed_property(self):
"""Test the is_completed property."""
context = AgentContext()
assert context.is_completed is False
context.mark_task_completed("Task done")
assert context.is_completed is True
context.reset_completion_flags()
assert context.is_completed is False
context.mark_plan_completed("Plan done")
assert context.is_completed is True
class TestContextManager:
"""Test cases for the agent_context context manager."""
def test_context_manager_basic(self):
"""Test basic context manager functionality."""
assert get_current_context() is None
with agent_context() as ctx:
assert get_current_context() is ctx
assert ctx.task_completed is False
assert get_current_context() is None
def test_nested_context_managers(self):
"""Test nested context managers."""
with agent_context() as outer_ctx:
assert get_current_context() is outer_ctx
with agent_context() as inner_ctx:
assert get_current_context() is inner_ctx
assert inner_ctx is not outer_ctx
assert get_current_context() is outer_ctx
def test_context_manager_with_parent(self):
"""Test context manager with explicit parent context."""
parent = AgentContext()
parent.mark_task_completed("Parent task")
with agent_context(parent_context=parent) as ctx:
assert ctx.task_completed is True
assert ctx.completion_message == "Parent task"
def test_context_manager_inheritance(self):
"""Test that nested contexts inherit from outer contexts by default."""
with agent_context() as outer:
outer.mark_task_completed("Outer task")
with agent_context() as inner:
assert inner.task_completed is True
assert inner.completion_message == "Outer task"
inner.mark_plan_completed("Inner plan")
# Outer context should not be affected by inner context changes
assert outer.task_completed is True
assert outer.plan_completed is False
assert outer.completion_message == "Outer task"
class TestThreadIsolation:
"""Test thread isolation of context variables."""
def test_thread_isolation(self):
"""Test that contexts are isolated between threads."""
results = {}
def thread_func(thread_id):
with agent_context() as ctx:
ctx.mark_task_completed(f"Thread {thread_id}")
time.sleep(0.1) # Give other threads time to run
# Store the context's message for verification
results[thread_id] = get_completion_message()
threads = []
for i in range(3):
t = threading.Thread(target=thread_func, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Each thread should have its own message
assert results[0] == "Thread 0"
assert results[1] == "Thread 1"
assert results[2] == "Thread 2"
class TestUtilityFunctions:
"""Test utility functions that operate on the current context."""
def test_mark_task_completed_utility(self):
"""Test the mark_task_completed utility function."""
with agent_context():
mark_task_completed("Task done via utility")
assert is_completed() is True
assert get_completion_message() == "Task done via utility"
def test_mark_plan_completed_utility(self):
"""Test the mark_plan_completed utility function."""
with agent_context():
mark_plan_completed("Plan done via utility")
assert is_completed() is True
assert get_completion_message() == "Plan done via utility"
def test_reset_completion_flags_utility(self):
"""Test the reset_completion_flags utility function."""
with agent_context():
mark_task_completed("Task done")
reset_completion_flags()
assert is_completed() is False
assert get_completion_message() == ""
def test_utility_functions_without_context(self):
"""Test utility functions when no context is active."""
# These should not raise exceptions even without an active context
mark_task_completed("No context")
mark_plan_completed("No context")
reset_completion_flags()
# These should have safe default returns
assert is_completed() is False
assert get_completion_message() == ""

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Union
from langchain_core.tools import tool
from rich.console import Console
from ra_aid.agent_context import get_completion_message, reset_completion_flags
from ra_aid.console.formatting import print_error
from ra_aid.exceptions import AgentInterrupt
from ra_aid.tools.memory import _global_memory
@ -84,16 +85,12 @@ def request_research(query: str) -> ResearchResult:
reason = f"error: {str(e)}"
finally:
# Get completion message if available
completion_message = _global_memory.get(
"completion_message",
"Task was completed successfully." if success else None,
)
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
work_log = get_work_log()
# Clear completion state from global memory
_global_memory["completion_message"] = ""
_global_memory["task_completed"] = False
# Clear completion state
reset_completion_flags()
response_data = {
"completion_message": completion_message,
@ -152,16 +149,12 @@ def request_web_research(query: str) -> ResearchResult:
reason = f"error: {str(e)}"
finally:
# Get completion message if available
completion_message = _global_memory.get(
"completion_message",
"Task was completed successfully." if success else None,
)
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
work_log = get_work_log()
# Clear completion state from global memory
_global_memory["completion_message"] = ""
_global_memory["task_completed"] = False
# Clear completion state
reset_completion_flags()
response_data = {
"completion_message": completion_message,
@ -222,16 +215,12 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
reason = f"error: {str(e)}"
# Get completion message if available
completion_message = _global_memory.get(
"completion_message", "Task was completed successfully." if success else None
)
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
work_log = get_work_log()
# Clear completion state from global memory
_global_memory["completion_message"] = ""
_global_memory["task_completed"] = False
_global_memory["plan_completed"] = False
# Clear completion state
reset_completion_flags()
response_data = {
"completion_message": completion_message,
@ -276,7 +265,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
# Run implementation agent
from ..agent_utils import run_task_implementation_agent
_global_memory["completion_message"] = ""
reset_completion_flags()
_result = run_task_implementation_agent(
base_task=_global_memory.get("base_task", ""),
@ -304,16 +293,13 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
reason = f"error: {str(e)}"
# Get completion message if available
completion_message = _global_memory.get(
"completion_message", "Task was completed successfully." if success else None
)
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
# Get and reset work log if at root depth
work_log = get_work_log()
# Clear completion state from global memory
_global_memory["completion_message"] = ""
_global_memory["task_completed"] = False
# Clear completion state
reset_completion_flags()
response_data = {
"key_facts": get_memory_value("key_facts"),
@ -325,6 +311,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
}
if work_log is not None:
response_data["work_log"] = work_log
print("TASK HERE", response_data)
return response_data
@ -347,7 +334,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
# Run planning agent
from ..agent_utils import run_planning_agent
_global_memory["completion_message"] = ""
reset_completion_flags()
_result = run_planning_agent(
task_spec,
@ -372,17 +359,13 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
reason = f"error: {str(e)}"
# Get completion message if available
completion_message = _global_memory.get(
"completion_message", "Task was completed successfully." if success else None
)
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
# Get and reset work log if at root depth
work_log = get_work_log()
# Clear completion state from global memory
_global_memory["completion_message"] = ""
_global_memory["task_completed"] = False
_global_memory["plan_completed"] = False
# Clear completion state
reset_completion_flags()
response_data = {
"completion_message": completion_message,
@ -394,4 +377,6 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
}
if work_log is not None:
response_data["work_log"] = work_log
print("HERE", response_data)
return response_data

View File

@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
try:
import magic
@ -12,6 +12,8 @@ from rich.markdown import Markdown
from rich.panel import Panel
from typing_extensions import TypedDict
from ra_aid.agent_context import mark_task_completed, mark_plan_completed
class WorkLogEntry(TypedDict):
timestamp: str
@ -28,25 +30,10 @@ class SnippetInfo(TypedDict):
console = Console()
# Global memory store
_global_memory: Dict[
str,
Union[
Dict[int, str],
Dict[int, SnippetInfo],
Dict[int, WorkLogEntry],
int,
Set[str],
bool,
str,
List[str],
List[WorkLogEntry],
],
] = {
_global_memory: Dict[str, Any] = {
"research_notes": [],
"plans": [],
"tasks": {}, # Dict[int, str] - ID to task mapping
"task_completed": False, # Flag indicating if task is complete
"completion_message": "", # Message explaining completion
"task_id_counter": 1, # Counter for generating unique task IDs
"key_facts": {}, # Dict[int, str] - ID to fact mapping
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
@ -55,7 +42,6 @@ _global_memory: Dict[
"implementation_requested": False,
"related_files": {}, # Dict[int, str] - ID to filepath mapping
"related_file_id_counter": 1, # Counter for generating unique file IDs
"plan_completed": False,
"agent_depth": 0,
"work_log": [], # List[WorkLogEntry] - Timestamped work events
}
@ -327,10 +313,9 @@ def one_shot_completed(message: str) -> str:
if _global_memory.get("implementation_requested", False):
return "Cannot complete in one shot - implementation was requested"
_global_memory["task_completed"] = True
_global_memory["completion_message"] = message
mark_task_completed(message)
console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed\n\n{message}")
log_work_event(f"Task completed:\n\n{message}")
return "Completion noted."
@ -341,9 +326,9 @@ def task_completed(message: str) -> str:
Args:
message: Message explaining how/why the task is complete
"""
_global_memory["task_completed"] = True
_global_memory["completion_message"] = message
mark_task_completed(message)
console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed:\n\n{message}")
return "Completion noted."
@ -354,14 +339,11 @@ def plan_implementation_completed(message: str) -> str:
Args:
message: Message explaining how the implementation plan was completed
"""
_global_memory["task_completed"] = True
_global_memory["completion_message"] = message
_global_memory["plan_completed"] = True
_global_memory["completion_message"] = message
mark_plan_completed(message)
_global_memory["tasks"].clear() # Clear task list when plan is completed
_global_memory["task_id_counter"] = 1
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Plan execution completed:\n\n{message}")
log_work_event(f"Completed implementation:\n\n{message}")
return "Plan completion noted and task list cleared."

View File

@ -0,0 +1,178 @@
"""Unit tests for the agent_context module."""
import threading
import time
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_task_completed,
mark_plan_completed,
reset_completion_flags,
is_completed,
get_completion_message,
)
class TestAgentContext:
"""Test cases for the AgentContext class and related functions."""
def test_context_creation(self):
"""Test creating a new context."""
context = AgentContext()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_context_inheritance(self):
"""Test that child contexts inherit state from parent contexts."""
parent = AgentContext()
parent.mark_task_completed("Parent task completed")
child = AgentContext(parent_context=parent)
assert child.task_completed is True
assert child.completion_message == "Parent task completed"
def test_mark_task_completed(self):
"""Test marking a task as completed."""
context = AgentContext()
context.mark_task_completed("Task done")
assert context.task_completed is True
assert context.plan_completed is False
assert context.completion_message == "Task done"
def test_mark_plan_completed(self):
"""Test marking a plan as completed."""
context = AgentContext()
context.mark_plan_completed("Plan done")
assert context.task_completed is True
assert context.plan_completed is True
assert context.completion_message == "Plan done"
def test_reset_completion_flags(self):
"""Test resetting completion flags."""
context = AgentContext()
context.mark_task_completed("Task done")
context.reset_completion_flags()
assert context.task_completed is False
assert context.plan_completed is False
assert context.completion_message == ""
def test_is_completed_property(self):
"""Test the is_completed property."""
context = AgentContext()
assert context.is_completed is False
context.mark_task_completed("Task done")
assert context.is_completed is True
context.reset_completion_flags()
assert context.is_completed is False
context.mark_plan_completed("Plan done")
assert context.is_completed is True
class TestContextManager:
"""Test cases for the agent_context context manager."""
def test_context_manager_basic(self):
"""Test basic context manager functionality."""
assert get_current_context() is None
with agent_context() as ctx:
assert get_current_context() is ctx
assert ctx.task_completed is False
assert get_current_context() is None
def test_nested_context_managers(self):
"""Test nested context managers."""
with agent_context() as outer_ctx:
assert get_current_context() is outer_ctx
with agent_context() as inner_ctx:
assert get_current_context() is inner_ctx
assert inner_ctx is not outer_ctx
assert get_current_context() is outer_ctx
def test_context_manager_with_parent(self):
"""Test context manager with explicit parent context."""
parent = AgentContext()
parent.mark_task_completed("Parent task")
with agent_context(parent_context=parent) as ctx:
assert ctx.task_completed is True
assert ctx.completion_message == "Parent task"
def test_context_manager_inheritance(self):
"""Test that nested contexts inherit from outer contexts by default."""
with agent_context() as outer:
outer.mark_task_completed("Outer task")
with agent_context() as inner:
assert inner.task_completed is True
assert inner.completion_message == "Outer task"
inner.mark_plan_completed("Inner plan")
# Outer context should not be affected by inner context changes
assert outer.task_completed is True
assert outer.plan_completed is False
assert outer.completion_message == "Outer task"
class TestThreadIsolation:
"""Test thread isolation of context variables."""
def test_thread_isolation(self):
"""Test that contexts are isolated between threads."""
results = {}
def thread_func(thread_id):
with agent_context() as ctx:
ctx.mark_task_completed(f"Thread {thread_id}")
time.sleep(0.1) # Give other threads time to run
# Store the context's message for verification
results[thread_id] = get_completion_message()
threads = []
for i in range(3):
t = threading.Thread(target=thread_func, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Each thread should have its own message
assert results[0] == "Thread 0"
assert results[1] == "Thread 1"
assert results[2] == "Thread 2"
class TestUtilityFunctions:
"""Test utility functions that operate on the current context."""
def test_mark_task_completed_utility(self):
"""Test the mark_task_completed utility function."""
with agent_context():
mark_task_completed("Task done via utility")
assert is_completed() is True
assert get_completion_message() == "Task done via utility"
def test_mark_plan_completed_utility(self):
"""Test the mark_plan_completed utility function."""
with agent_context():
mark_plan_completed("Plan done via utility")
assert is_completed() is True
assert get_completion_message() == "Plan done via utility"
def test_reset_completion_flags_utility(self):
"""Test the reset_completion_flags utility function."""
with agent_context():
mark_task_completed("Task done")
reset_completion_flags()
assert is_completed() is False
assert get_completion_message() == ""
def test_utility_functions_without_context(self):
"""Test utility functions when no context is active."""
# These should not raise exceptions even without an active context
mark_task_completed("No context")
mark_plan_completed("No context")
reset_completion_flags()
# These should have safe default returns
assert is_completed() is False
assert get_completion_message() == ""

View File

@ -8,6 +8,7 @@ import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from ra_aid.agent_context import agent_context, get_current_context, reset_completion_flags
from ra_aid.agent_utils import (
AgentState,
create_agent,
@ -325,7 +326,7 @@ def test_increment_and_decrement_agent_depth():
def test_run_agent_stream(monkeypatch):
from ra_aid.agent_utils import _global_memory, _run_agent_stream
from ra_aid.agent_utils import _run_agent_stream
# Create a dummy agent that yields one chunk
class DummyAgent:
@ -334,9 +335,11 @@ def test_run_agent_stream(monkeypatch):
dummy_agent = DummyAgent()
# Set flags so that _run_agent_stream will reset them
_global_memory["plan_completed"] = True
_global_memory["task_completed"] = True
_global_memory["completion_message"] = "existing"
with agent_context() as ctx:
ctx.plan_completed = True
ctx.task_completed = True
ctx.completion_message = "existing"
call_flag = {"called": False}
def fake_print_agent_output(
@ -349,9 +352,11 @@ def test_run_agent_stream(monkeypatch):
)
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
assert call_flag["called"]
assert _global_memory["plan_completed"] is False
assert _global_memory["task_completed"] is False
assert _global_memory["completion_message"] == ""
with agent_context() as ctx:
assert ctx.plan_completed is False
assert ctx.task_completed is False
assert ctx.completion_message == ""
def test_execute_test_command_wrapper(monkeypatch):