track agent depth via context

This commit is contained in:
AI Christianson 2025-03-04 19:05:33 -05:00
parent a1b268fdf4
commit 60a6707107
8 changed files with 72 additions and 66 deletions

View File

@ -91,6 +91,17 @@ class AgentContext:
"""Check if the current context is marked as completed."""
return self.task_completed or self.plan_completed
@property
def depth(self) -> int:
"""Calculate the depth of this context based on parent chain.
Returns:
int: 0 for a context with no parent, parent.depth + 1 otherwise
"""
if self.parent is None:
return 0
return self.parent.depth + 1
def get_current_context() -> Optional[AgentContext]:
"""Get the current agent context for this thread.
@ -101,6 +112,18 @@ def get_current_context() -> Optional[AgentContext]:
return getattr(_thread_local, "current_context", None)
def get_depth() -> int:
"""Get the depth of the current agent context.
Returns:
int: Depth of the current context, or 0 if no context exists
"""
ctx = get_current_context()
if ctx is None:
return 0
return ctx.depth
@contextmanager
def agent_context(parent_context=None):
"""Context manager for agent execution.

View File

@ -32,6 +32,7 @@ from rich.panel import Panel
from ra_aid.agent_context import (
agent_context,
get_depth,
is_completed,
reset_completion_flags,
should_exit,
@ -904,15 +905,6 @@ def _restore_interrupt_handling(original_handler):
signal.signal(signal.SIGINT, original_handler)
def _increment_agent_depth():
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def reset_agent_completion_flags():
"""Reset completion flags in the current context."""
reset_completion_flags()
@ -1031,7 +1023,6 @@ def run_agent_with_retry(
# Create a new agent context for this run
with InterruptibleSection(), agent_context() as ctx:
try:
_increment_agent_depth()
for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
@ -1103,5 +1094,4 @@ def run_agent_with_retry(
_handle_api_error(e, attempt, max_retries, base_delay)
finally:
_decrement_agent_depth()
_restore_interrupt_handling(original_handler)

View File

@ -10,6 +10,7 @@ from rich.console import Console
from ra_aid.agent_context import (
get_completion_message,
get_crash_message,
get_depth,
is_crashed,
reset_completion_flags,
)
@ -59,7 +60,7 @@ def request_research(query: str) -> ResearchResult:
)
# Check recursion depth
current_depth = _global_memory.get("agent_depth", 0)
current_depth = get_depth()
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
print_error("Maximum research recursion depth reached")
try:

View File

@ -41,9 +41,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
# Global memory store
_global_memory: Dict[str, Any] = {
"agent_depth": 0,
}
_global_memory: Dict[str, Any] = {}
@tool("emit_research_notes")

View File

@ -9,6 +9,7 @@ from ra_aid.agent_context import (
agent_context,
get_completion_message,
get_current_context,
get_depth,
is_completed,
mark_plan_completed,
mark_should_exit,
@ -286,6 +287,31 @@ class TestUtilityFunctions:
assert is_completed() is True
assert get_completion_message() == "Task done via utility"
def test_agent_context_depth_property(self):
"""Test that the depth property correctly calculates context depth."""
# Create contexts with different nesting levels
ctx1 = AgentContext() # Depth 0
ctx2 = AgentContext(ctx1) # Depth 1
ctx3 = AgentContext(ctx2) # Depth 2
# Verify depths
assert ctx1.depth == 0
assert ctx2.depth == 1
assert ctx3.depth == 2
def test_get_depth_function(self):
"""Test that get_depth() returns the correct depth of the current context."""
# No context active
assert get_depth() == 0
# With nested contexts
with agent_context() as ctx1:
assert get_depth() == 0
with agent_context() as ctx2:
assert get_depth() == 1
with agent_context() as ctx3:
assert get_depth() == 2
def test_mark_plan_completed_utility(self):
"""Test the mark_plan_completed utility function."""
with agent_context():

View File

@ -316,18 +316,22 @@ def test_setup_and_restore_interrupt_handling():
assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth():
from ra_aid.agent_utils import (
_decrement_agent_depth,
_global_memory,
_increment_agent_depth,
)
def test_agent_context_depth():
from ra_aid.agent_context import agent_context, get_depth
_global_memory["agent_depth"] = 10
_increment_agent_depth()
assert _global_memory["agent_depth"] == 11
_decrement_agent_depth()
assert _global_memory["agent_depth"] == 10
# Test depth with nested contexts
assert get_depth() == 0 # No context
with agent_context() as ctx1:
assert get_depth() == 0 # Root context has depth 0
assert ctx1.depth == 0
with agent_context() as ctx2:
assert get_depth() == 1 # Nested context has depth 1
assert ctx2.depth == 1
with agent_context() as ctx3:
assert get_depth() == 2 # Doubly nested context has depth 2
assert ctx3.depth == 2
def test_run_agent_stream(monkeypatch):
@ -501,12 +505,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_is_crashed():
return ctx.is_crashed() if ctx else False
@ -522,12 +520,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
"ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling,
)
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
# First, run without a crash - agent should be run
@ -581,12 +573,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True
ctx.agent_crashed_message = message
@ -603,12 +589,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
"ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling,
)
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
with agent_context() as ctx:
@ -660,12 +640,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True
ctx.agent_crashed_message = message
@ -682,12 +656,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
"ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling,
)
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
@ -705,6 +673,7 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
# Verify the agent is marked as crashed
assert is_crashed()
def test_handle_api_error_resource_exhausted():
from google.api_core.exceptions import ResourceExhausted
from ra_aid.agent_utils import _handle_api_error

View File

@ -14,7 +14,6 @@ def mock_dependencies(monkeypatch):
"""Mock all dependencies needed for main()."""
# Initialize global memory with necessary keys to prevent KeyError
_global_memory.clear()
_global_memory["agent_depth"] = 0
_global_memory["config"] = {}
# Mock dependencies that interact with external systems
@ -193,7 +192,6 @@ def test_temperature_validation(mock_dependencies):
# Reset global memory for clean test
_global_memory.clear()
_global_memory["agent_depth"] = 0
_global_memory["config"] = {}
# Test valid temperature (0.7)

View File

@ -154,8 +154,9 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
def test_request_research_max_depth(reset_memory, mock_functions):
"""Test that max recursion depth handling uses KeyFactRepository."""
# Set recursion depth to max
_global_memory["agent_depth"] = 3
# Mock depth using context-based approach
with patch('ra_aid.tools.agent.get_depth') as mock_get_depth:
mock_get_depth.return_value = 3
# Call the function (should hit max depth case)
result = request_research("test query")