diff --git a/ra_aid/agent_context.py b/ra_aid/agent_context.py index 49ffeb5..c3b11e7 100644 --- a/ra_aid/agent_context.py +++ b/ra_aid/agent_context.py @@ -90,6 +90,17 @@ class AgentContext: def is_completed(self) -> bool: """Check if the current context is marked as completed.""" return self.task_completed or self.plan_completed + + @property + def depth(self) -> int: + """Calculate the depth of this context based on parent chain. + + Returns: + int: 0 for a context with no parent, parent.depth + 1 otherwise + """ + if self.parent is None: + return 0 + return self.parent.depth + 1 def get_current_context() -> Optional[AgentContext]: @@ -101,6 +112,18 @@ def get_current_context() -> Optional[AgentContext]: return getattr(_thread_local, "current_context", None) +def get_depth() -> int: + """Get the depth of the current agent context. + + Returns: + int: Depth of the current context, or 0 if no context exists + """ + ctx = get_current_context() + if ctx is None: + return 0 + return ctx.depth + + @contextmanager def agent_context(parent_context=None): """Context manager for agent execution. diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index cd5ef50..7ef9efa 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -32,6 +32,7 @@ from rich.panel import Panel from ra_aid.agent_context import ( agent_context, + get_depth, is_completed, reset_completion_flags, should_exit, @@ -904,15 +905,6 @@ def _restore_interrupt_handling(original_handler): signal.signal(signal.SIGINT, original_handler) -def _increment_agent_depth(): - current_depth = _global_memory.get("agent_depth", 0) - _global_memory["agent_depth"] = current_depth + 1 - - -def _decrement_agent_depth(): - _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 - - def reset_agent_completion_flags(): """Reset completion flags in the current context.""" reset_completion_flags() @@ -1031,7 +1023,6 @@ def run_agent_with_retry( # Create a new agent context for this run with InterruptibleSection(), agent_context() as ctx: try: - _increment_agent_depth() for attempt in range(max_retries): logger.debug("Attempt %d/%d", attempt + 1, max_retries) check_interrupt() @@ -1103,5 +1094,4 @@ def run_agent_with_retry( _handle_api_error(e, attempt, max_retries, base_delay) finally: - _decrement_agent_depth() _restore_interrupt_handling(original_handler) \ No newline at end of file diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index c44cc61..a87841f 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -10,6 +10,7 @@ from rich.console import Console from ra_aid.agent_context import ( get_completion_message, get_crash_message, + get_depth, is_crashed, reset_completion_flags, ) @@ -59,7 +60,7 @@ def request_research(query: str) -> ResearchResult: ) # Check recursion depth - current_depth = _global_memory.get("agent_depth", 0) + current_depth = get_depth() if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT: print_error("Maximum research recursion depth reached") try: diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 66ad6ff..789d703 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -41,9 +41,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi from ra_aid.database.repositories.related_files_repository import get_related_files_repository # Global memory store -_global_memory: Dict[str, Any] = { - "agent_depth": 0, -} +_global_memory: Dict[str, Any] = {} @tool("emit_research_notes") diff --git a/tests/ra_aid/test_agent_context.py b/tests/ra_aid/test_agent_context.py index a9e46ed..4d6b1fd 100644 --- a/tests/ra_aid/test_agent_context.py +++ b/tests/ra_aid/test_agent_context.py @@ -9,6 +9,7 @@ from ra_aid.agent_context import ( agent_context, get_completion_message, get_current_context, + get_depth, is_completed, mark_plan_completed, mark_should_exit, @@ -285,6 +286,31 @@ class TestUtilityFunctions: mark_task_completed("Task done via utility") assert is_completed() is True assert get_completion_message() == "Task done via utility" + + def test_agent_context_depth_property(self): + """Test that the depth property correctly calculates context depth.""" + # Create contexts with different nesting levels + ctx1 = AgentContext() # Depth 0 + ctx2 = AgentContext(ctx1) # Depth 1 + ctx3 = AgentContext(ctx2) # Depth 2 + + # Verify depths + assert ctx1.depth == 0 + assert ctx2.depth == 1 + assert ctx3.depth == 2 + + def test_get_depth_function(self): + """Test that get_depth() returns the correct depth of the current context.""" + # No context active + assert get_depth() == 0 + + # With nested contexts + with agent_context() as ctx1: + assert get_depth() == 0 + with agent_context() as ctx2: + assert get_depth() == 1 + with agent_context() as ctx3: + assert get_depth() == 2 def test_mark_plan_completed_utility(self): """Test the mark_plan_completed utility function.""" diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 4b5dd2c..591c0db 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -316,18 +316,22 @@ def test_setup_and_restore_interrupt_handling(): assert signal.getsignal(signal.SIGINT) == original_handler -def test_increment_and_decrement_agent_depth(): - from ra_aid.agent_utils import ( - _decrement_agent_depth, - _global_memory, - _increment_agent_depth, - ) +def test_agent_context_depth(): + from ra_aid.agent_context import agent_context, get_depth - _global_memory["agent_depth"] = 10 - _increment_agent_depth() - assert _global_memory["agent_depth"] == 11 - _decrement_agent_depth() - assert _global_memory["agent_depth"] == 10 + # Test depth with nested contexts + assert get_depth() == 0 # No context + with agent_context() as ctx1: + assert get_depth() == 0 # Root context has depth 0 + assert ctx1.depth == 0 + + with agent_context() as ctx2: + assert get_depth() == 1 # Nested context has depth 1 + assert ctx2.depth == 1 + + with agent_context() as ctx3: + assert get_depth() == 2 # Doubly nested context has depth 2 + assert ctx3.depth == 2 def test_run_agent_stream(monkeypatch): @@ -501,12 +505,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch): def mock_restore_interrupt_handling(handler): pass - def mock_increment_agent_depth(): - pass - - def mock_decrement_agent_depth(): - pass - def mock_is_crashed(): return ctx.is_crashed() if ctx else False @@ -522,12 +520,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch): "ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling, ) - monkeypatch.setattr( - "ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth - ) - monkeypatch.setattr( - "ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth - ) monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) # First, run without a crash - agent should be run @@ -581,12 +573,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch): def mock_restore_interrupt_handling(handler): pass - def mock_increment_agent_depth(): - pass - - def mock_decrement_agent_depth(): - pass - def mock_mark_agent_crashed(message): ctx.agent_has_crashed = True ctx.agent_crashed_message = message @@ -603,12 +589,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch): "ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling, ) - monkeypatch.setattr( - "ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth - ) - monkeypatch.setattr( - "ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth - ) monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) with agent_context() as ctx: @@ -660,12 +640,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch): def mock_restore_interrupt_handling(handler): pass - def mock_increment_agent_depth(): - pass - - def mock_decrement_agent_depth(): - pass - def mock_mark_agent_crashed(message): ctx.agent_has_crashed = True ctx.agent_crashed_message = message @@ -682,12 +656,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch): "ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling, ) - monkeypatch.setattr( - "ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth - ) - monkeypatch.setattr( - "ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth - ) monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None) monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError) @@ -705,6 +673,7 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch): # Verify the agent is marked as crashed assert is_crashed() + def test_handle_api_error_resource_exhausted(): from google.api_core.exceptions import ResourceExhausted from ra_aid.agent_utils import _handle_api_error diff --git a/tests/ra_aid/test_main.py b/tests/ra_aid/test_main.py index bf4f985..2787b8f 100644 --- a/tests/ra_aid/test_main.py +++ b/tests/ra_aid/test_main.py @@ -14,7 +14,6 @@ def mock_dependencies(monkeypatch): """Mock all dependencies needed for main().""" # Initialize global memory with necessary keys to prevent KeyError _global_memory.clear() - _global_memory["agent_depth"] = 0 _global_memory["config"] = {} # Mock dependencies that interact with external systems @@ -193,7 +192,6 @@ def test_temperature_validation(mock_dependencies): # Reset global memory for clean test _global_memory.clear() - _global_memory["agent_depth"] = 0 _global_memory["config"] = {} # Test valid temperature (0.7) diff --git a/tests/ra_aid/tools/test_agent.py b/tests/ra_aid/tools/test_agent.py index 073154e..fdb36df 100644 --- a/tests/ra_aid/tools/test_agent.py +++ b/tests/ra_aid/tools/test_agent.py @@ -154,8 +154,9 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions) def test_request_research_max_depth(reset_memory, mock_functions): """Test that max recursion depth handling uses KeyFactRepository.""" - # Set recursion depth to max - _global_memory["agent_depth"] = 3 + # Mock depth using context-based approach + with patch('ra_aid.tools.agent.get_depth') as mock_get_depth: + mock_get_depth.return_value = 3 # Call the function (should hit max depth case) result = request_research("test query")