diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 8b1708b..a170519 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -236,11 +236,11 @@ def is_anthropic_claude(config: Dict[str, Any]) -> bool: """ provider = config.get("provider", "") model_name = config.get("model", "") - return ( - provider.lower() == "anthropic" - and model_name - and "claude" in model_name.lower() + result = ( + (provider.lower() == "anthropic" and model_name and "claude" in model_name.lower()) + or (provider.lower() == "openrouter" and model_name.lower().startswith("anthropic/claude-")) ) + return result def create_agent( diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 5346921..a9eb515 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -12,6 +12,7 @@ from ra_aid.agent_utils import ( AgentState, create_agent, get_model_token_limit, + is_anthropic_claude, state_modifier, ) from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params @@ -401,3 +402,22 @@ def test_handle_api_error_retry(monkeypatch): monkeypatch.setattr(time, "sleep", lambda s: None) # Should not raise error when attempt is lower than max retries _handle_api_error(Exception("error code 429"), 0, 5, 1) + + +def test_is_anthropic_claude(): + """Test is_anthropic_claude function with various configurations.""" + # Test Anthropic provider cases + assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"}) + assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"}) + assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"}) + + # Test OpenRouter provider cases + assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-2"}) + assert is_anthropic_claude({"provider": "openrouter", "model": "anthropic/claude-instant"}) + assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"}) + + # Test edge cases + assert not is_anthropic_claude({}) # Empty config + assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model + assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider + assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider