track agent depth via context

This commit is contained in:
AI Christianson 2025-03-04 19:05:33 -05:00
parent a1b268fdf4
commit 60a6707107
8 changed files with 72 additions and 66 deletions

View File

@ -90,6 +90,17 @@ class AgentContext:
def is_completed(self) -> bool: def is_completed(self) -> bool:
"""Check if the current context is marked as completed.""" """Check if the current context is marked as completed."""
return self.task_completed or self.plan_completed return self.task_completed or self.plan_completed
@property
def depth(self) -> int:
"""Calculate the depth of this context based on parent chain.
Returns:
int: 0 for a context with no parent, parent.depth + 1 otherwise
"""
if self.parent is None:
return 0
return self.parent.depth + 1
def get_current_context() -> Optional[AgentContext]: def get_current_context() -> Optional[AgentContext]:
@ -101,6 +112,18 @@ def get_current_context() -> Optional[AgentContext]:
return getattr(_thread_local, "current_context", None) return getattr(_thread_local, "current_context", None)
def get_depth() -> int:
"""Get the depth of the current agent context.
Returns:
int: Depth of the current context, or 0 if no context exists
"""
ctx = get_current_context()
if ctx is None:
return 0
return ctx.depth
@contextmanager @contextmanager
def agent_context(parent_context=None): def agent_context(parent_context=None):
"""Context manager for agent execution. """Context manager for agent execution.

View File

@ -32,6 +32,7 @@ from rich.panel import Panel
from ra_aid.agent_context import ( from ra_aid.agent_context import (
agent_context, agent_context,
get_depth,
is_completed, is_completed,
reset_completion_flags, reset_completion_flags,
should_exit, should_exit,
@ -904,15 +905,6 @@ def _restore_interrupt_handling(original_handler):
signal.signal(signal.SIGINT, original_handler) signal.signal(signal.SIGINT, original_handler)
def _increment_agent_depth():
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def reset_agent_completion_flags(): def reset_agent_completion_flags():
"""Reset completion flags in the current context.""" """Reset completion flags in the current context."""
reset_completion_flags() reset_completion_flags()
@ -1031,7 +1023,6 @@ def run_agent_with_retry(
# Create a new agent context for this run # Create a new agent context for this run
with InterruptibleSection(), agent_context() as ctx: with InterruptibleSection(), agent_context() as ctx:
try: try:
_increment_agent_depth()
for attempt in range(max_retries): for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries) logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt() check_interrupt()
@ -1103,5 +1094,4 @@ def run_agent_with_retry(
_handle_api_error(e, attempt, max_retries, base_delay) _handle_api_error(e, attempt, max_retries, base_delay)
finally: finally:
_decrement_agent_depth()
_restore_interrupt_handling(original_handler) _restore_interrupt_handling(original_handler)

View File

@ -10,6 +10,7 @@ from rich.console import Console
from ra_aid.agent_context import ( from ra_aid.agent_context import (
get_completion_message, get_completion_message,
get_crash_message, get_crash_message,
get_depth,
is_crashed, is_crashed,
reset_completion_flags, reset_completion_flags,
) )
@ -59,7 +60,7 @@ def request_research(query: str) -> ResearchResult:
) )
# Check recursion depth # Check recursion depth
current_depth = _global_memory.get("agent_depth", 0) current_depth = get_depth()
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT: if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
print_error("Maximum research recursion depth reached") print_error("Maximum research recursion depth reached")
try: try:

View File

@ -41,9 +41,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository
# Global memory store # Global memory store
_global_memory: Dict[str, Any] = { _global_memory: Dict[str, Any] = {}
"agent_depth": 0,
}
@tool("emit_research_notes") @tool("emit_research_notes")

View File

@ -9,6 +9,7 @@ from ra_aid.agent_context import (
agent_context, agent_context,
get_completion_message, get_completion_message,
get_current_context, get_current_context,
get_depth,
is_completed, is_completed,
mark_plan_completed, mark_plan_completed,
mark_should_exit, mark_should_exit,
@ -285,6 +286,31 @@ class TestUtilityFunctions:
mark_task_completed("Task done via utility") mark_task_completed("Task done via utility")
assert is_completed() is True assert is_completed() is True
assert get_completion_message() == "Task done via utility" assert get_completion_message() == "Task done via utility"
def test_agent_context_depth_property(self):
"""Test that the depth property correctly calculates context depth."""
# Create contexts with different nesting levels
ctx1 = AgentContext() # Depth 0
ctx2 = AgentContext(ctx1) # Depth 1
ctx3 = AgentContext(ctx2) # Depth 2
# Verify depths
assert ctx1.depth == 0
assert ctx2.depth == 1
assert ctx3.depth == 2
def test_get_depth_function(self):
"""Test that get_depth() returns the correct depth of the current context."""
# No context active
assert get_depth() == 0
# With nested contexts
with agent_context() as ctx1:
assert get_depth() == 0
with agent_context() as ctx2:
assert get_depth() == 1
with agent_context() as ctx3:
assert get_depth() == 2
def test_mark_plan_completed_utility(self): def test_mark_plan_completed_utility(self):
"""Test the mark_plan_completed utility function.""" """Test the mark_plan_completed utility function."""

View File

@ -316,18 +316,22 @@ def test_setup_and_restore_interrupt_handling():
assert signal.getsignal(signal.SIGINT) == original_handler assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth(): def test_agent_context_depth():
from ra_aid.agent_utils import ( from ra_aid.agent_context import agent_context, get_depth
_decrement_agent_depth,
_global_memory,
_increment_agent_depth,
)
_global_memory["agent_depth"] = 10 # Test depth with nested contexts
_increment_agent_depth() assert get_depth() == 0 # No context
assert _global_memory["agent_depth"] == 11 with agent_context() as ctx1:
_decrement_agent_depth() assert get_depth() == 0 # Root context has depth 0
assert _global_memory["agent_depth"] == 10 assert ctx1.depth == 0
with agent_context() as ctx2:
assert get_depth() == 1 # Nested context has depth 1
assert ctx2.depth == 1
with agent_context() as ctx3:
assert get_depth() == 2 # Doubly nested context has depth 2
assert ctx3.depth == 2
def test_run_agent_stream(monkeypatch): def test_run_agent_stream(monkeypatch):
@ -501,12 +505,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
def mock_restore_interrupt_handling(handler): def mock_restore_interrupt_handling(handler):
pass pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_is_crashed(): def mock_is_crashed():
return ctx.is_crashed() if ctx else False return ctx.is_crashed() if ctx else False
@ -522,12 +520,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
"ra_aid.agent_utils._restore_interrupt_handling", "ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling, mock_restore_interrupt_handling,
) )
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
# First, run without a crash - agent should be run # First, run without a crash - agent should be run
@ -581,12 +573,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
def mock_restore_interrupt_handling(handler): def mock_restore_interrupt_handling(handler):
pass pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message): def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True ctx.agent_has_crashed = True
ctx.agent_crashed_message = message ctx.agent_crashed_message = message
@ -603,12 +589,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
"ra_aid.agent_utils._restore_interrupt_handling", "ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling, mock_restore_interrupt_handling,
) )
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
with agent_context() as ctx: with agent_context() as ctx:
@ -660,12 +640,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
def mock_restore_interrupt_handling(handler): def mock_restore_interrupt_handling(handler):
pass pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message): def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True ctx.agent_has_crashed = True
ctx.agent_crashed_message = message ctx.agent_crashed_message = message
@ -682,12 +656,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
"ra_aid.agent_utils._restore_interrupt_handling", "ra_aid.agent_utils._restore_interrupt_handling",
mock_restore_interrupt_handling, mock_restore_interrupt_handling,
) )
monkeypatch.setattr(
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
)
monkeypatch.setattr(
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None) monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError) monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
@ -705,6 +673,7 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
# Verify the agent is marked as crashed # Verify the agent is marked as crashed
assert is_crashed() assert is_crashed()
def test_handle_api_error_resource_exhausted(): def test_handle_api_error_resource_exhausted():
from google.api_core.exceptions import ResourceExhausted from google.api_core.exceptions import ResourceExhausted
from ra_aid.agent_utils import _handle_api_error from ra_aid.agent_utils import _handle_api_error

View File

@ -14,7 +14,6 @@ def mock_dependencies(monkeypatch):
"""Mock all dependencies needed for main().""" """Mock all dependencies needed for main()."""
# Initialize global memory with necessary keys to prevent KeyError # Initialize global memory with necessary keys to prevent KeyError
_global_memory.clear() _global_memory.clear()
_global_memory["agent_depth"] = 0
_global_memory["config"] = {} _global_memory["config"] = {}
# Mock dependencies that interact with external systems # Mock dependencies that interact with external systems
@ -193,7 +192,6 @@ def test_temperature_validation(mock_dependencies):
# Reset global memory for clean test # Reset global memory for clean test
_global_memory.clear() _global_memory.clear()
_global_memory["agent_depth"] = 0
_global_memory["config"] = {} _global_memory["config"] = {}
# Test valid temperature (0.7) # Test valid temperature (0.7)

View File

@ -154,8 +154,9 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
def test_request_research_max_depth(reset_memory, mock_functions): def test_request_research_max_depth(reset_memory, mock_functions):
"""Test that max recursion depth handling uses KeyFactRepository.""" """Test that max recursion depth handling uses KeyFactRepository."""
# Set recursion depth to max # Mock depth using context-based approach
_global_memory["agent_depth"] = 3 with patch('ra_aid.tools.agent.get_depth') as mock_get_depth:
mock_get_depth.return_value = 3
# Call the function (should hit max depth case) # Call the function (should hit max depth case)
result = request_research("test query") result = request_research("test query")