feat(anthropic_token_limiter): add get_provider_and_model_for_agent_type function to streamline provider and model retrieval based on agent type
fix(anthropic_token_limiter): refactor get_model_token_limit to use the new get_provider_and_model_for_agent_type function for cleaner code test(anthropic_token_limiter): add unit tests for get_provider_and_model_for_agent_type and adjust_claude_37_token_limit functions to ensure correctness and coverage
This commit is contained in:
parent
29c9cac4f4
commit
92faf8fc2d
|
|
@ -1,24 +1,19 @@
|
||||||
"""Utilities for handling token limits with Anthropic models."""
|
"""Utilities for handling token limits with Anthropic models."""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from ra_aid.model_detection import is_claude_37
|
from ra_aid.model_detection import is_claude_37
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
RemoveMessage,
|
|
||||||
ToolMessage,
|
|
||||||
trim_messages,
|
trim_messages,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.base import message_to_dict
|
from langchain_core.messages.base import message_to_dict
|
||||||
|
|
||||||
from ra_aid.anthropic_message_utils import (
|
from ra_aid.anthropic_message_utils import (
|
||||||
anthropic_trim_messages,
|
anthropic_trim_messages,
|
||||||
has_tool_use,
|
|
||||||
)
|
)
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import token_counter, get_model_info
|
from litellm import token_counter, get_model_info
|
||||||
|
|
@ -27,7 +22,6 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
from ra_aid.console.output import cpm, print_messages_compact
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -170,6 +164,29 @@ def sonnet_35_state_modifier(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]:
|
||||||
|
"""Get the provider and model name for the specified agent type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary containing provider and model information
|
||||||
|
agent_type: Type of agent ("default", "research", or "planner")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, str]: A tuple containing (provider, model_name)
|
||||||
|
"""
|
||||||
|
if agent_type == "research":
|
||||||
|
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("research_model", "") or config.get("model", "")
|
||||||
|
elif agent_type == "planner":
|
||||||
|
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||||
|
else:
|
||||||
|
provider = config.get("provider", "")
|
||||||
|
model_name = config.get("model", "")
|
||||||
|
|
||||||
|
return provider, model_name
|
||||||
|
|
||||||
|
|
||||||
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
|
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
|
||||||
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
||||||
|
|
||||||
|
|
@ -217,19 +234,13 @@ def get_model_token_limit(
|
||||||
# In tests, this may fail because the repository isn't set up
|
# In tests, this may fail because the repository isn't set up
|
||||||
# So we'll use the passed config directly
|
# So we'll use the passed config directly
|
||||||
pass
|
pass
|
||||||
if agent_type == "research":
|
|
||||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
|
||||||
model_name = config.get("research_model", "") or config.get("model", "")
|
|
||||||
elif agent_type == "planner":
|
# Always attempt to get model info from litellm first
|
||||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
|
||||||
else:
|
|
||||||
provider = config.get("provider", "")
|
|
||||||
model_name = config.get("model", "")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
|
||||||
model_info = get_model_info(provider_model)
|
model_info = get_model_info(provider_model)
|
||||||
max_input_tokens = model_info.get("max_input_tokens")
|
max_input_tokens = model_info.get("max_input_tokens")
|
||||||
if max_input_tokens:
|
if max_input_tokens:
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,11 @@ from ra_aid.anthropic_token_limiter import (
|
||||||
get_model_token_limit,
|
get_model_token_limit,
|
||||||
state_modifier,
|
state_modifier,
|
||||||
sonnet_35_state_modifier,
|
sonnet_35_state_modifier,
|
||||||
convert_message_to_litellm_format
|
convert_message_to_litellm_format,
|
||||||
|
adjust_claude_37_token_limit
|
||||||
)
|
)
|
||||||
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||||
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
|
from ra_aid.models_params import models_params
|
||||||
|
|
||||||
|
|
||||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
|
|
@ -113,9 +114,8 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
self.assertEqual(result, 0)
|
self.assertEqual(result, 0)
|
||||||
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||||
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||||
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
|
def test_state_modifier(self, mock_trim_messages, mock_create_wrapper):
|
||||||
# Setup a proper token counter function that returns integers
|
# Setup a proper token counter function that returns integers
|
||||||
def token_counter(msgs):
|
def token_counter(msgs):
|
||||||
# Return token count based on number of messages
|
# Return token count based on number of messages
|
||||||
|
|
@ -155,8 +155,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
model.model = "claude-3-opus-20240229"
|
model.model = "claude-3-opus-20240229"
|
||||||
|
|
||||||
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
|
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
|
||||||
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
|
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim:
|
||||||
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
|
|
||||||
# Setup mock to return a fixed token count per message
|
# Setup mock to return a fixed token count per message
|
||||||
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
|
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
|
||||||
# Setup mock to return a subset of messages
|
# Setup mock to return a subset of messages
|
||||||
|
|
@ -206,23 +205,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
self.assertEqual(call_args["include_system"], True)
|
self.assertEqual(call_args["include_system"], True)
|
||||||
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
@patch("litellm.get_model_info")
|
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||||
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
@patch("ra_aid.anthropic_token_limiter.is_claude_37")
|
||||||
from ra_aid.config import DEFAULT_MODEL
|
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
|
||||||
|
def test_get_model_token_limit_from_litellm(self, mock_adjust, mock_is_claude_37, mock_get_model_info, mock_get_config_repo):
|
||||||
|
# Use a specific model name instead of DEFAULT_MODEL to avoid test dependency
|
||||||
|
model_name = "claude-3-7-sonnet-20250219"
|
||||||
|
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
mock_config = {"provider": "anthropic", "model": model_name}
|
||||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
# Mock litellm's get_model_info to return a token limit
|
# Mock litellm's get_model_info to return a token limit
|
||||||
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
||||||
|
|
||||||
|
# Mock is_claude_37 to return True
|
||||||
|
mock_is_claude_37.return_value = True
|
||||||
|
|
||||||
|
# Mock adjust_claude_37_token_limit to return the original value
|
||||||
|
mock_adjust.return_value = 100000
|
||||||
|
|
||||||
# Test getting token limit
|
# Test getting token limit
|
||||||
result = get_model_token_limit(mock_config)
|
result = get_model_token_limit(mock_config)
|
||||||
self.assertEqual(result, 100000)
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
# Verify get_model_info was called with the right model
|
# Verify get_model_info was called with the right model
|
||||||
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
mock_get_model_info.assert_called_once_with(f"anthropic/{model_name}")
|
||||||
|
|
||||||
|
# Verify adjust_claude_37_token_limit was called
|
||||||
|
mock_adjust.assert_called_once_with(100000, None)
|
||||||
|
|
||||||
def test_get_model_token_limit_research(self):
|
def test_get_model_token_limit_research(self):
|
||||||
"""Test get_model_token_limit with research provider and model."""
|
"""Test get_model_token_limit with research provider and model."""
|
||||||
|
|
@ -230,17 +241,24 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
"model": "gpt-4",
|
"model": "gpt-4",
|
||||||
"research_provider": "anthropic",
|
"research_provider": "anthropic",
|
||||||
"research_model": "claude-2",
|
"research_model": "claude-3-7-sonnet-20250219",
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
patch("litellm.get_model_info") as mock_get_info:
|
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||||
mock_get_config_repo.return_value.get_all.return_value = config
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||||
|
mock_adjust.return_value = 150000
|
||||||
|
|
||||||
|
# Call the function to check the return value
|
||||||
token_limit = get_model_token_limit(config, "research")
|
token_limit = get_model_token_limit(config, "research")
|
||||||
self.assertEqual(token_limit, 150000)
|
self.assertEqual(token_limit, 150000)
|
||||||
|
|
||||||
# Verify get_model_info was called with the research model
|
# Verify get_model_info was called with the research model
|
||||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
|
||||||
|
# Verify adjust_claude_37_token_limit was called
|
||||||
|
mock_adjust.assert_called_once_with(150000, None)
|
||||||
|
|
||||||
def test_get_model_token_limit_planner(self):
|
def test_get_model_token_limit_planner(self):
|
||||||
"""Test get_model_token_limit with planner provider and model."""
|
"""Test get_model_token_limit with planner provider and model."""
|
||||||
|
|
@ -252,16 +270,23 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
patch("litellm.get_model_info") as mock_get_info:
|
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||||
mock_get_config_repo.return_value.get_all.return_value = config
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||||
|
mock_adjust.return_value = 120000
|
||||||
|
|
||||||
|
# Call the function to check the return value
|
||||||
token_limit = get_model_token_limit(config, "planner")
|
token_limit = get_model_token_limit(config, "planner")
|
||||||
self.assertEqual(token_limit, 120000)
|
self.assertEqual(token_limit, 120000)
|
||||||
|
|
||||||
# Verify get_model_info was called with the planner model
|
# Verify get_model_info was called with the planner model
|
||||||
mock_get_info.assert_called_with("deepseek/dsm-1")
|
mock_get_info.assert_called_once_with("deepseek/dsm-1")
|
||||||
|
# Verify adjust_claude_37_token_limit was called
|
||||||
|
mock_adjust.assert_called_once_with(120000, None)
|
||||||
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
@patch("litellm.get_model_info")
|
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||||
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
|
@ -280,54 +305,87 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
self.assertEqual(result, 100000)
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
@patch("litellm.get_model_info")
|
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||||
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
|
||||||
from ra_aid.config import DEFAULT_MODEL
|
def test_get_model_token_limit_for_different_agent_types(self, mock_adjust, mock_get_model_info, mock_get_config_repo):
|
||||||
|
# Use specific model names instead of DEFAULT_MODEL to avoid test dependency
|
||||||
|
claude_model = "claude-3-7-sonnet-20250219"
|
||||||
|
|
||||||
# Setup mocks for different agent types
|
# Setup mocks for different agent types
|
||||||
mock_config = {
|
mock_config = {
|
||||||
"provider": "anthropic",
|
"provider": "anthropic",
|
||||||
"model": DEFAULT_MODEL,
|
"model": claude_model,
|
||||||
"research_provider": "openai",
|
"research_provider": "openai",
|
||||||
"research_model": "gpt-4",
|
"research_model": "gpt-4",
|
||||||
"planner_provider": "anthropic",
|
"planner_provider": "anthropic",
|
||||||
"planner_model": "claude-3-sonnet-20240229"
|
"planner_model": "claude-3-7-opus-20250301"
|
||||||
}
|
}
|
||||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
# Mock different returns for different models
|
# Mock different returns for different models
|
||||||
def model_info_side_effect(model_name):
|
def model_info_side_effect(model_name):
|
||||||
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
if "claude-3-7-sonnet" in model_name:
|
||||||
return {"max_input_tokens": 200000}
|
return {"max_input_tokens": 200000}
|
||||||
elif "gpt-4" in model_name:
|
elif "gpt-4" in model_name:
|
||||||
return {"max_input_tokens": 8192}
|
return {"max_input_tokens": 8192}
|
||||||
elif "claude-3-sonnet" in model_name:
|
elif "claude-3-7-opus" in model_name:
|
||||||
return {"max_input_tokens": 100000}
|
return {"max_input_tokens": 250000}
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown model: {model_name}")
|
raise Exception(f"Unknown model: {model_name}")
|
||||||
|
|
||||||
mock_get_model_info.side_effect = model_info_side_effect
|
mock_get_model_info.side_effect = model_info_side_effect
|
||||||
|
|
||||||
|
# Mock adjust_claude_37_token_limit to return the same values
|
||||||
|
mock_adjust.side_effect = lambda tokens, model: tokens
|
||||||
|
|
||||||
# Test default agent type
|
# Test default agent type
|
||||||
result = get_model_token_limit(mock_config, "default")
|
result = get_model_token_limit(mock_config, "default")
|
||||||
self.assertEqual(result, 200000)
|
self.assertEqual(result, 200000)
|
||||||
|
mock_get_model_info.assert_called_with(f"anthropic/{claude_model}")
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_get_model_info.reset_mock()
|
||||||
|
|
||||||
# Test research agent type
|
# Test research agent type
|
||||||
result = get_model_token_limit(mock_config, "research")
|
result = get_model_token_limit(mock_config, "research")
|
||||||
self.assertEqual(result, 8192)
|
self.assertEqual(result, 8192)
|
||||||
|
mock_get_model_info.assert_called_with("openai/gpt-4")
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_get_model_info.reset_mock()
|
||||||
|
|
||||||
# Test planner agent type
|
# Test planner agent type
|
||||||
result = get_model_token_limit(mock_config, "planner")
|
result = get_model_token_limit(mock_config, "planner")
|
||||||
self.assertEqual(result, 100000)
|
self.assertEqual(result, 250000)
|
||||||
|
mock_get_model_info.assert_called_with("anthropic/claude-3-7-opus-20250301")
|
||||||
|
|
||||||
def test_get_model_token_limit_anthropic(self):
|
def test_get_model_token_limit_anthropic(self):
|
||||||
"""Test get_model_token_limit with Anthropic model."""
|
"""Test get_model_token_limit with Anthropic model."""
|
||||||
config = {"provider": "anthropic", "model": "claude2"}
|
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||||
|
|
||||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
|
||||||
|
patch("litellm.get_model_info") as mock_get_info, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
mock_get_config_repo.return_value.get_all.return_value = config
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
mock_get_info.side_effect = Exception("Model not found")
|
||||||
|
|
||||||
|
# Create a mock models_params with claude-3-7
|
||||||
|
mock_models_params_dict = {
|
||||||
|
"anthropic": {
|
||||||
|
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
|
||||||
|
mock_models_params.get.side_effect = mock_models_params_dict.get
|
||||||
|
|
||||||
|
# Mock adjust to return the same value
|
||||||
|
mock_adjust.return_value = 200000
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
self.assertEqual(token_limit, 200000)
|
||||||
|
|
||||||
def test_get_model_token_limit_openai(self):
|
def test_get_model_token_limit_openai(self):
|
||||||
"""Test get_model_token_limit with OpenAI model."""
|
"""Test get_model_token_limit with OpenAI model."""
|
||||||
|
|
@ -358,28 +416,51 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_success(self):
|
def test_get_model_token_limit_litellm_success(self):
|
||||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||||
|
|
||||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
patch("litellm.get_model_info") as mock_get_info:
|
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||||
mock_get_config_repo.return_value.get_all.return_value = config
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||||
|
mock_adjust.return_value = 100000
|
||||||
|
|
||||||
|
# Call the function to check the return value
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
self.assertEqual(token_limit, 100000)
|
self.assertEqual(token_limit, 100000)
|
||||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
|
||||||
|
# Verify get_model_info was called with the right model
|
||||||
|
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
|
||||||
|
mock_adjust.assert_called_once_with(100000, None)
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_not_found(self):
|
def test_get_model_token_limit_litellm_not_found(self):
|
||||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||||
|
|
||||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
patch("litellm.get_model_info") as mock_get_info:
|
patch("litellm.get_model_info") as mock_get_info, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||||
|
|
||||||
mock_get_config_repo.return_value.get_all.return_value = config
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
message="Model not found", model="claude-3-7-sonnet-20250219", llm_provider="anthropic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create a mock models_params with claude-3-7
|
||||||
|
mock_models_params_dict = {
|
||||||
|
"anthropic": {
|
||||||
|
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
|
||||||
|
mock_models_params.get.side_effect = mock_models_params_dict.get
|
||||||
|
|
||||||
|
# Mock adjust to return the same value
|
||||||
|
mock_adjust.return_value = 200000
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
self.assertEqual(token_limit, 200000)
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_error(self):
|
def test_get_model_token_limit_litellm_error(self):
|
||||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||||
|
|
@ -399,6 +480,30 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
self.assertIsNone(token_limit)
|
self.assertIsNone(token_limit)
|
||||||
|
|
||||||
|
def test_adjust_claude_37_token_limit(self):
|
||||||
|
"""Test adjust_claude_37_token_limit function."""
|
||||||
|
# Create a mock model
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_model.model = "claude-3.7-sonnet"
|
||||||
|
mock_model.max_tokens = 4096
|
||||||
|
|
||||||
|
# Test with Claude 3.7 model
|
||||||
|
result = adjust_claude_37_token_limit(100000, mock_model)
|
||||||
|
self.assertEqual(result, 95904) # 100000 - 4096
|
||||||
|
|
||||||
|
# Test with non-Claude 3.7 model
|
||||||
|
mock_model.model = "claude-3-opus"
|
||||||
|
result = adjust_claude_37_token_limit(100000, mock_model)
|
||||||
|
self.assertEqual(result, 100000) # No adjustment
|
||||||
|
|
||||||
|
# Test with None max_input_tokens
|
||||||
|
result = adjust_claude_37_token_limit(None, mock_model)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
# Test with None model
|
||||||
|
result = adjust_claude_37_token_limit(100000, None)
|
||||||
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
def test_has_tool_use(self):
|
def test_has_tool_use(self):
|
||||||
"""Test the has_tool_use function."""
|
"""Test the has_tool_use function."""
|
||||||
# Test with regular AI message
|
# Test with regular AI message
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue