track agent depth via context
This commit is contained in:
parent
a1b268fdf4
commit
60a6707107
|
|
@ -90,6 +90,17 @@ class AgentContext:
|
|||
def is_completed(self) -> bool:
|
||||
"""Check if the current context is marked as completed."""
|
||||
return self.task_completed or self.plan_completed
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
"""Calculate the depth of this context based on parent chain.
|
||||
|
||||
Returns:
|
||||
int: 0 for a context with no parent, parent.depth + 1 otherwise
|
||||
"""
|
||||
if self.parent is None:
|
||||
return 0
|
||||
return self.parent.depth + 1
|
||||
|
||||
|
||||
def get_current_context() -> Optional[AgentContext]:
|
||||
|
|
@ -101,6 +112,18 @@ def get_current_context() -> Optional[AgentContext]:
|
|||
return getattr(_thread_local, "current_context", None)
|
||||
|
||||
|
||||
def get_depth() -> int:
|
||||
"""Get the depth of the current agent context.
|
||||
|
||||
Returns:
|
||||
int: Depth of the current context, or 0 if no context exists
|
||||
"""
|
||||
ctx = get_current_context()
|
||||
if ctx is None:
|
||||
return 0
|
||||
return ctx.depth
|
||||
|
||||
|
||||
@contextmanager
|
||||
def agent_context(parent_context=None):
|
||||
"""Context manager for agent execution.
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from rich.panel import Panel
|
|||
|
||||
from ra_aid.agent_context import (
|
||||
agent_context,
|
||||
get_depth,
|
||||
is_completed,
|
||||
reset_completion_flags,
|
||||
should_exit,
|
||||
|
|
@ -904,15 +905,6 @@ def _restore_interrupt_handling(original_handler):
|
|||
signal.signal(signal.SIGINT, original_handler)
|
||||
|
||||
|
||||
def _increment_agent_depth():
|
||||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
_global_memory["agent_depth"] = current_depth + 1
|
||||
|
||||
|
||||
def _decrement_agent_depth():
|
||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||
|
||||
|
||||
def reset_agent_completion_flags():
|
||||
"""Reset completion flags in the current context."""
|
||||
reset_completion_flags()
|
||||
|
|
@ -1031,7 +1023,6 @@ def run_agent_with_retry(
|
|||
# Create a new agent context for this run
|
||||
with InterruptibleSection(), agent_context() as ctx:
|
||||
try:
|
||||
_increment_agent_depth()
|
||||
for attempt in range(max_retries):
|
||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||
check_interrupt()
|
||||
|
|
@ -1103,5 +1094,4 @@ def run_agent_with_retry(
|
|||
|
||||
_handle_api_error(e, attempt, max_retries, base_delay)
|
||||
finally:
|
||||
_decrement_agent_depth()
|
||||
_restore_interrupt_handling(original_handler)
|
||||
|
|
@ -10,6 +10,7 @@ from rich.console import Console
|
|||
from ra_aid.agent_context import (
|
||||
get_completion_message,
|
||||
get_crash_message,
|
||||
get_depth,
|
||||
is_crashed,
|
||||
reset_completion_flags,
|
||||
)
|
||||
|
|
@ -59,7 +60,7 @@ def request_research(query: str) -> ResearchResult:
|
|||
)
|
||||
|
||||
# Check recursion depth
|
||||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
current_depth = get_depth()
|
||||
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
||||
print_error("Maximum research recursion depth reached")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -41,9 +41,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
|
|||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||
|
||||
# Global memory store
|
||||
_global_memory: Dict[str, Any] = {
|
||||
"agent_depth": 0,
|
||||
}
|
||||
_global_memory: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@tool("emit_research_notes")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from ra_aid.agent_context import (
|
|||
agent_context,
|
||||
get_completion_message,
|
||||
get_current_context,
|
||||
get_depth,
|
||||
is_completed,
|
||||
mark_plan_completed,
|
||||
mark_should_exit,
|
||||
|
|
@ -285,6 +286,31 @@ class TestUtilityFunctions:
|
|||
mark_task_completed("Task done via utility")
|
||||
assert is_completed() is True
|
||||
assert get_completion_message() == "Task done via utility"
|
||||
|
||||
def test_agent_context_depth_property(self):
|
||||
"""Test that the depth property correctly calculates context depth."""
|
||||
# Create contexts with different nesting levels
|
||||
ctx1 = AgentContext() # Depth 0
|
||||
ctx2 = AgentContext(ctx1) # Depth 1
|
||||
ctx3 = AgentContext(ctx2) # Depth 2
|
||||
|
||||
# Verify depths
|
||||
assert ctx1.depth == 0
|
||||
assert ctx2.depth == 1
|
||||
assert ctx3.depth == 2
|
||||
|
||||
def test_get_depth_function(self):
|
||||
"""Test that get_depth() returns the correct depth of the current context."""
|
||||
# No context active
|
||||
assert get_depth() == 0
|
||||
|
||||
# With nested contexts
|
||||
with agent_context() as ctx1:
|
||||
assert get_depth() == 0
|
||||
with agent_context() as ctx2:
|
||||
assert get_depth() == 1
|
||||
with agent_context() as ctx3:
|
||||
assert get_depth() == 2
|
||||
|
||||
def test_mark_plan_completed_utility(self):
|
||||
"""Test the mark_plan_completed utility function."""
|
||||
|
|
|
|||
|
|
@ -316,18 +316,22 @@ def test_setup_and_restore_interrupt_handling():
|
|||
assert signal.getsignal(signal.SIGINT) == original_handler
|
||||
|
||||
|
||||
def test_increment_and_decrement_agent_depth():
|
||||
from ra_aid.agent_utils import (
|
||||
_decrement_agent_depth,
|
||||
_global_memory,
|
||||
_increment_agent_depth,
|
||||
)
|
||||
def test_agent_context_depth():
|
||||
from ra_aid.agent_context import agent_context, get_depth
|
||||
|
||||
_global_memory["agent_depth"] = 10
|
||||
_increment_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 11
|
||||
_decrement_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 10
|
||||
# Test depth with nested contexts
|
||||
assert get_depth() == 0 # No context
|
||||
with agent_context() as ctx1:
|
||||
assert get_depth() == 0 # Root context has depth 0
|
||||
assert ctx1.depth == 0
|
||||
|
||||
with agent_context() as ctx2:
|
||||
assert get_depth() == 1 # Nested context has depth 1
|
||||
assert ctx2.depth == 1
|
||||
|
||||
with agent_context() as ctx3:
|
||||
assert get_depth() == 2 # Doubly nested context has depth 2
|
||||
assert ctx3.depth == 2
|
||||
|
||||
|
||||
def test_run_agent_stream(monkeypatch):
|
||||
|
|
@ -501,12 +505,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
|||
def mock_restore_interrupt_handling(handler):
|
||||
pass
|
||||
|
||||
def mock_increment_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_decrement_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_is_crashed():
|
||||
return ctx.is_crashed() if ctx else False
|
||||
|
||||
|
|
@ -522,12 +520,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
|||
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||
mock_restore_interrupt_handling,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
|
||||
# First, run without a crash - agent should be run
|
||||
|
|
@ -581,12 +573,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
|||
def mock_restore_interrupt_handling(handler):
|
||||
pass
|
||||
|
||||
def mock_increment_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_decrement_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_mark_agent_crashed(message):
|
||||
ctx.agent_has_crashed = True
|
||||
ctx.agent_crashed_message = message
|
||||
|
|
@ -603,12 +589,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
|||
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||
mock_restore_interrupt_handling,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
|
||||
with agent_context() as ctx:
|
||||
|
|
@ -660,12 +640,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
|||
def mock_restore_interrupt_handling(handler):
|
||||
pass
|
||||
|
||||
def mock_increment_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_decrement_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_mark_agent_crashed(message):
|
||||
ctx.agent_has_crashed = True
|
||||
ctx.agent_crashed_message = message
|
||||
|
|
@ -682,12 +656,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
|||
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||
mock_restore_interrupt_handling,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
||||
)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
|
||||
|
|
@ -705,6 +673,7 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
|||
# Verify the agent is marked as crashed
|
||||
assert is_crashed()
|
||||
|
||||
|
||||
def test_handle_api_error_resource_exhausted():
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ def mock_dependencies(monkeypatch):
|
|||
"""Mock all dependencies needed for main()."""
|
||||
# Initialize global memory with necessary keys to prevent KeyError
|
||||
_global_memory.clear()
|
||||
_global_memory["agent_depth"] = 0
|
||||
_global_memory["config"] = {}
|
||||
|
||||
# Mock dependencies that interact with external systems
|
||||
|
|
@ -193,7 +192,6 @@ def test_temperature_validation(mock_dependencies):
|
|||
|
||||
# Reset global memory for clean test
|
||||
_global_memory.clear()
|
||||
_global_memory["agent_depth"] = 0
|
||||
_global_memory["config"] = {}
|
||||
|
||||
# Test valid temperature (0.7)
|
||||
|
|
|
|||
|
|
@ -154,8 +154,9 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
|
|||
|
||||
def test_request_research_max_depth(reset_memory, mock_functions):
|
||||
"""Test that max recursion depth handling uses KeyFactRepository."""
|
||||
# Set recursion depth to max
|
||||
_global_memory["agent_depth"] = 3
|
||||
# Mock depth using context-based approach
|
||||
with patch('ra_aid.tools.agent.get_depth') as mock_get_depth:
|
||||
mock_get_depth.return_value = 3
|
||||
|
||||
# Call the function (should hit max depth case)
|
||||
result = request_research("test query")
|
||||
|
|
|
|||
Loading…
Reference in New Issue