diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 33f6b9e..a8ebdb0 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -225,11 +225,11 @@ def is_anthropic_claude(config: Dict[str, Any]) -> bool: """ provider = config.get("provider", "") model_name = config.get("model", "") - return ( - provider.lower() == "anthropic" - and model_name - and "claude" in model_name.lower() + result = ( + (provider.lower() == "anthropic" and model_name and "claude" in model_name.lower()) + or (provider.lower() == "openrouter" and model_name.lower().startswith("anthropic/claude-")) ) + return result def create_agent( diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 5e935ed..e08ef81 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -11,6 +11,7 @@ from ra_aid.agent_utils import ( AgentState, create_agent, get_model_token_limit, + is_anthropic_claude, state_modifier, ) from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params @@ -275,3 +276,137 @@ def test_get_model_token_limit_planner(mock_memory): mock_get_info.return_value = {"max_input_tokens": 120000} token_limit = get_model_token_limit(config, "planner") assert token_limit == 120000 + + +# New tests for private helper methods in agent_utils.py + + +def test_setup_and_restore_interrupt_handling(): + import signal + + from ra_aid.agent_utils import ( + _request_interrupt, + _restore_interrupt_handling, + _setup_interrupt_handling, + ) + + original_handler = signal.getsignal(signal.SIGINT) + handler = _setup_interrupt_handling() + # Verify the SIGINT handler is set to _request_interrupt + assert signal.getsignal(signal.SIGINT) == _request_interrupt + _restore_interrupt_handling(handler) + # Verify the SIGINT handler is restored to the original + assert signal.getsignal(signal.SIGINT) == original_handler + + +def test_increment_and_decrement_agent_depth(): + from ra_aid.agent_utils import ( + _decrement_agent_depth, + _global_memory, + _increment_agent_depth, + ) + + _global_memory["agent_depth"] = 10 + _increment_agent_depth() + assert _global_memory["agent_depth"] == 11 + _decrement_agent_depth() + assert _global_memory["agent_depth"] == 10 + + +def test_run_agent_stream(monkeypatch): + from ra_aid.agent_utils import _global_memory, _run_agent_stream + + # Create a dummy agent that yields one chunk + class DummyAgent: + def stream(self, input_data, cfg: dict): + yield {"content": "chunk1"} + + dummy_agent = DummyAgent() + # Set flags so that _run_agent_stream will reset them + _global_memory["plan_completed"] = True + _global_memory["task_completed"] = True + _global_memory["completion_message"] = "existing" + call_flag = {"called": False} + + def fake_print_agent_output( + chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"] + ): + call_flag["called"] = True + + monkeypatch.setattr( + "ra_aid.agent_utils.print_agent_output", fake_print_agent_output + ) + _run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {}) + assert call_flag["called"] + assert _global_memory["plan_completed"] is False + assert _global_memory["task_completed"] is False + assert _global_memory["completion_message"] == "" + + +def test_execute_test_command_wrapper(monkeypatch): + from ra_aid.agent_utils import _execute_test_command_wrapper + + # Patch execute_test_command to return a testable tuple + def fake_execute(config, orig, tests, auto): + return (True, "new prompt", auto, tests + 1) + + monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute) + result = _execute_test_command_wrapper("orig", {}, 0, False) + assert result == (True, "new prompt", False, 1) + + +def test_handle_api_error_valueerror(): + import pytest + + from ra_aid.agent_utils import _handle_api_error + + # ValueError not containing "code" or "429" should be re-raised + with pytest.raises(ValueError): + _handle_api_error(ValueError("some error"), 0, 5, 1) + + +def test_handle_api_error_max_retries(): + import pytest + + from ra_aid.agent_utils import _handle_api_error + + # When attempt reaches max retries, a RuntimeError should be raised + with pytest.raises(RuntimeError): + _handle_api_error(Exception("error code 429"), 4, 5, 1) + + +def test_handle_api_error_retry(monkeypatch): + import time + + from ra_aid.agent_utils import _handle_api_error + + # Patch time.monotonic and time.sleep to simulate immediate delay expiration + fake_time = [0] + + def fake_monotonic(): + fake_time[0] += 0.5 + return fake_time[0] + + monkeypatch.setattr(time, "monotonic", fake_monotonic) + monkeypatch.setattr(time, "sleep", lambda s: None) + # Should not raise error when attempt is lower than max retries + _handle_api_error(Exception("error code 429"), 0, 5, 1) + + +def test_is_anthropic_claude(): + """Test is_anthropic_claude function with various configurations.""" + # Test Anthropic provider cases + assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"}) + assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"}) + assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"}) + + # Test OpenRouter provider cases + assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-2"}) + assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-instant"}) + assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"}) + + # Test edge cases + assert not is_anthropic_claude({}) # Empty config + assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model + assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider + assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider