track agent depth via context
This commit is contained in:
parent
a1b268fdf4
commit
60a6707107
|
|
@ -90,6 +90,17 @@ class AgentContext:
|
||||||
def is_completed(self) -> bool:
|
def is_completed(self) -> bool:
|
||||||
"""Check if the current context is marked as completed."""
|
"""Check if the current context is marked as completed."""
|
||||||
return self.task_completed or self.plan_completed
|
return self.task_completed or self.plan_completed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def depth(self) -> int:
|
||||||
|
"""Calculate the depth of this context based on parent chain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 0 for a context with no parent, parent.depth + 1 otherwise
|
||||||
|
"""
|
||||||
|
if self.parent is None:
|
||||||
|
return 0
|
||||||
|
return self.parent.depth + 1
|
||||||
|
|
||||||
|
|
||||||
def get_current_context() -> Optional[AgentContext]:
|
def get_current_context() -> Optional[AgentContext]:
|
||||||
|
|
@ -101,6 +112,18 @@ def get_current_context() -> Optional[AgentContext]:
|
||||||
return getattr(_thread_local, "current_context", None)
|
return getattr(_thread_local, "current_context", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_depth() -> int:
|
||||||
|
"""Get the depth of the current agent context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Depth of the current context, or 0 if no context exists
|
||||||
|
"""
|
||||||
|
ctx = get_current_context()
|
||||||
|
if ctx is None:
|
||||||
|
return 0
|
||||||
|
return ctx.depth
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def agent_context(parent_context=None):
|
def agent_context(parent_context=None):
|
||||||
"""Context manager for agent execution.
|
"""Context manager for agent execution.
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
agent_context,
|
agent_context,
|
||||||
|
get_depth,
|
||||||
is_completed,
|
is_completed,
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
should_exit,
|
should_exit,
|
||||||
|
|
@ -904,15 +905,6 @@ def _restore_interrupt_handling(original_handler):
|
||||||
signal.signal(signal.SIGINT, original_handler)
|
signal.signal(signal.SIGINT, original_handler)
|
||||||
|
|
||||||
|
|
||||||
def _increment_agent_depth():
|
|
||||||
current_depth = _global_memory.get("agent_depth", 0)
|
|
||||||
_global_memory["agent_depth"] = current_depth + 1
|
|
||||||
|
|
||||||
|
|
||||||
def _decrement_agent_depth():
|
|
||||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
|
||||||
|
|
||||||
|
|
||||||
def reset_agent_completion_flags():
|
def reset_agent_completion_flags():
|
||||||
"""Reset completion flags in the current context."""
|
"""Reset completion flags in the current context."""
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
|
|
@ -1031,7 +1023,6 @@ def run_agent_with_retry(
|
||||||
# Create a new agent context for this run
|
# Create a new agent context for this run
|
||||||
with InterruptibleSection(), agent_context() as ctx:
|
with InterruptibleSection(), agent_context() as ctx:
|
||||||
try:
|
try:
|
||||||
_increment_agent_depth()
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
|
|
@ -1103,5 +1094,4 @@ def run_agent_with_retry(
|
||||||
|
|
||||||
_handle_api_error(e, attempt, max_retries, base_delay)
|
_handle_api_error(e, attempt, max_retries, base_delay)
|
||||||
finally:
|
finally:
|
||||||
_decrement_agent_depth()
|
|
||||||
_restore_interrupt_handling(original_handler)
|
_restore_interrupt_handling(original_handler)
|
||||||
|
|
@ -10,6 +10,7 @@ from rich.console import Console
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
get_completion_message,
|
get_completion_message,
|
||||||
get_crash_message,
|
get_crash_message,
|
||||||
|
get_depth,
|
||||||
is_crashed,
|
is_crashed,
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
)
|
)
|
||||||
|
|
@ -59,7 +60,7 @@ def request_research(query: str) -> ResearchResult:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check recursion depth
|
# Check recursion depth
|
||||||
current_depth = _global_memory.get("agent_depth", 0)
|
current_depth = get_depth()
|
||||||
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
||||||
print_error("Maximum research recursion depth reached")
|
print_error("Maximum research recursion depth reached")
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -41,9 +41,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
|
||||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
|
||||||
# Global memory store
|
# Global memory store
|
||||||
_global_memory: Dict[str, Any] = {
|
_global_memory: Dict[str, Any] = {}
|
||||||
"agent_depth": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_research_notes")
|
@tool("emit_research_notes")
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from ra_aid.agent_context import (
|
||||||
agent_context,
|
agent_context,
|
||||||
get_completion_message,
|
get_completion_message,
|
||||||
get_current_context,
|
get_current_context,
|
||||||
|
get_depth,
|
||||||
is_completed,
|
is_completed,
|
||||||
mark_plan_completed,
|
mark_plan_completed,
|
||||||
mark_should_exit,
|
mark_should_exit,
|
||||||
|
|
@ -285,6 +286,31 @@ class TestUtilityFunctions:
|
||||||
mark_task_completed("Task done via utility")
|
mark_task_completed("Task done via utility")
|
||||||
assert is_completed() is True
|
assert is_completed() is True
|
||||||
assert get_completion_message() == "Task done via utility"
|
assert get_completion_message() == "Task done via utility"
|
||||||
|
|
||||||
|
def test_agent_context_depth_property(self):
|
||||||
|
"""Test that the depth property correctly calculates context depth."""
|
||||||
|
# Create contexts with different nesting levels
|
||||||
|
ctx1 = AgentContext() # Depth 0
|
||||||
|
ctx2 = AgentContext(ctx1) # Depth 1
|
||||||
|
ctx3 = AgentContext(ctx2) # Depth 2
|
||||||
|
|
||||||
|
# Verify depths
|
||||||
|
assert ctx1.depth == 0
|
||||||
|
assert ctx2.depth == 1
|
||||||
|
assert ctx3.depth == 2
|
||||||
|
|
||||||
|
def test_get_depth_function(self):
|
||||||
|
"""Test that get_depth() returns the correct depth of the current context."""
|
||||||
|
# No context active
|
||||||
|
assert get_depth() == 0
|
||||||
|
|
||||||
|
# With nested contexts
|
||||||
|
with agent_context() as ctx1:
|
||||||
|
assert get_depth() == 0
|
||||||
|
with agent_context() as ctx2:
|
||||||
|
assert get_depth() == 1
|
||||||
|
with agent_context() as ctx3:
|
||||||
|
assert get_depth() == 2
|
||||||
|
|
||||||
def test_mark_plan_completed_utility(self):
|
def test_mark_plan_completed_utility(self):
|
||||||
"""Test the mark_plan_completed utility function."""
|
"""Test the mark_plan_completed utility function."""
|
||||||
|
|
|
||||||
|
|
@ -316,18 +316,22 @@ def test_setup_and_restore_interrupt_handling():
|
||||||
assert signal.getsignal(signal.SIGINT) == original_handler
|
assert signal.getsignal(signal.SIGINT) == original_handler
|
||||||
|
|
||||||
|
|
||||||
def test_increment_and_decrement_agent_depth():
|
def test_agent_context_depth():
|
||||||
from ra_aid.agent_utils import (
|
from ra_aid.agent_context import agent_context, get_depth
|
||||||
_decrement_agent_depth,
|
|
||||||
_global_memory,
|
|
||||||
_increment_agent_depth,
|
|
||||||
)
|
|
||||||
|
|
||||||
_global_memory["agent_depth"] = 10
|
# Test depth with nested contexts
|
||||||
_increment_agent_depth()
|
assert get_depth() == 0 # No context
|
||||||
assert _global_memory["agent_depth"] == 11
|
with agent_context() as ctx1:
|
||||||
_decrement_agent_depth()
|
assert get_depth() == 0 # Root context has depth 0
|
||||||
assert _global_memory["agent_depth"] == 10
|
assert ctx1.depth == 0
|
||||||
|
|
||||||
|
with agent_context() as ctx2:
|
||||||
|
assert get_depth() == 1 # Nested context has depth 1
|
||||||
|
assert ctx2.depth == 1
|
||||||
|
|
||||||
|
with agent_context() as ctx3:
|
||||||
|
assert get_depth() == 2 # Doubly nested context has depth 2
|
||||||
|
assert ctx3.depth == 2
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_stream(monkeypatch):
|
def test_run_agent_stream(monkeypatch):
|
||||||
|
|
@ -501,12 +505,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
||||||
def mock_restore_interrupt_handling(handler):
|
def mock_restore_interrupt_handling(handler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_increment_agent_depth():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mock_decrement_agent_depth():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mock_is_crashed():
|
def mock_is_crashed():
|
||||||
return ctx.is_crashed() if ctx else False
|
return ctx.is_crashed() if ctx else False
|
||||||
|
|
||||||
|
|
@ -522,12 +520,6 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
||||||
"ra_aid.agent_utils._restore_interrupt_handling",
|
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||||
mock_restore_interrupt_handling,
|
mock_restore_interrupt_handling,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||||
|
|
||||||
# First, run without a crash - agent should be run
|
# First, run without a crash - agent should be run
|
||||||
|
|
@ -581,12 +573,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
||||||
def mock_restore_interrupt_handling(handler):
|
def mock_restore_interrupt_handling(handler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_increment_agent_depth():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mock_decrement_agent_depth():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mock_mark_agent_crashed(message):
|
def mock_mark_agent_crashed(message):
|
||||||
ctx.agent_has_crashed = True
|
ctx.agent_has_crashed = True
|
||||||
ctx.agent_crashed_message = message
|
ctx.agent_crashed_message = message
|
||||||
|
|
@ -603,12 +589,6 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
||||||
"ra_aid.agent_utils._restore_interrupt_handling",
|
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||||
mock_restore_interrupt_handling,
|
mock_restore_interrupt_handling,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||||
|
|
||||||
with agent_context() as ctx:
|
with agent_context() as ctx:
|
||||||
|
|
@ -660,12 +640,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
||||||
def mock_restore_interrupt_handling(handler):
|
def mock_restore_interrupt_handling(handler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_increment_agent_depth():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mock_decrement_agent_depth():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mock_mark_agent_crashed(message):
|
def mock_mark_agent_crashed(message):
|
||||||
ctx.agent_has_crashed = True
|
ctx.agent_has_crashed = True
|
||||||
ctx.agent_crashed_message = message
|
ctx.agent_crashed_message = message
|
||||||
|
|
@ -682,12 +656,6 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
||||||
"ra_aid.agent_utils._restore_interrupt_handling",
|
"ra_aid.agent_utils._restore_interrupt_handling",
|
||||||
mock_restore_interrupt_handling,
|
mock_restore_interrupt_handling,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
|
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
|
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
|
||||||
|
|
@ -705,6 +673,7 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
||||||
# Verify the agent is marked as crashed
|
# Verify the agent is marked as crashed
|
||||||
assert is_crashed()
|
assert is_crashed()
|
||||||
|
|
||||||
|
|
||||||
def test_handle_api_error_resource_exhausted():
|
def test_handle_api_error_resource_exhausted():
|
||||||
from google.api_core.exceptions import ResourceExhausted
|
from google.api_core.exceptions import ResourceExhausted
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ def mock_dependencies(monkeypatch):
|
||||||
"""Mock all dependencies needed for main()."""
|
"""Mock all dependencies needed for main()."""
|
||||||
# Initialize global memory with necessary keys to prevent KeyError
|
# Initialize global memory with necessary keys to prevent KeyError
|
||||||
_global_memory.clear()
|
_global_memory.clear()
|
||||||
_global_memory["agent_depth"] = 0
|
|
||||||
_global_memory["config"] = {}
|
_global_memory["config"] = {}
|
||||||
|
|
||||||
# Mock dependencies that interact with external systems
|
# Mock dependencies that interact with external systems
|
||||||
|
|
@ -193,7 +192,6 @@ def test_temperature_validation(mock_dependencies):
|
||||||
|
|
||||||
# Reset global memory for clean test
|
# Reset global memory for clean test
|
||||||
_global_memory.clear()
|
_global_memory.clear()
|
||||||
_global_memory["agent_depth"] = 0
|
|
||||||
_global_memory["config"] = {}
|
_global_memory["config"] = {}
|
||||||
|
|
||||||
# Test valid temperature (0.7)
|
# Test valid temperature (0.7)
|
||||||
|
|
|
||||||
|
|
@ -154,8 +154,9 @@ def test_request_research_uses_key_fact_repository(reset_memory, mock_functions)
|
||||||
|
|
||||||
def test_request_research_max_depth(reset_memory, mock_functions):
|
def test_request_research_max_depth(reset_memory, mock_functions):
|
||||||
"""Test that max recursion depth handling uses KeyFactRepository."""
|
"""Test that max recursion depth handling uses KeyFactRepository."""
|
||||||
# Set recursion depth to max
|
# Mock depth using context-based approach
|
||||||
_global_memory["agent_depth"] = 3
|
with patch('ra_aid.tools.agent.get_depth') as mock_get_depth:
|
||||||
|
mock_get_depth.return_value = 3
|
||||||
|
|
||||||
# Call the function (should hit max depth case)
|
# Call the function (should hit max depth case)
|
||||||
result = request_research("test query")
|
result = request_research("test query")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue