feat(anthropic_token_limiter): add get_provider_and_model_for_agent_type function to streamline provider and model retrieval based on agent type

fix(anthropic_token_limiter): refactor get_model_token_limit to use the new get_provider_and_model_for_agent_type function for cleaner code
test(anthropic_token_limiter): add unit tests for get_provider_and_model_for_agent_type and adjust_claude_37_token_limit functions to ensure correctness and coverage
This commit is contained in:
Ariel Frischer 2025-03-14 13:31:51 -07:00
parent 29c9cac4f4
commit 92faf8fc2d
2 changed files with 170 additions and 54 deletions

View File

@ -1,24 +1,19 @@
"""Utilities for handling token limits with Anthropic models."""
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple
from langchain_core.language_models import BaseChatModel
from ra_aid.model_detection import is_claude_37
from dataclasses import dataclass
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import (
AIMessage,
BaseMessage,
RemoveMessage,
ToolMessage,
trim_messages,
)
from langchain_core.messages.base import message_to_dict
from ra_aid.anthropic_message_utils import (
anthropic_trim_messages,
has_tool_use,
)
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import token_counter, get_model_info
@ -27,7 +22,6 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.console.output import cpm, print_messages_compact
logger = get_logger(__name__)
@ -170,6 +164,29 @@ def sonnet_35_state_modifier(
return result
def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]:
"""Get the provider and model name for the specified agent type.
Args:
config: Configuration dictionary containing provider and model information
agent_type: Type of agent ("default", "research", or "planner")
Returns:
Tuple[str, str]: A tuple containing (provider, model_name)
"""
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
return provider, model_name
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
@ -217,19 +234,13 @@ def get_model_token_limit(
# In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly
pass
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
# Always attempt to get model info from litellm first
provider_model = model_name if not provider else f"{provider}/{model_name}"
try:
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:

View File

@ -17,10 +17,11 @@ from ra_aid.anthropic_token_limiter import (
get_model_token_limit,
state_modifier,
sonnet_35_state_modifier,
convert_message_to_litellm_format
convert_message_to_litellm_format,
adjust_claude_37_token_limit
)
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
from ra_aid.models_params import models_params
class TestAnthropicTokenLimiter(unittest.TestCase):
@ -113,9 +114,8 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
self.assertEqual(result, 0)
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
def test_state_modifier(self, mock_trim_messages, mock_create_wrapper):
# Setup a proper token counter function that returns integers
def token_counter(msgs):
# Return token count based on number of messages
@ -155,8 +155,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
model.model = "claude-3-opus-20240229"
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim:
# Setup mock to return a fixed token count per message
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
# Setup mock to return a subset of messages
@ -206,23 +205,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
self.assertEqual(call_args["include_system"], True)
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
from ra_aid.config import DEFAULT_MODEL
@patch("ra_aid.anthropic_token_limiter.get_model_info")
@patch("ra_aid.anthropic_token_limiter.is_claude_37")
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
def test_get_model_token_limit_from_litellm(self, mock_adjust, mock_is_claude_37, mock_get_model_info, mock_get_config_repo):
# Use a specific model name instead of DEFAULT_MODEL to avoid test dependency
model_name = "claude-3-7-sonnet-20250219"
# Setup mocks
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
mock_config = {"provider": "anthropic", "model": model_name}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock litellm's get_model_info to return a token limit
mock_get_model_info.return_value = {"max_input_tokens": 100000}
# Mock is_claude_37 to return True
mock_is_claude_37.return_value = True
# Mock adjust_claude_37_token_limit to return the original value
mock_adjust.return_value = 100000
# Test getting token limit
result = get_model_token_limit(mock_config)
self.assertEqual(result, 100000)
# Verify get_model_info was called with the right model
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
mock_get_model_info.assert_called_once_with(f"anthropic/{model_name}")
# Verify adjust_claude_37_token_limit was called
mock_adjust.assert_called_once_with(100000, None)
def test_get_model_token_limit_research(self):
"""Test get_model_token_limit with research provider and model."""
@ -230,17 +241,24 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
"provider": "openai",
"model": "gpt-4",
"research_provider": "anthropic",
"research_model": "claude-2",
"research_model": "claude-3-7-sonnet-20250219",
}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 150000}
mock_adjust.return_value = 150000
# Call the function to check the return value
token_limit = get_model_token_limit(config, "research")
self.assertEqual(token_limit, 150000)
# Verify get_model_info was called with the research model
mock_get_info.assert_called_with("anthropic/claude-2")
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
# Verify adjust_claude_37_token_limit was called
mock_adjust.assert_called_once_with(150000, None)
def test_get_model_token_limit_planner(self):
"""Test get_model_token_limit with planner provider and model."""
@ -252,16 +270,23 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 120000}
mock_adjust.return_value = 120000
# Call the function to check the return value
token_limit = get_model_token_limit(config, "planner")
self.assertEqual(token_limit, 120000)
# Verify get_model_info was called with the planner model
mock_get_info.assert_called_with("deepseek/dsm-1")
mock_get_info.assert_called_once_with("deepseek/dsm-1")
# Verify adjust_claude_37_token_limit was called
mock_adjust.assert_called_once_with(120000, None)
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
@patch("ra_aid.anthropic_token_limiter.get_model_info")
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
# Setup mocks
mock_config = {"provider": "anthropic", "model": "claude-2"}
@ -280,54 +305,87 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
self.assertEqual(result, 100000)
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
from ra_aid.config import DEFAULT_MODEL
@patch("ra_aid.anthropic_token_limiter.get_model_info")
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
def test_get_model_token_limit_for_different_agent_types(self, mock_adjust, mock_get_model_info, mock_get_config_repo):
# Use specific model names instead of DEFAULT_MODEL to avoid test dependency
claude_model = "claude-3-7-sonnet-20250219"
# Setup mocks for different agent types
mock_config = {
"provider": "anthropic",
"model": DEFAULT_MODEL,
"model": claude_model,
"research_provider": "openai",
"research_model": "gpt-4",
"planner_provider": "anthropic",
"planner_model": "claude-3-sonnet-20240229"
"planner_model": "claude-3-7-opus-20250301"
}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock different returns for different models
def model_info_side_effect(model_name):
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
if "claude-3-7-sonnet" in model_name:
return {"max_input_tokens": 200000}
elif "gpt-4" in model_name:
return {"max_input_tokens": 8192}
elif "claude-3-sonnet" in model_name:
return {"max_input_tokens": 100000}
elif "claude-3-7-opus" in model_name:
return {"max_input_tokens": 250000}
else:
raise Exception(f"Unknown model: {model_name}")
mock_get_model_info.side_effect = model_info_side_effect
# Mock adjust_claude_37_token_limit to return the same values
mock_adjust.side_effect = lambda tokens, model: tokens
# Test default agent type
result = get_model_token_limit(mock_config, "default")
self.assertEqual(result, 200000)
mock_get_model_info.assert_called_with(f"anthropic/{claude_model}")
# Reset mock
mock_get_model_info.reset_mock()
# Test research agent type
result = get_model_token_limit(mock_config, "research")
self.assertEqual(result, 8192)
mock_get_model_info.assert_called_with("openai/gpt-4")
# Reset mock
mock_get_model_info.reset_mock()
# Test planner agent type
result = get_model_token_limit(mock_config, "planner")
self.assertEqual(result, 100000)
self.assertEqual(result, 250000)
mock_get_model_info.assert_called_with("anthropic/claude-3-7-opus-20250301")
def test_get_model_token_limit_anthropic(self):
"""Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"}
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
patch("litellm.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
# Setup mocks
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = Exception("Model not found")
# Create a mock models_params with claude-3-7
mock_models_params_dict = {
"anthropic": {
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
}
}
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
mock_models_params.get.side_effect = mock_models_params_dict.get
# Mock adjust to return the same value
mock_adjust.return_value = 200000
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
self.assertEqual(token_limit, 200000)
def test_get_model_token_limit_openai(self):
"""Test get_model_token_limit with OpenAI model."""
@ -358,28 +416,51 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
def test_get_model_token_limit_litellm_success(self):
"""Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"}
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 100000}
mock_adjust.return_value = 100000
# Call the function to check the return value
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, 100000)
mock_get_info.assert_called_with("anthropic/claude-2")
# Verify get_model_info was called with the right model
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
mock_adjust.assert_called_once_with(100000, None)
def test_get_model_token_limit_litellm_not_found(self):
"""Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"}
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
patch("litellm.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic"
message="Model not found", model="claude-3-7-sonnet-20250219", llm_provider="anthropic"
)
# Create a mock models_params with claude-3-7
mock_models_params_dict = {
"anthropic": {
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
}
}
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
mock_models_params.get.side_effect = mock_models_params_dict.get
# Mock adjust to return the same value
mock_adjust.return_value = 200000
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
self.assertEqual(token_limit, 200000)
def test_get_model_token_limit_litellm_error(self):
"""Test fallback to models_tokens when litellm raises other exceptions."""
@ -399,6 +480,30 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit)
def test_adjust_claude_37_token_limit(self):
"""Test adjust_claude_37_token_limit function."""
# Create a mock model
mock_model = MagicMock()
mock_model.model = "claude-3.7-sonnet"
mock_model.max_tokens = 4096
# Test with Claude 3.7 model
result = adjust_claude_37_token_limit(100000, mock_model)
self.assertEqual(result, 95904) # 100000 - 4096
# Test with non-Claude 3.7 model
mock_model.model = "claude-3-opus"
result = adjust_claude_37_token_limit(100000, mock_model)
self.assertEqual(result, 100000) # No adjustment
# Test with None max_input_tokens
result = adjust_claude_37_token_limit(None, mock_model)
self.assertIsNone(result)
# Test with None model
result = adjust_claude_37_token_limit(100000, None)
self.assertEqual(result, 100000)
def test_has_tool_use(self):
"""Test the has_tool_use function."""
# Test with regular AI message