use create_react_agent for sonnet via openrouter

This commit is contained in:
AI Christianson 2025-02-18 22:09:00 -05:00
parent 66aa13f6ee
commit 2741b54357
2 changed files with 139 additions and 4 deletions

View File

@ -225,11 +225,11 @@ def is_anthropic_claude(config: Dict[str, Any]) -> bool:
"""
provider = config.get("provider", "")
model_name = config.get("model", "")
return (
provider.lower() == "anthropic"
and model_name
and "claude" in model_name.lower()
result = (
(provider.lower() == "anthropic" and model_name and "claude" in model_name.lower())
or (provider.lower() == "openrouter" and model_name.lower().startswith("anthropic/claude-"))
)
return result
def create_agent(

View File

@ -11,6 +11,7 @@ from ra_aid.agent_utils import (
AgentState,
create_agent,
get_model_token_limit,
is_anthropic_claude,
state_modifier,
)
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
@ -275,3 +276,137 @@ def test_get_model_token_limit_planner(mock_memory):
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000
# New tests for private helper methods in agent_utils.py
def test_setup_and_restore_interrupt_handling():
import signal
from ra_aid.agent_utils import (
_request_interrupt,
_restore_interrupt_handling,
_setup_interrupt_handling,
)
original_handler = signal.getsignal(signal.SIGINT)
handler = _setup_interrupt_handling()
# Verify the SIGINT handler is set to _request_interrupt
assert signal.getsignal(signal.SIGINT) == _request_interrupt
_restore_interrupt_handling(handler)
# Verify the SIGINT handler is restored to the original
assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth():
from ra_aid.agent_utils import (
_decrement_agent_depth,
_global_memory,
_increment_agent_depth,
)
_global_memory["agent_depth"] = 10
_increment_agent_depth()
assert _global_memory["agent_depth"] == 11
_decrement_agent_depth()
assert _global_memory["agent_depth"] == 10
def test_run_agent_stream(monkeypatch):
from ra_aid.agent_utils import _global_memory, _run_agent_stream
# Create a dummy agent that yields one chunk
class DummyAgent:
def stream(self, input_data, cfg: dict):
yield {"content": "chunk1"}
dummy_agent = DummyAgent()
# Set flags so that _run_agent_stream will reset them
_global_memory["plan_completed"] = True
_global_memory["task_completed"] = True
_global_memory["completion_message"] = "existing"
call_flag = {"called": False}
def fake_print_agent_output(
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
):
call_flag["called"] = True
monkeypatch.setattr(
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
)
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
assert call_flag["called"]
assert _global_memory["plan_completed"] is False
assert _global_memory["task_completed"] is False
assert _global_memory["completion_message"] == ""
def test_execute_test_command_wrapper(monkeypatch):
from ra_aid.agent_utils import _execute_test_command_wrapper
# Patch execute_test_command to return a testable tuple
def fake_execute(config, orig, tests, auto):
return (True, "new prompt", auto, tests + 1)
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
result = _execute_test_command_wrapper("orig", {}, 0, False)
assert result == (True, "new prompt", False, 1)
def test_handle_api_error_valueerror():
import pytest
from ra_aid.agent_utils import _handle_api_error
# ValueError not containing "code" or "429" should be re-raised
with pytest.raises(ValueError):
_handle_api_error(ValueError("some error"), 0, 5, 1)
def test_handle_api_error_max_retries():
import pytest
from ra_aid.agent_utils import _handle_api_error
# When attempt reaches max retries, a RuntimeError should be raised
with pytest.raises(RuntimeError):
_handle_api_error(Exception("error code 429"), 4, 5, 1)
def test_handle_api_error_retry(monkeypatch):
import time
from ra_aid.agent_utils import _handle_api_error
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
fake_time = [0]
def fake_monotonic():
fake_time[0] += 0.5
return fake_time[0]
monkeypatch.setattr(time, "monotonic", fake_monotonic)
monkeypatch.setattr(time, "sleep", lambda s: None)
# Should not raise error when attempt is lower than max retries
_handle_api_error(Exception("error code 429"), 0, 5, 1)
def test_is_anthropic_claude():
"""Test is_anthropic_claude function with various configurations."""
# Test Anthropic provider cases
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
# Test OpenRouter provider cases
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-2"})
assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-instant"})
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
# Test edge cases
assert not is_anthropic_claude({}) # Empty config
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider