refactor(tests): move token limit tests from test_agent_utils.py to test_anthropic_token_limiter.py for better organization and clarity

This commit is contained in:
Ariel Frischer 2025-03-11 23:53:37 -07:00
parent e42f281f94
commit 8d2d273c6b
2 changed files with 156 additions and 122 deletions

View File

@ -14,8 +14,10 @@ from ra_aid.agent_context import (
from ra_aid.agent_utils import (
AgentState,
create_agent,
get_model_token_limit,
is_anthropic_claude,
)
from ra_aid.anthropic_token_limiter import (
get_model_token_limit,
state_modifier,
)
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
@ -63,87 +65,15 @@ def mock_config_repository():
yield mock_repo
def test_get_model_token_limit_anthropic(mock_config_repository):
"""Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_openai(mock_config_repository):
"""Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
def test_get_model_token_limit_unknown(mock_config_repository):
"""Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
def test_get_model_token_limit_missing_config(mock_config_repository):
"""Test get_model_token_limit with missing configuration."""
config = {}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
def test_get_model_token_limit_litellm_success():
"""Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config, "default")
assert token_limit == 100000
def test_get_model_token_limit_litellm_not_found():
"""Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic"
)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_litellm_error():
"""Test fallback to models_tokens when litellm raises other exceptions."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_unexpected_error():
"""Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
# These tests have been moved to test_anthropic_token_limiter.py
def test_create_agent_anthropic(mock_model, mock_config_repository):
"""Test create_agent with Anthropic Claude model."""
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
with patch("ra_aid.agent_utils.create_react_agent") as mock_react, \
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier:
mock_react.return_value = "react_agent"
agent = create_agent(mock_model, [])
@ -221,20 +151,7 @@ def mock_messages():
]
def test_state_modifier(mock_messages):
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
state = AgentState(messages=mock_messages)
with patch(
"ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens"
) as mock_estimate:
mock_estimate.side_effect = lambda msg: 100 if msg else 0
result = state_modifier(state, max_input_tokens=250)
assert len(result) < len(mock_messages)
assert isinstance(result[0], SystemMessage)
assert result[-1] == mock_messages[-1]
# This test has been moved to test_anthropic_token_limiter.py
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
@ -265,7 +182,7 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
):
mock_react.return_value = "react_agent"
mock_limit.return_value = 100000
@ -288,7 +205,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
):
mock_react.return_value = "react_agent"
mock_limit.return_value = 100000
@ -299,36 +216,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2")
def test_get_model_token_limit_research(mock_config_repository):
"""Test get_model_token_limit with research provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"research_provider": "anthropic",
"research_model": "claude-2",
}
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research")
assert token_limit == 150000
def test_get_model_token_limit_planner(mock_config_repository):
"""Test get_model_token_limit with planner provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"planner_provider": "deepseek",
"planner_model": "dsm-1",
}
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000
# These tests have been moved to test_anthropic_token_limiter.py
# New tests for private helper methods in agent_utils.py

View File

@ -1,5 +1,6 @@
import unittest
from unittest.mock import MagicMock, patch
import litellm
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import (
@ -19,6 +20,7 @@ from ra_aid.anthropic_token_limiter import (
convert_message_to_litellm_format
)
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
class TestAnthropicTokenLimiter(unittest.TestCase):
@ -140,6 +142,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
# Verify print_messages_compact was called at least once
self.assertTrue(mock_print.call_count >= 1)
def test_state_modifier_with_messages(self):
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
# Create a state with messages
messages = [
SystemMessage(content="System prompt"),
HumanMessage(content="Human message 1"),
AIMessage(content="AI response 1"),
HumanMessage(content="Human message 2"),
AIMessage(content="AI response 2"),
]
state = AgentState(messages=messages)
model = MagicMock(spec=ChatAnthropic)
model.model = "claude-3-opus-20240229"
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
# Setup mock to return a fixed token count per message
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
# Setup mock to return a subset of messages
mock_trim.return_value = [messages[0], messages[-2], messages[-1]]
result = state_modifier(state, model, max_input_tokens=250)
# Should return what anthropic_trim_messages returned
self.assertEqual(len(result), 3)
self.assertEqual(result[0], messages[0]) # First message preserved
self.assertEqual(result[-1], messages[-1]) # Last message preserved
@patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens")
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate):
@ -191,6 +222,42 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
# Verify get_model_info was called with the right model
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
def test_get_model_token_limit_research(self):
"""Test get_model_token_limit with research provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"research_provider": "anthropic",
"research_model": "claude-2",
}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research")
self.assertEqual(token_limit, 150000)
# Verify get_model_info was called with the research model
mock_get_info.assert_called_with("anthropic/claude-2")
def test_get_model_token_limit_planner(self):
"""Test get_model_token_limit with planner provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"planner_provider": "deepseek",
"planner_model": "dsm-1",
}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
self.assertEqual(token_limit, 120000)
# Verify get_model_info was called with the planner model
mock_get_info.assert_called_with("deepseek/dsm-1")
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
@ -252,6 +319,85 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
result = get_model_token_limit(mock_config, "planner")
self.assertEqual(result, 100000)
def test_get_model_token_limit_anthropic(self):
"""Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
def test_get_model_token_limit_openai(self):
"""Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"])
def test_get_model_token_limit_unknown(self):
"""Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit)
def test_get_model_token_limit_missing_config(self):
"""Test get_model_token_limit with missing configuration."""
config = {}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit)
def test_get_model_token_limit_litellm_success(self):
"""Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, 100000)
mock_get_info.assert_called_with("anthropic/claude-2")
def test_get_model_token_limit_litellm_not_found(self):
"""Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic"
)
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
def test_get_model_token_limit_litellm_error(self):
"""Test fallback to models_tokens when litellm raises other exceptions."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
def test_get_model_token_limit_unexpected_error(self):
"""Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit)
def test_has_tool_use(self):
"""Test the has_tool_use function."""
# Test with regular AI message