diff --git a/ra_aid/agent_context.py b/ra_aid/agent_context.py new file mode 100644 index 0000000..2d5ddfa --- /dev/null +++ b/ra_aid/agent_context.py @@ -0,0 +1,150 @@ +"""Context manager for tracking agent state and completion status.""" + +import threading +from contextlib import contextmanager +from typing import Dict, Optional, Set + +# Thread-local storage for context variables +_thread_local = threading.local() + + +class AgentContext: + """Context manager for agent state tracking.""" + + def __init__(self, parent_context=None): + """Initialize a new agent context. + + Args: + parent_context: Optional parent context to inherit state from + """ + # Initialize completion flags + self.task_completed = False + self.plan_completed = False + self.completion_message = "" + + # Inherit state from parent if provided + if parent_context: + self.task_completed = parent_context.task_completed + self.plan_completed = parent_context.plan_completed + self.completion_message = parent_context.completion_message + + def mark_task_completed(self, message: str) -> None: + """Mark the current task as completed. + + Args: + message: Completion message explaining how/why the task is complete + """ + self.task_completed = True + self.completion_message = message + + def mark_plan_completed(self, message: str) -> None: + """Mark the current plan as completed. + + Args: + message: Completion message explaining how the plan was completed + """ + self.task_completed = True + self.plan_completed = True + self.completion_message = message + + def reset_completion_flags(self) -> None: + """Reset all completion flags.""" + self.task_completed = False + self.plan_completed = False + self.completion_message = "" + + @property + def is_completed(self) -> bool: + """Check if the current context is marked as completed.""" + return self.task_completed or self.plan_completed + + +def get_current_context() -> Optional[AgentContext]: + """Get the current agent context for this thread. + + Returns: + The current AgentContext or None if no context is active + """ + return getattr(_thread_local, "current_context", None) + + +@contextmanager +def agent_context(parent_context=None): + """Context manager for agent execution. + + Creates a new agent context and makes it the current context for the duration + of the with block. Restores the previous context when exiting the block. + + Args: + parent_context: Optional parent context to inherit state from + + Yields: + The newly created AgentContext + """ + # Save the previous context + previous_context = getattr(_thread_local, "current_context", None) + + # Create a new context, inheriting from parent if provided + # If parent_context is None but previous_context exists, use previous_context as parent + if parent_context is None and previous_context is not None: + context = AgentContext(previous_context) + else: + context = AgentContext(parent_context) + + # Set as current context + _thread_local.current_context = context + + try: + yield context + finally: + # Restore previous context + _thread_local.current_context = previous_context + + +def mark_task_completed(message: str) -> None: + """Mark the current task as completed. + + Args: + message: Completion message explaining how/why the task is complete + """ + context = get_current_context() + if context: + context.mark_task_completed(message) + + +def mark_plan_completed(message: str) -> None: + """Mark the current plan as completed. + + Args: + message: Completion message explaining how the plan was completed + """ + context = get_current_context() + if context: + context.mark_plan_completed(message) + + +def reset_completion_flags() -> None: + """Reset completion flags in the current context.""" + context = get_current_context() + if context: + context.reset_completion_flags() + + +def is_completed() -> bool: + """Check if the current context is marked as completed. + + Returns: + True if the current context is marked as completed, False otherwise + """ + context = get_current_context() + return context.is_completed if context else False + + +def get_completion_message() -> str: + """Get the completion message from the current context. + + Returns: + The completion message or empty string if no context or no message + """ + context = get_current_context() + return context.completion_message if context else "" diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 288d04e..e39d429 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -7,7 +7,7 @@ import threading import time import uuid from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence, ContextManager import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError @@ -69,6 +69,13 @@ from ra_aid.tool_configs import ( get_web_research_tools, ) from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command +from ra_aid.agent_context import ( + agent_context, + get_current_context, + is_completed, + reset_completion_flags, + get_completion_message, +) from ra_aid.tools.memory import ( _global_memory, get_memory_value, @@ -821,9 +828,8 @@ def _decrement_agent_depth(): def reset_agent_completion_flags(): - _global_memory["plan_completed"] = False - _global_memory["task_completed"] = False - _global_memory["completion_message"] = "" + """Reset completion flags in the current context.""" + reset_completion_flags() def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test): @@ -897,8 +903,8 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict) check_interrupt() agent_type = get_agent_type(agent) print_agent_output(chunk, agent_type) - if _global_memory["plan_completed"] or _global_memory["task_completed"]: - reset_agent_completion_flags() + if is_completed(): + reset_completion_flags() break @@ -919,7 +925,8 @@ def run_agent_with_retry( original_prompt = prompt msg_list = [HumanMessage(content=prompt)] - with InterruptibleSection(): + # Create a new agent context for this run + with InterruptibleSection(), agent_context() as ctx: try: _increment_agent_depth() for attempt in range(max_retries): diff --git a/ra_aid/llm.py b/ra_aid/llm.py index af0238d..105eed7 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -247,7 +247,7 @@ def create_llm_client( if supports_thinking: temp_kwargs = {"thinking": { "type": "enabled", - "budget_tokens": 8000 + "budget_tokens": 12000 }} if provider == "deepseek": diff --git a/ra_aid/prompts.py b/ra_aid/prompts.py index 8344838..8cfd984 100644 --- a/ra_aid/prompts.py +++ b/ra_aid/prompts.py @@ -527,6 +527,8 @@ FOLLOW TEST DRIVEN DEVELOPMENT (TDD) PRACTICES WHERE POSSIBE. E.G. COMPILE CODE IF YOU CAN SEE THE CODE WRITTEN/CHANGED BY THE PROGRAMMER, TRUST IT. YOU DO NOT NEED TO RE-READ EVERY FILE WITH EVERY SMALL EDIT. +YOU MUST CALL emit_related_files BEFORE CALLING run_programming_task WITH ALL RELEVANT FILES, UNLESS THEY ARE ALREADY RECORDED AS RELATED FILES. + NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT! """ diff --git a/ra_aid/tests/ra_aid/test_agent_context.py b/ra_aid/tests/ra_aid/test_agent_context.py new file mode 100644 index 0000000..a2a5da4 --- /dev/null +++ b/ra_aid/tests/ra_aid/test_agent_context.py @@ -0,0 +1,372 @@ +"""Unit tests for the agent_context module.""" + +import threading +import time +import pytest + +from ra_aid.agent_context import ( + AgentContext, + agent_context, + get_current_context, + mark_task_completed, + mark_plan_completed, + reset_completion_flags, + is_completed, + get_completion_message, +) + + +class TestAgentContext: + """Test cases for the AgentContext class and related functions.""" + + def test_context_creation(self): + """Test creating a new context.""" + context = AgentContext() + assert context.task_completed is False + assert context.plan_completed is False + assert context.completion_message == "" + + def test_context_inheritance(self): + """Test that child contexts inherit state from parent contexts.""" + parent = AgentContext() + parent.mark_task_completed("Parent task completed") + + child = AgentContext(parent_context=parent) + assert child.task_completed is True + assert child.completion_message == "Parent task completed" + + def test_mark_task_completed(self): + """Test marking a task as completed.""" + context = AgentContext() + context.mark_task_completed("Task done") + + assert context.task_completed is True + assert context.plan_completed is False + assert context.completion_message == "Task done" + + def test_mark_plan_completed(self): + """Test marking a plan as completed.""" + context = AgentContext() + context.mark_plan_completed("Plan done") + + assert context.task_completed is True + assert context.plan_completed is True + assert context.completion_message == "Plan done" + + def test_reset_completion_flags(self): + """Test resetting completion flags.""" + context = AgentContext() + context.mark_task_completed("Task done") + + context.reset_completion_flags() + assert context.task_completed is False + assert context.plan_completed is False + assert context.completion_message == "" + + def test_is_completed_property(self): + """Test the is_completed property.""" + context = AgentContext() + assert context.is_completed is False + + context.mark_task_completed("Task done") + assert context.is_completed is True + + context.reset_completion_flags() + assert context.is_completed is False + + context.mark_plan_completed("Plan done") + assert context.is_completed is True + + +class TestContextManager: + """Test cases for the agent_context context manager.""" + + def test_context_manager_basic(self): + """Test basic context manager functionality.""" + assert get_current_context() is None + + with agent_context() as ctx: + assert get_current_context() is ctx + assert ctx.task_completed is False + + assert get_current_context() is None + + def test_nested_context_managers(self): + """Test nested context managers.""" + with agent_context() as outer_ctx: + assert get_current_context() is outer_ctx + + with agent_context() as inner_ctx: + assert get_current_context() is inner_ctx + assert inner_ctx is not outer_ctx + + assert get_current_context() is outer_ctx + + def test_context_manager_with_parent(self): + """Test context manager with explicit parent context.""" + parent = AgentContext() + parent.mark_task_completed("Parent task") + + with agent_context(parent_context=parent) as ctx: + assert ctx.task_completed is True + assert ctx.completion_message == "Parent task" + + def test_context_manager_inheritance(self): + """Test that nested contexts inherit from outer contexts by default.""" + with agent_context() as outer: + outer.mark_task_completed("Outer task") + + with agent_context() as inner: + assert inner.task_completed is True + assert inner.completion_message == "Outer task" + + inner.mark_plan_completed("Inner plan") + + # Outer context should not be affected by inner context changes + assert outer.task_completed is True + assert outer.plan_completed is False + assert outer.completion_message == "Outer task" + + +class TestThreadIsolation: + """Test thread isolation of context variables.""" + + def test_thread_isolation(self): + """Test that contexts are isolated between threads.""" + results = {} + + def thread_func(thread_id): + with agent_context() as ctx: + ctx.mark_task_completed(f"Thread {thread_id}") + time.sleep(0.1) # Give other threads time to run + # Store the context's message for verification + results[thread_id] = get_completion_message() + + threads = [] + for i in range(3): + t = threading.Thread(target=thread_func, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Each thread should have its own message + assert results[0] == "Thread 0" + assert results[1] == "Thread 1" + assert results[2] == "Thread 2" + + +class TestUtilityFunctions: + """Test utility functions that operate on the current context.""" + + def test_mark_task_completed_utility(self): + """Test the mark_task_completed utility function.""" + with agent_context(): + mark_task_completed("Task done via utility") + assert is_completed() is True + assert get_completion_message() == "Task done via utility" + + def test_mark_plan_completed_utility(self): + """Test the mark_plan_completed utility function.""" + with agent_context(): + mark_plan_completed("Plan done via utility") + assert is_completed() is True + assert get_completion_message() == "Plan done via utility" + + def test_reset_completion_flags_utility(self): + """Test the reset_completion_flags utility function.""" + with agent_context(): + mark_task_completed("Task done") + reset_completion_flags() + assert is_completed() is False + assert get_completion_message() == "" + + def test_utility_functions_without_context(self): + """Test utility functions when no context is active.""" + # These should not raise exceptions even without an active context + mark_task_completed("No context") + mark_plan_completed("No context") + reset_completion_flags() + + # These should have safe default returns + assert is_completed() is False + assert get_completion_message() == "" +"""Unit tests for the agent_context module.""" + +import threading +import time +import pytest + +from ra_aid.agent_context import ( + AgentContext, + agent_context, + get_current_context, + mark_task_completed, + mark_plan_completed, + reset_completion_flags, + is_completed, + get_completion_message, +) + + +class TestAgentContext: + """Test cases for the AgentContext class and related functions.""" + + def test_context_creation(self): + """Test creating a new context.""" + context = AgentContext() + assert context.task_completed is False + assert context.plan_completed is False + assert context.completion_message == "" + + def test_context_inheritance(self): + """Test that child contexts inherit state from parent contexts.""" + parent = AgentContext() + parent.mark_task_completed("Parent task completed") + child = AgentContext(parent_context=parent) + assert child.task_completed is True + assert child.completion_message == "Parent task completed" + + def test_mark_task_completed(self): + """Test marking a task as completed.""" + context = AgentContext() + context.mark_task_completed("Task done") + assert context.task_completed is True + assert context.plan_completed is False + assert context.completion_message == "Task done" + + def test_mark_plan_completed(self): + """Test marking a plan as completed.""" + context = AgentContext() + context.mark_plan_completed("Plan done") + assert context.task_completed is True + assert context.plan_completed is True + assert context.completion_message == "Plan done" + + def test_reset_completion_flags(self): + """Test resetting completion flags.""" + context = AgentContext() + context.mark_task_completed("Task done") + context.reset_completion_flags() + assert context.task_completed is False + assert context.plan_completed is False + assert context.completion_message == "" + + def test_is_completed_property(self): + """Test the is_completed property.""" + context = AgentContext() + assert context.is_completed is False + context.mark_task_completed("Task done") + assert context.is_completed is True + context.reset_completion_flags() + assert context.is_completed is False + context.mark_plan_completed("Plan done") + assert context.is_completed is True + + +class TestContextManager: + """Test cases for the agent_context context manager.""" + + def test_context_manager_basic(self): + """Test basic context manager functionality.""" + assert get_current_context() is None + with agent_context() as ctx: + assert get_current_context() is ctx + assert ctx.task_completed is False + assert get_current_context() is None + + def test_nested_context_managers(self): + """Test nested context managers.""" + with agent_context() as outer_ctx: + assert get_current_context() is outer_ctx + with agent_context() as inner_ctx: + assert get_current_context() is inner_ctx + assert inner_ctx is not outer_ctx + assert get_current_context() is outer_ctx + + def test_context_manager_with_parent(self): + """Test context manager with explicit parent context.""" + parent = AgentContext() + parent.mark_task_completed("Parent task") + with agent_context(parent_context=parent) as ctx: + assert ctx.task_completed is True + assert ctx.completion_message == "Parent task" + + def test_context_manager_inheritance(self): + """Test that nested contexts inherit from outer contexts by default.""" + with agent_context() as outer: + outer.mark_task_completed("Outer task") + with agent_context() as inner: + assert inner.task_completed is True + assert inner.completion_message == "Outer task" + inner.mark_plan_completed("Inner plan") + # Outer context should not be affected by inner context changes + assert outer.task_completed is True + assert outer.plan_completed is False + assert outer.completion_message == "Outer task" + + +class TestThreadIsolation: + """Test thread isolation of context variables.""" + + def test_thread_isolation(self): + """Test that contexts are isolated between threads.""" + results = {} + + def thread_func(thread_id): + with agent_context() as ctx: + ctx.mark_task_completed(f"Thread {thread_id}") + time.sleep(0.1) # Give other threads time to run + # Store the context's message for verification + results[thread_id] = get_completion_message() + + threads = [] + for i in range(3): + t = threading.Thread(target=thread_func, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Each thread should have its own message + assert results[0] == "Thread 0" + assert results[1] == "Thread 1" + assert results[2] == "Thread 2" + + +class TestUtilityFunctions: + """Test utility functions that operate on the current context.""" + + def test_mark_task_completed_utility(self): + """Test the mark_task_completed utility function.""" + with agent_context(): + mark_task_completed("Task done via utility") + assert is_completed() is True + assert get_completion_message() == "Task done via utility" + + def test_mark_plan_completed_utility(self): + """Test the mark_plan_completed utility function.""" + with agent_context(): + mark_plan_completed("Plan done via utility") + assert is_completed() is True + assert get_completion_message() == "Plan done via utility" + + def test_reset_completion_flags_utility(self): + """Test the reset_completion_flags utility function.""" + with agent_context(): + mark_task_completed("Task done") + reset_completion_flags() + assert is_completed() is False + assert get_completion_message() == "" + + def test_utility_functions_without_context(self): + """Test utility functions when no context is active.""" + # These should not raise exceptions even without an active context + mark_task_completed("No context") + mark_plan_completed("No context") + reset_completion_flags() + # These should have safe default returns + assert is_completed() is False + assert get_completion_message() == "" diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 0a75aee..1e75da9 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Union from langchain_core.tools import tool from rich.console import Console +from ra_aid.agent_context import get_completion_message, reset_completion_flags from ra_aid.console.formatting import print_error from ra_aid.exceptions import AgentInterrupt from ra_aid.tools.memory import _global_memory @@ -84,16 +85,12 @@ def request_research(query: str) -> ResearchResult: reason = f"error: {str(e)}" finally: # Get completion message if available - completion_message = _global_memory.get( - "completion_message", - "Task was completed successfully." if success else None, - ) + completion_message = get_completion_message() or ("Task was completed successfully." if success else None) work_log = get_work_log() - # Clear completion state from global memory - _global_memory["completion_message"] = "" - _global_memory["task_completed"] = False + # Clear completion state + reset_completion_flags() response_data = { "completion_message": completion_message, @@ -152,16 +149,12 @@ def request_web_research(query: str) -> ResearchResult: reason = f"error: {str(e)}" finally: # Get completion message if available - completion_message = _global_memory.get( - "completion_message", - "Task was completed successfully." if success else None, - ) + completion_message = get_completion_message() or ("Task was completed successfully." if success else None) work_log = get_work_log() - # Clear completion state from global memory - _global_memory["completion_message"] = "" - _global_memory["task_completed"] = False + # Clear completion state + reset_completion_flags() response_data = { "completion_message": completion_message, @@ -222,16 +215,12 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: reason = f"error: {str(e)}" # Get completion message if available - completion_message = _global_memory.get( - "completion_message", "Task was completed successfully." if success else None - ) + completion_message = get_completion_message() or ("Task was completed successfully." if success else None) work_log = get_work_log() - # Clear completion state from global memory - _global_memory["completion_message"] = "" - _global_memory["task_completed"] = False - _global_memory["plan_completed"] = False + # Clear completion state + reset_completion_flags() response_data = { "completion_message": completion_message, @@ -276,7 +265,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: # Run implementation agent from ..agent_utils import run_task_implementation_agent - _global_memory["completion_message"] = "" + reset_completion_flags() _result = run_task_implementation_agent( base_task=_global_memory.get("base_task", ""), @@ -304,16 +293,13 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: reason = f"error: {str(e)}" # Get completion message if available - completion_message = _global_memory.get( - "completion_message", "Task was completed successfully." if success else None - ) + completion_message = get_completion_message() or ("Task was completed successfully." if success else None) # Get and reset work log if at root depth work_log = get_work_log() - # Clear completion state from global memory - _global_memory["completion_message"] = "" - _global_memory["task_completed"] = False + # Clear completion state + reset_completion_flags() response_data = { "key_facts": get_memory_value("key_facts"), @@ -325,6 +311,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: } if work_log is not None: response_data["work_log"] = work_log + print("TASK HERE", response_data) return response_data @@ -347,7 +334,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: # Run planning agent from ..agent_utils import run_planning_agent - _global_memory["completion_message"] = "" + reset_completion_flags() _result = run_planning_agent( task_spec, @@ -372,17 +359,13 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: reason = f"error: {str(e)}" # Get completion message if available - completion_message = _global_memory.get( - "completion_message", "Task was completed successfully." if success else None - ) + completion_message = get_completion_message() or ("Task was completed successfully." if success else None) # Get and reset work log if at root depth work_log = get_work_log() - # Clear completion state from global memory - _global_memory["completion_message"] = "" - _global_memory["task_completed"] = False - _global_memory["plan_completed"] = False + # Clear completion state + reset_completion_flags() response_data = { "completion_message": completion_message, @@ -394,4 +377,6 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: } if work_log is not None: response_data["work_log"] = work_log + + print("HERE", response_data) return response_data diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 278b435..21e362a 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union try: import magic @@ -12,6 +12,8 @@ from rich.markdown import Markdown from rich.panel import Panel from typing_extensions import TypedDict +from ra_aid.agent_context import mark_task_completed, mark_plan_completed + class WorkLogEntry(TypedDict): timestamp: str @@ -28,25 +30,10 @@ class SnippetInfo(TypedDict): console = Console() # Global memory store -_global_memory: Dict[ - str, - Union[ - Dict[int, str], - Dict[int, SnippetInfo], - Dict[int, WorkLogEntry], - int, - Set[str], - bool, - str, - List[str], - List[WorkLogEntry], - ], -] = { +_global_memory: Dict[str, Any] = { "research_notes": [], "plans": [], "tasks": {}, # Dict[int, str] - ID to task mapping - "task_completed": False, # Flag indicating if task is complete - "completion_message": "", # Message explaining completion "task_id_counter": 1, # Counter for generating unique task IDs "key_facts": {}, # Dict[int, str] - ID to fact mapping "key_fact_id_counter": 1, # Counter for generating unique fact IDs @@ -55,7 +42,6 @@ _global_memory: Dict[ "implementation_requested": False, "related_files": {}, # Dict[int, str] - ID to filepath mapping "related_file_id_counter": 1, # Counter for generating unique file IDs - "plan_completed": False, "agent_depth": 0, "work_log": [], # List[WorkLogEntry] - Timestamped work events } @@ -327,10 +313,9 @@ def one_shot_completed(message: str) -> str: if _global_memory.get("implementation_requested", False): return "Cannot complete in one shot - implementation was requested" - _global_memory["task_completed"] = True - _global_memory["completion_message"] = message + mark_task_completed(message) console.print(Panel(Markdown(message), title="✅ Task Completed")) - log_work_event(f"Task completed\n\n{message}") + log_work_event(f"Task completed:\n\n{message}") return "Completion noted." @@ -341,9 +326,9 @@ def task_completed(message: str) -> str: Args: message: Message explaining how/why the task is complete """ - _global_memory["task_completed"] = True - _global_memory["completion_message"] = message + mark_task_completed(message) console.print(Panel(Markdown(message), title="✅ Task Completed")) + log_work_event(f"Task completed:\n\n{message}") return "Completion noted." @@ -354,14 +339,11 @@ def plan_implementation_completed(message: str) -> str: Args: message: Message explaining how the implementation plan was completed """ - _global_memory["task_completed"] = True - _global_memory["completion_message"] = message - _global_memory["plan_completed"] = True - _global_memory["completion_message"] = message + mark_plan_completed(message) _global_memory["tasks"].clear() # Clear task list when plan is completed _global_memory["task_id_counter"] = 1 console.print(Panel(Markdown(message), title="✅ Plan Executed")) - log_work_event(f"Plan execution completed:\n\n{message}") + log_work_event(f"Completed implementation:\n\n{message}") return "Plan completion noted and task list cleared." diff --git a/tests/ra_aid/test_agent_context.py b/tests/ra_aid/test_agent_context.py new file mode 100644 index 0000000..b14d401 --- /dev/null +++ b/tests/ra_aid/test_agent_context.py @@ -0,0 +1,178 @@ +"""Unit tests for the agent_context module.""" + +import threading +import time +import pytest + +from ra_aid.agent_context import ( + AgentContext, + agent_context, + get_current_context, + mark_task_completed, + mark_plan_completed, + reset_completion_flags, + is_completed, + get_completion_message, +) + + +class TestAgentContext: + """Test cases for the AgentContext class and related functions.""" + + def test_context_creation(self): + """Test creating a new context.""" + context = AgentContext() + assert context.task_completed is False + assert context.plan_completed is False + assert context.completion_message == "" + + def test_context_inheritance(self): + """Test that child contexts inherit state from parent contexts.""" + parent = AgentContext() + parent.mark_task_completed("Parent task completed") + child = AgentContext(parent_context=parent) + assert child.task_completed is True + assert child.completion_message == "Parent task completed" + + def test_mark_task_completed(self): + """Test marking a task as completed.""" + context = AgentContext() + context.mark_task_completed("Task done") + assert context.task_completed is True + assert context.plan_completed is False + assert context.completion_message == "Task done" + + def test_mark_plan_completed(self): + """Test marking a plan as completed.""" + context = AgentContext() + context.mark_plan_completed("Plan done") + assert context.task_completed is True + assert context.plan_completed is True + assert context.completion_message == "Plan done" + + def test_reset_completion_flags(self): + """Test resetting completion flags.""" + context = AgentContext() + context.mark_task_completed("Task done") + context.reset_completion_flags() + assert context.task_completed is False + assert context.plan_completed is False + assert context.completion_message == "" + + def test_is_completed_property(self): + """Test the is_completed property.""" + context = AgentContext() + assert context.is_completed is False + context.mark_task_completed("Task done") + assert context.is_completed is True + context.reset_completion_flags() + assert context.is_completed is False + context.mark_plan_completed("Plan done") + assert context.is_completed is True + + +class TestContextManager: + """Test cases for the agent_context context manager.""" + + def test_context_manager_basic(self): + """Test basic context manager functionality.""" + assert get_current_context() is None + with agent_context() as ctx: + assert get_current_context() is ctx + assert ctx.task_completed is False + assert get_current_context() is None + + def test_nested_context_managers(self): + """Test nested context managers.""" + with agent_context() as outer_ctx: + assert get_current_context() is outer_ctx + with agent_context() as inner_ctx: + assert get_current_context() is inner_ctx + assert inner_ctx is not outer_ctx + assert get_current_context() is outer_ctx + + def test_context_manager_with_parent(self): + """Test context manager with explicit parent context.""" + parent = AgentContext() + parent.mark_task_completed("Parent task") + with agent_context(parent_context=parent) as ctx: + assert ctx.task_completed is True + assert ctx.completion_message == "Parent task" + + def test_context_manager_inheritance(self): + """Test that nested contexts inherit from outer contexts by default.""" + with agent_context() as outer: + outer.mark_task_completed("Outer task") + with agent_context() as inner: + assert inner.task_completed is True + assert inner.completion_message == "Outer task" + inner.mark_plan_completed("Inner plan") + # Outer context should not be affected by inner context changes + assert outer.task_completed is True + assert outer.plan_completed is False + assert outer.completion_message == "Outer task" + + +class TestThreadIsolation: + """Test thread isolation of context variables.""" + + def test_thread_isolation(self): + """Test that contexts are isolated between threads.""" + results = {} + + def thread_func(thread_id): + with agent_context() as ctx: + ctx.mark_task_completed(f"Thread {thread_id}") + time.sleep(0.1) # Give other threads time to run + # Store the context's message for verification + results[thread_id] = get_completion_message() + + threads = [] + for i in range(3): + t = threading.Thread(target=thread_func, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Each thread should have its own message + assert results[0] == "Thread 0" + assert results[1] == "Thread 1" + assert results[2] == "Thread 2" + + +class TestUtilityFunctions: + """Test utility functions that operate on the current context.""" + + def test_mark_task_completed_utility(self): + """Test the mark_task_completed utility function.""" + with agent_context(): + mark_task_completed("Task done via utility") + assert is_completed() is True + assert get_completion_message() == "Task done via utility" + + def test_mark_plan_completed_utility(self): + """Test the mark_plan_completed utility function.""" + with agent_context(): + mark_plan_completed("Plan done via utility") + assert is_completed() is True + assert get_completion_message() == "Plan done via utility" + + def test_reset_completion_flags_utility(self): + """Test the reset_completion_flags utility function.""" + with agent_context(): + mark_task_completed("Task done") + reset_completion_flags() + assert is_completed() is False + assert get_completion_message() == "" + + def test_utility_functions_without_context(self): + """Test utility functions when no context is active.""" + # These should not raise exceptions even without an active context + mark_task_completed("No context") + mark_plan_completed("No context") + reset_completion_flags() + # These should have safe default returns + assert is_completed() is False + assert get_completion_message() == "" diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index e31355a..fca5fcd 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -8,6 +8,7 @@ import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from ra_aid.agent_context import agent_context, get_current_context, reset_completion_flags from ra_aid.agent_utils import ( AgentState, create_agent, @@ -325,7 +326,7 @@ def test_increment_and_decrement_agent_depth(): def test_run_agent_stream(monkeypatch): - from ra_aid.agent_utils import _global_memory, _run_agent_stream + from ra_aid.agent_utils import _run_agent_stream # Create a dummy agent that yields one chunk class DummyAgent: @@ -334,9 +335,11 @@ def test_run_agent_stream(monkeypatch): dummy_agent = DummyAgent() # Set flags so that _run_agent_stream will reset them - _global_memory["plan_completed"] = True - _global_memory["task_completed"] = True - _global_memory["completion_message"] = "existing" + with agent_context() as ctx: + ctx.plan_completed = True + ctx.task_completed = True + ctx.completion_message = "existing" + call_flag = {"called": False} def fake_print_agent_output( @@ -349,9 +352,11 @@ def test_run_agent_stream(monkeypatch): ) _run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {}) assert call_flag["called"] - assert _global_memory["plan_completed"] is False - assert _global_memory["task_completed"] is False - assert _global_memory["completion_message"] == "" + + with agent_context() as ctx: + assert ctx.plan_completed is False + assert ctx.task_completed is False + assert ctx.completion_message == "" def test_execute_test_command_wrapper(monkeypatch):