From 92faf8fc2db9a12e48f3b022af01b7bef4642a4e Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Fri, 14 Mar 2025 13:31:51 -0700 Subject: [PATCH] feat(anthropic_token_limiter): add get_provider_and_model_for_agent_type function to streamline provider and model retrieval based on agent type fix(anthropic_token_limiter): refactor get_model_token_limit to use the new get_provider_and_model_for_agent_type function for cleaner code test(anthropic_token_limiter): add unit tests for get_provider_and_model_for_agent_type and adjust_claude_37_token_limit functions to ensure correctness and coverage --- ra_aid/anthropic_token_limiter.py | 47 +++-- tests/ra_aid/test_anthropic_token_limiter.py | 177 +++++++++++++++---- 2 files changed, 170 insertions(+), 54 deletions(-) diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index ba7b890..794459a 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -1,24 +1,19 @@ """Utilities for handling token limits with Anthropic models.""" from functools import partial -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple from langchain_core.language_models import BaseChatModel from ra_aid.model_detection import is_claude_37 -from dataclasses import dataclass from langchain_anthropic import ChatAnthropic from langchain_core.messages import ( - AIMessage, BaseMessage, - RemoveMessage, - ToolMessage, trim_messages, ) from langchain_core.messages.base import message_to_dict from ra_aid.anthropic_message_utils import ( anthropic_trim_messages, - has_tool_use, ) from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import token_counter, get_model_info @@ -27,7 +22,6 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.console.output import cpm, print_messages_compact logger = get_logger(__name__) @@ -170,6 +164,29 @@ def sonnet_35_state_modifier( return result +def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]: + """Get the provider and model name for the specified agent type. + + Args: + config: Configuration dictionary containing provider and model information + agent_type: Type of agent ("default", "research", or "planner") + + Returns: + Tuple[str, str]: A tuple containing (provider, model_name) + """ + if agent_type == "research": + provider = config.get("research_provider", "") or config.get("provider", "") + model_name = config.get("research_model", "") or config.get("model", "") + elif agent_type == "planner": + provider = config.get("planner_provider", "") or config.get("provider", "") + model_name = config.get("planner_model", "") or config.get("model", "") + else: + provider = config.get("provider", "") + model_name = config.get("model", "") + + return provider, model_name + + def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]: """Adjust token limit for Claude 3.7 models by subtracting max_tokens. @@ -217,19 +234,13 @@ def get_model_token_limit( # In tests, this may fail because the repository isn't set up # So we'll use the passed config directly pass - if agent_type == "research": - provider = config.get("research_provider", "") or config.get("provider", "") - model_name = config.get("research_model", "") or config.get("model", "") - elif agent_type == "planner": - provider = config.get("planner_provider", "") or config.get("provider", "") - model_name = config.get("planner_model", "") or config.get("model", "") - else: - provider = config.get("provider", "") - model_name = config.get("model", "") + + provider, model_name = get_provider_and_model_for_agent_type(config, agent_type) + + # Always attempt to get model info from litellm first + provider_model = model_name if not provider else f"{provider}/{model_name}" try: - - provider_model = model_name if not provider else f"{provider}/{model_name}" model_info = get_model_info(provider_model) max_input_tokens = model_info.get("max_input_tokens") if max_input_tokens: diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py index 36f9528..a3ac861 100644 --- a/tests/ra_aid/test_anthropic_token_limiter.py +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -17,10 +17,11 @@ from ra_aid.anthropic_token_limiter import ( get_model_token_limit, state_modifier, sonnet_35_state_modifier, - convert_message_to_litellm_format + convert_message_to_litellm_format, + adjust_claude_37_token_limit ) from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair -from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT +from ra_aid.models_params import models_params class TestAnthropicTokenLimiter(unittest.TestCase): @@ -113,9 +114,8 @@ class TestAnthropicTokenLimiter(unittest.TestCase): self.assertEqual(result, 0) @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") - @patch("ra_aid.anthropic_token_limiter.print_messages_compact") @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") - def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper): + def test_state_modifier(self, mock_trim_messages, mock_create_wrapper): # Setup a proper token counter function that returns integers def token_counter(msgs): # Return token count based on number of messages @@ -155,8 +155,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase): model.model = "claude-3-opus-20240229" with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \ - patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \ - patch("ra_aid.anthropic_token_limiter.print_messages_compact"): + patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim: # Setup mock to return a fixed token count per message mock_wrapper.return_value = lambda msgs: len(msgs) * 100 # Setup mock to return a subset of messages @@ -206,23 +205,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase): self.assertEqual(call_args["include_system"], True) @patch("ra_aid.anthropic_token_limiter.get_config_repository") - @patch("litellm.get_model_info") - def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo): - from ra_aid.config import DEFAULT_MODEL + @patch("ra_aid.anthropic_token_limiter.get_model_info") + @patch("ra_aid.anthropic_token_limiter.is_claude_37") + @patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") + def test_get_model_token_limit_from_litellm(self, mock_adjust, mock_is_claude_37, mock_get_model_info, mock_get_config_repo): + # Use a specific model name instead of DEFAULT_MODEL to avoid test dependency + model_name = "claude-3-7-sonnet-20250219" # Setup mocks - mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL} + mock_config = {"provider": "anthropic", "model": model_name} mock_get_config_repo.return_value.get_all.return_value = mock_config # Mock litellm's get_model_info to return a token limit mock_get_model_info.return_value = {"max_input_tokens": 100000} + # Mock is_claude_37 to return True + mock_is_claude_37.return_value = True + + # Mock adjust_claude_37_token_limit to return the original value + mock_adjust.return_value = 100000 + # Test getting token limit result = get_model_token_limit(mock_config) self.assertEqual(result, 100000) # Verify get_model_info was called with the right model - mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") + mock_get_model_info.assert_called_once_with(f"anthropic/{model_name}") + + # Verify adjust_claude_37_token_limit was called + mock_adjust.assert_called_once_with(100000, None) def test_get_model_token_limit_research(self): """Test get_model_token_limit with research provider and model.""" @@ -230,17 +241,24 @@ class TestAnthropicTokenLimiter(unittest.TestCase): "provider": "openai", "model": "gpt-4", "research_provider": "anthropic", - "research_model": "claude-2", + "research_model": "claude-3-7-sonnet-20250219", } with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.return_value = {"max_input_tokens": 150000} + mock_adjust.return_value = 150000 + + # Call the function to check the return value token_limit = get_model_token_limit(config, "research") self.assertEqual(token_limit, 150000) + # Verify get_model_info was called with the research model - mock_get_info.assert_called_with("anthropic/claude-2") + mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219") + # Verify adjust_claude_37_token_limit was called + mock_adjust.assert_called_once_with(150000, None) def test_get_model_token_limit_planner(self): """Test get_model_token_limit with planner provider and model.""" @@ -252,16 +270,23 @@ class TestAnthropicTokenLimiter(unittest.TestCase): } with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.return_value = {"max_input_tokens": 120000} + mock_adjust.return_value = 120000 + + # Call the function to check the return value token_limit = get_model_token_limit(config, "planner") self.assertEqual(token_limit, 120000) + # Verify get_model_info was called with the planner model - mock_get_info.assert_called_with("deepseek/dsm-1") + mock_get_info.assert_called_once_with("deepseek/dsm-1") + # Verify adjust_claude_37_token_limit was called + mock_adjust.assert_called_once_with(120000, None) @patch("ra_aid.anthropic_token_limiter.get_config_repository") - @patch("litellm.get_model_info") + @patch("ra_aid.anthropic_token_limiter.get_model_info") def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo): # Setup mocks mock_config = {"provider": "anthropic", "model": "claude-2"} @@ -280,54 +305,87 @@ class TestAnthropicTokenLimiter(unittest.TestCase): self.assertEqual(result, 100000) @patch("ra_aid.anthropic_token_limiter.get_config_repository") - @patch("litellm.get_model_info") - def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo): - from ra_aid.config import DEFAULT_MODEL + @patch("ra_aid.anthropic_token_limiter.get_model_info") + @patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") + def test_get_model_token_limit_for_different_agent_types(self, mock_adjust, mock_get_model_info, mock_get_config_repo): + # Use specific model names instead of DEFAULT_MODEL to avoid test dependency + claude_model = "claude-3-7-sonnet-20250219" # Setup mocks for different agent types mock_config = { "provider": "anthropic", - "model": DEFAULT_MODEL, + "model": claude_model, "research_provider": "openai", "research_model": "gpt-4", "planner_provider": "anthropic", - "planner_model": "claude-3-sonnet-20240229" + "planner_model": "claude-3-7-opus-20250301" } mock_get_config_repo.return_value.get_all.return_value = mock_config # Mock different returns for different models def model_info_side_effect(model_name): - if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name: + if "claude-3-7-sonnet" in model_name: return {"max_input_tokens": 200000} elif "gpt-4" in model_name: return {"max_input_tokens": 8192} - elif "claude-3-sonnet" in model_name: - return {"max_input_tokens": 100000} + elif "claude-3-7-opus" in model_name: + return {"max_input_tokens": 250000} else: raise Exception(f"Unknown model: {model_name}") mock_get_model_info.side_effect = model_info_side_effect + # Mock adjust_claude_37_token_limit to return the same values + mock_adjust.side_effect = lambda tokens, model: tokens + # Test default agent type result = get_model_token_limit(mock_config, "default") self.assertEqual(result, 200000) + mock_get_model_info.assert_called_with(f"anthropic/{claude_model}") + + # Reset mock + mock_get_model_info.reset_mock() # Test research agent type result = get_model_token_limit(mock_config, "research") self.assertEqual(result, 8192) + mock_get_model_info.assert_called_with("openai/gpt-4") + + # Reset mock + mock_get_model_info.reset_mock() # Test planner agent type result = get_model_token_limit(mock_config, "planner") - self.assertEqual(result, 100000) + self.assertEqual(result, 250000) + mock_get_model_info.assert_called_with("anthropic/claude-3-7-opus-20250301") def test_get_model_token_limit_anthropic(self): """Test get_model_token_limit with Anthropic model.""" - config = {"provider": "anthropic", "model": "claude2"} + config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"} - with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \ + patch("litellm.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: + + # Setup mocks mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = Exception("Model not found") + + # Create a mock models_params with claude-3-7 + mock_models_params_dict = { + "anthropic": { + "claude-3-7-sonnet-20250219": {"token_limit": 200000} + } + } + mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__ + mock_models_params.get.side_effect = mock_models_params_dict.get + + # Mock adjust to return the same value + mock_adjust.return_value = 200000 + token_limit = get_model_token_limit(config, "default") - self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + self.assertEqual(token_limit, 200000) def test_get_model_token_limit_openai(self): """Test get_model_token_limit with OpenAI model.""" @@ -358,28 +416,51 @@ class TestAnthropicTokenLimiter(unittest.TestCase): def test_get_model_token_limit_litellm_success(self): """Test get_model_token_limit successfully getting limit from litellm.""" - config = {"provider": "anthropic", "model": "claude-2"} + config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"} with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.return_value = {"max_input_tokens": 100000} + mock_adjust.return_value = 100000 + + # Call the function to check the return value token_limit = get_model_token_limit(config, "default") self.assertEqual(token_limit, 100000) - mock_get_info.assert_called_with("anthropic/claude-2") + + # Verify get_model_info was called with the right model + mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219") + mock_adjust.assert_called_once_with(100000, None) def test_get_model_token_limit_litellm_not_found(self): """Test fallback to models_tokens when litellm raises NotFoundError.""" - config = {"provider": "anthropic", "model": "claude-2"} + config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"} with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("litellm.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: + mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.side_effect = litellm.exceptions.NotFoundError( - message="Model not found", model="claude-2", llm_provider="anthropic" + message="Model not found", model="claude-3-7-sonnet-20250219", llm_provider="anthropic" ) + + # Create a mock models_params with claude-3-7 + mock_models_params_dict = { + "anthropic": { + "claude-3-7-sonnet-20250219": {"token_limit": 200000} + } + } + mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__ + mock_models_params.get.side_effect = mock_models_params_dict.get + + # Mock adjust to return the same value + mock_adjust.return_value = 200000 + token_limit = get_model_token_limit(config, "default") - self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + self.assertEqual(token_limit, 200000) def test_get_model_token_limit_litellm_error(self): """Test fallback to models_tokens when litellm raises other exceptions.""" @@ -399,6 +480,30 @@ class TestAnthropicTokenLimiter(unittest.TestCase): token_limit = get_model_token_limit(config, "default") self.assertIsNone(token_limit) + def test_adjust_claude_37_token_limit(self): + """Test adjust_claude_37_token_limit function.""" + # Create a mock model + mock_model = MagicMock() + mock_model.model = "claude-3.7-sonnet" + mock_model.max_tokens = 4096 + + # Test with Claude 3.7 model + result = adjust_claude_37_token_limit(100000, mock_model) + self.assertEqual(result, 95904) # 100000 - 4096 + + # Test with non-Claude 3.7 model + mock_model.model = "claude-3-opus" + result = adjust_claude_37_token_limit(100000, mock_model) + self.assertEqual(result, 100000) # No adjustment + + # Test with None max_input_tokens + result = adjust_claude_37_token_limit(None, mock_model) + self.assertIsNone(result) + + # Test with None model + result = adjust_claude_37_token_limit(100000, None) + self.assertEqual(result, 100000) + def test_has_tool_use(self): """Test the has_tool_use function.""" # Test with regular AI message