refactor(tests): move token limit tests from test_agent_utils.py to test_anthropic_token_limiter.py for better organization and clarity
This commit is contained in:
parent
e42f281f94
commit
8d2d273c6b
|
|
@ -14,8 +14,10 @@ from ra_aid.agent_context import (
|
|||
from ra_aid.agent_utils import (
|
||||
AgentState,
|
||||
create_agent,
|
||||
get_model_token_limit,
|
||||
is_anthropic_claude,
|
||||
)
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
get_model_token_limit,
|
||||
state_modifier,
|
||||
)
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
|
|
@ -63,87 +65,15 @@ def mock_config_repository():
|
|||
yield mock_repo
|
||||
|
||||
|
||||
def test_get_model_token_limit_anthropic(mock_config_repository):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_openai(mock_config_repository):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
config = {"provider": "openai", "model": "gpt-4"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_unknown(mock_config_repository):
|
||||
"""Test get_model_token_limit with unknown provider/model."""
|
||||
config = {"provider": "unknown", "model": "unknown-model"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_get_model_token_limit_missing_config(mock_config_repository):
|
||||
"""Test get_model_token_limit with missing configuration."""
|
||||
config = {}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_success():
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == 100000
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found():
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_error():
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.side_effect = Exception("Unknown error")
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_unexpected_error():
|
||||
"""Test returning None when unexpected errors occur."""
|
||||
config = None # This will cause an attribute error when accessed
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
# These tests have been moved to test_anthropic_token_limiter.py
|
||||
|
||||
|
||||
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||
"""Test create_agent with Anthropic Claude model."""
|
||||
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
||||
|
||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react, \
|
||||
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier:
|
||||
mock_react.return_value = "react_agent"
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
|
|
@ -221,20 +151,7 @@ def mock_messages():
|
|||
]
|
||||
|
||||
|
||||
def test_state_modifier(mock_messages):
|
||||
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
||||
state = AgentState(messages=mock_messages)
|
||||
|
||||
with patch(
|
||||
"ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens"
|
||||
) as mock_estimate:
|
||||
mock_estimate.side_effect = lambda msg: 100 if msg else 0
|
||||
|
||||
result = state_modifier(state, max_input_tokens=250)
|
||||
|
||||
assert len(result) < len(mock_messages)
|
||||
assert isinstance(result[0], SystemMessage)
|
||||
assert result[-1] == mock_messages[-1]
|
||||
# This test has been moved to test_anthropic_token_limiter.py
|
||||
|
||||
|
||||
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
||||
|
|
@ -265,7 +182,7 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
|
|||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
mock_limit.return_value = 100000
|
||||
|
|
@ -288,7 +205,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
|
|||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
mock_limit.return_value = 100000
|
||||
|
|
@ -299,36 +216,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
|
|||
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2")
|
||||
|
||||
|
||||
def test_get_model_token_limit_research(mock_config_repository):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2",
|
||||
}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
assert token_limit == 150000
|
||||
|
||||
|
||||
def test_get_model_token_limit_planner(mock_config_repository):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"planner_provider": "deepseek",
|
||||
"planner_model": "dsm-1",
|
||||
}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
assert token_limit == 120000
|
||||
# These tests have been moved to test_anthropic_token_limiter.py
|
||||
|
||||
|
||||
# New tests for private helper methods in agent_utils.py
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import litellm
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import (
|
||||
|
|
@ -19,6 +20,7 @@ from ra_aid.anthropic_token_limiter import (
|
|||
convert_message_to_litellm_format
|
||||
)
|
||||
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
|
||||
|
||||
|
||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||
|
|
@ -140,6 +142,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
# Verify print_messages_compact was called at least once
|
||||
self.assertTrue(mock_print.call_count >= 1)
|
||||
|
||||
def test_state_modifier_with_messages(self):
|
||||
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
||||
# Create a state with messages
|
||||
messages = [
|
||||
SystemMessage(content="System prompt"),
|
||||
HumanMessage(content="Human message 1"),
|
||||
AIMessage(content="AI response 1"),
|
||||
HumanMessage(content="Human message 2"),
|
||||
AIMessage(content="AI response 2"),
|
||||
]
|
||||
state = AgentState(messages=messages)
|
||||
model = MagicMock(spec=ChatAnthropic)
|
||||
model.model = "claude-3-opus-20240229"
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
|
||||
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
|
||||
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
|
||||
# Setup mock to return a fixed token count per message
|
||||
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
|
||||
# Setup mock to return a subset of messages
|
||||
mock_trim.return_value = [messages[0], messages[-2], messages[-1]]
|
||||
|
||||
result = state_modifier(state, model, max_input_tokens=250)
|
||||
|
||||
# Should return what anthropic_trim_messages returned
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0], messages[0]) # First message preserved
|
||||
self.assertEqual(result[-1], messages[-1]) # Last message preserved
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens")
|
||||
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||
def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate):
|
||||
|
|
@ -191,6 +222,42 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
|
||||
# Verify get_model_info was called with the right model
|
||||
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
||||
|
||||
def test_get_model_token_limit_research(self):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2",
|
||||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
self.assertEqual(token_limit, 150000)
|
||||
# Verify get_model_info was called with the research model
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
|
||||
def test_get_model_token_limit_planner(self):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"planner_provider": "deepseek",
|
||||
"planner_model": "dsm-1",
|
||||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
self.assertEqual(token_limit, 120000)
|
||||
# Verify get_model_info was called with the planner model
|
||||
mock_get_info.assert_called_with("deepseek/dsm-1")
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
|
|
@ -252,6 +319,85 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
result = get_model_token_limit(mock_config, "planner")
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
def test_get_model_token_limit_anthropic(self):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_openai(self):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
config = {"provider": "openai", "model": "gpt-4"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_unknown(self):
|
||||
"""Test get_model_token_limit with unknown provider/model."""
|
||||
config = {"provider": "unknown", "model": "unknown-model"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_get_model_token_limit_missing_config(self):
|
||||
"""Test get_model_token_limit with missing configuration."""
|
||||
config = {}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_get_model_token_limit_litellm_success(self):
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, 100000)
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found(self):
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_litellm_error(self):
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = Exception("Unknown error")
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_unexpected_error(self):
|
||||
"""Test returning None when unexpected errors occur."""
|
||||
config = None # This will cause an attribute error when accessed
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_has_tool_use(self):
|
||||
"""Test the has_tool_use function."""
|
||||
# Test with regular AI message
|
||||
|
|
|
|||
Loading…
Reference in New Issue