feat(anthropic_token_limiter): add get_provider_and_model_for_agent_type function to streamline provider and model retrieval based on agent type
fix(anthropic_token_limiter): refactor get_model_token_limit to use the new get_provider_and_model_for_agent_type function for cleaner code test(anthropic_token_limiter): add unit tests for get_provider_and_model_for_agent_type and adjust_claude_37_token_limit functions to ensure correctness and coverage
This commit is contained in:
parent
29c9cac4f4
commit
92faf8fc2d
|
|
@ -1,24 +1,19 @@
|
|||
"""Utilities for handling token limits with Anthropic models."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from ra_aid.model_detection import is_claude_37
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.messages.base import message_to_dict
|
||||
|
||||
from ra_aid.anthropic_message_utils import (
|
||||
anthropic_trim_messages,
|
||||
has_tool_use,
|
||||
)
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import token_counter, get_model_info
|
||||
|
|
@ -27,7 +22,6 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
|||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.console.output import cpm, print_messages_compact
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -170,6 +164,29 @@ def sonnet_35_state_modifier(
|
|||
return result
|
||||
|
||||
|
||||
def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]:
|
||||
"""Get the provider and model name for the specified agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing (provider, model_name)
|
||||
"""
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
return provider, model_name
|
||||
|
||||
|
||||
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
|
||||
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
||||
|
||||
|
|
@ -217,19 +234,13 @@ def get_model_token_limit(
|
|||
# In tests, this may fail because the repository isn't set up
|
||||
# So we'll use the passed config directly
|
||||
pass
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
|
||||
|
||||
# Always attempt to get model info from litellm first
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
|
||||
try:
|
||||
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ from ra_aid.anthropic_token_limiter import (
|
|||
get_model_token_limit,
|
||||
state_modifier,
|
||||
sonnet_35_state_modifier,
|
||||
convert_message_to_litellm_format
|
||||
convert_message_to_litellm_format,
|
||||
adjust_claude_37_token_limit
|
||||
)
|
||||
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.models_params import models_params
|
||||
|
||||
|
||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||
|
|
@ -113,9 +114,8 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
self.assertEqual(result, 0)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
|
||||
def test_state_modifier(self, mock_trim_messages, mock_create_wrapper):
|
||||
# Setup a proper token counter function that returns integers
|
||||
def token_counter(msgs):
|
||||
# Return token count based on number of messages
|
||||
|
|
@ -155,8 +155,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
model.model = "claude-3-opus-20240229"
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
|
||||
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
|
||||
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
|
||||
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim:
|
||||
# Setup mock to return a fixed token count per message
|
||||
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
|
||||
# Setup mock to return a subset of messages
|
||||
|
|
@ -206,23 +205,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
self.assertEqual(call_args["include_system"], True)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||
@patch("ra_aid.anthropic_token_limiter.is_claude_37")
|
||||
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
|
||||
def test_get_model_token_limit_from_litellm(self, mock_adjust, mock_is_claude_37, mock_get_model_info, mock_get_config_repo):
|
||||
# Use a specific model name instead of DEFAULT_MODEL to avoid test dependency
|
||||
model_name = "claude-3-7-sonnet-20250219"
|
||||
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
||||
mock_config = {"provider": "anthropic", "model": model_name}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock litellm's get_model_info to return a token limit
|
||||
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
||||
|
||||
# Mock is_claude_37 to return True
|
||||
mock_is_claude_37.return_value = True
|
||||
|
||||
# Mock adjust_claude_37_token_limit to return the original value
|
||||
mock_adjust.return_value = 100000
|
||||
|
||||
# Test getting token limit
|
||||
result = get_model_token_limit(mock_config)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
# Verify get_model_info was called with the right model
|
||||
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
||||
mock_get_model_info.assert_called_once_with(f"anthropic/{model_name}")
|
||||
|
||||
# Verify adjust_claude_37_token_limit was called
|
||||
mock_adjust.assert_called_once_with(100000, None)
|
||||
|
||||
def test_get_model_token_limit_research(self):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
|
|
@ -230,17 +241,24 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2",
|
||||
"research_model": "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
mock_adjust.return_value = 150000
|
||||
|
||||
# Call the function to check the return value
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
self.assertEqual(token_limit, 150000)
|
||||
|
||||
# Verify get_model_info was called with the research model
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
|
||||
# Verify adjust_claude_37_token_limit was called
|
||||
mock_adjust.assert_called_once_with(150000, None)
|
||||
|
||||
def test_get_model_token_limit_planner(self):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
|
|
@ -252,16 +270,23 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
mock_adjust.return_value = 120000
|
||||
|
||||
# Call the function to check the return value
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
self.assertEqual(token_limit, 120000)
|
||||
|
||||
# Verify get_model_info was called with the planner model
|
||||
mock_get_info.assert_called_with("deepseek/dsm-1")
|
||||
mock_get_info.assert_called_once_with("deepseek/dsm-1")
|
||||
# Verify adjust_claude_37_token_limit was called
|
||||
mock_adjust.assert_called_once_with(120000, None)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
|
@ -280,54 +305,87 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
self.assertEqual(result, 100000)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
|
||||
def test_get_model_token_limit_for_different_agent_types(self, mock_adjust, mock_get_model_info, mock_get_config_repo):
|
||||
# Use specific model names instead of DEFAULT_MODEL to avoid test dependency
|
||||
claude_model = "claude-3-7-sonnet-20250219"
|
||||
|
||||
# Setup mocks for different agent types
|
||||
mock_config = {
|
||||
"provider": "anthropic",
|
||||
"model": DEFAULT_MODEL,
|
||||
"model": claude_model,
|
||||
"research_provider": "openai",
|
||||
"research_model": "gpt-4",
|
||||
"planner_provider": "anthropic",
|
||||
"planner_model": "claude-3-sonnet-20240229"
|
||||
"planner_model": "claude-3-7-opus-20250301"
|
||||
}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock different returns for different models
|
||||
def model_info_side_effect(model_name):
|
||||
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
||||
if "claude-3-7-sonnet" in model_name:
|
||||
return {"max_input_tokens": 200000}
|
||||
elif "gpt-4" in model_name:
|
||||
return {"max_input_tokens": 8192}
|
||||
elif "claude-3-sonnet" in model_name:
|
||||
return {"max_input_tokens": 100000}
|
||||
elif "claude-3-7-opus" in model_name:
|
||||
return {"max_input_tokens": 250000}
|
||||
else:
|
||||
raise Exception(f"Unknown model: {model_name}")
|
||||
|
||||
mock_get_model_info.side_effect = model_info_side_effect
|
||||
|
||||
# Mock adjust_claude_37_token_limit to return the same values
|
||||
mock_adjust.side_effect = lambda tokens, model: tokens
|
||||
|
||||
# Test default agent type
|
||||
result = get_model_token_limit(mock_config, "default")
|
||||
self.assertEqual(result, 200000)
|
||||
mock_get_model_info.assert_called_with(f"anthropic/{claude_model}")
|
||||
|
||||
# Reset mock
|
||||
mock_get_model_info.reset_mock()
|
||||
|
||||
# Test research agent type
|
||||
result = get_model_token_limit(mock_config, "research")
|
||||
self.assertEqual(result, 8192)
|
||||
mock_get_model_info.assert_called_with("openai/gpt-4")
|
||||
|
||||
# Reset mock
|
||||
mock_get_model_info.reset_mock()
|
||||
|
||||
# Test planner agent type
|
||||
result = get_model_token_limit(mock_config, "planner")
|
||||
self.assertEqual(result, 100000)
|
||||
self.assertEqual(result, 250000)
|
||||
mock_get_model_info.assert_called_with("anthropic/claude-3-7-opus-20250301")
|
||||
|
||||
def test_get_model_token_limit_anthropic(self):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
|
||||
patch("litellm.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
|
||||
# Setup mocks
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = Exception("Model not found")
|
||||
|
||||
# Create a mock models_params with claude-3-7
|
||||
mock_models_params_dict = {
|
||||
"anthropic": {
|
||||
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
|
||||
}
|
||||
}
|
||||
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
|
||||
mock_models_params.get.side_effect = mock_models_params_dict.get
|
||||
|
||||
# Mock adjust to return the same value
|
||||
mock_adjust.return_value = 200000
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
self.assertEqual(token_limit, 200000)
|
||||
|
||||
def test_get_model_token_limit_openai(self):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
|
|
@ -358,28 +416,51 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
|
||||
def test_get_model_token_limit_litellm_success(self):
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
mock_adjust.return_value = 100000
|
||||
|
||||
# Call the function to check the return value
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, 100000)
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
|
||||
# Verify get_model_info was called with the right model
|
||||
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
|
||||
mock_adjust.assert_called_once_with(100000, None)
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found(self):
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("litellm.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
message="Model not found", model="claude-3-7-sonnet-20250219", llm_provider="anthropic"
|
||||
)
|
||||
|
||||
# Create a mock models_params with claude-3-7
|
||||
mock_models_params_dict = {
|
||||
"anthropic": {
|
||||
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
|
||||
}
|
||||
}
|
||||
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
|
||||
mock_models_params.get.side_effect = mock_models_params_dict.get
|
||||
|
||||
# Mock adjust to return the same value
|
||||
mock_adjust.return_value = 200000
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
self.assertEqual(token_limit, 200000)
|
||||
|
||||
def test_get_model_token_limit_litellm_error(self):
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
|
|
@ -399,6 +480,30 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_adjust_claude_37_token_limit(self):
|
||||
"""Test adjust_claude_37_token_limit function."""
|
||||
# Create a mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.model = "claude-3.7-sonnet"
|
||||
mock_model.max_tokens = 4096
|
||||
|
||||
# Test with Claude 3.7 model
|
||||
result = adjust_claude_37_token_limit(100000, mock_model)
|
||||
self.assertEqual(result, 95904) # 100000 - 4096
|
||||
|
||||
# Test with non-Claude 3.7 model
|
||||
mock_model.model = "claude-3-opus"
|
||||
result = adjust_claude_37_token_limit(100000, mock_model)
|
||||
self.assertEqual(result, 100000) # No adjustment
|
||||
|
||||
# Test with None max_input_tokens
|
||||
result = adjust_claude_37_token_limit(None, mock_model)
|
||||
self.assertIsNone(result)
|
||||
|
||||
# Test with None model
|
||||
result = adjust_claude_37_token_limit(100000, None)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
def test_has_tool_use(self):
|
||||
"""Test the has_tool_use function."""
|
||||
# Test with regular AI message
|
||||
|
|
|
|||
Loading…
Reference in New Issue