199 lines
8.4 KiB
Python
199 lines
8.4 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
|
|
|
from ra_aid.anthropic_token_limiter import (
|
|
create_token_counter_wrapper,
|
|
estimate_messages_tokens,
|
|
get_model_token_limit,
|
|
state_modifier,
|
|
)
|
|
|
|
|
|
class TestAnthropicTokenLimiter(unittest.TestCase):
|
|
def setUp(self):
|
|
from ra_aid.config import DEFAULT_MODEL
|
|
|
|
self.mock_model = MagicMock(spec=ChatAnthropic)
|
|
self.mock_model.model = DEFAULT_MODEL
|
|
|
|
# Sample messages for testing
|
|
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
|
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
|
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
|
|
|
# Create more messages for testing
|
|
self.extra_messages = [
|
|
HumanMessage(content=f"Extra message {i}") for i in range(5)
|
|
]
|
|
|
|
# Mock state for testing state_modifier with many messages
|
|
self.state = AgentState(
|
|
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
|
|
next=None,
|
|
)
|
|
|
|
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
|
def test_create_token_counter_wrapper(self, mock_token_counter):
|
|
from ra_aid.config import DEFAULT_MODEL
|
|
|
|
# Setup mock return values
|
|
mock_token_counter.return_value = 50
|
|
|
|
# Create the wrapper
|
|
wrapper = create_token_counter_wrapper(DEFAULT_MODEL)
|
|
|
|
# Test with BaseMessage objects
|
|
result = wrapper([self.human_message])
|
|
self.assertEqual(result, 50)
|
|
|
|
# Test with empty list
|
|
result = wrapper([])
|
|
self.assertEqual(result, 0)
|
|
|
|
# Verify the mock was called with the right parameters
|
|
mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL)
|
|
|
|
@patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens")
|
|
def test_estimate_messages_tokens(self, mock_estimate_tokens):
|
|
# Setup mock to return different values for different messages
|
|
mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20
|
|
|
|
# Test with multiple messages
|
|
messages = [self.system_message, self.human_message]
|
|
result = estimate_messages_tokens(messages)
|
|
|
|
# Should be sum of individual token counts (10 + 20)
|
|
self.assertEqual(result, 30)
|
|
|
|
# Test with empty list
|
|
result = estimate_messages_tokens([])
|
|
self.assertEqual(result, 0)
|
|
|
|
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
|
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
|
def test_state_modifier(self, mock_print, mock_create_wrapper):
|
|
# Setup a proper token counter function that returns integers
|
|
# This function needs to return values that will cause trim_messages to keep only the first message
|
|
def token_counter(msgs):
|
|
# For a single message, return a small token count
|
|
if len(msgs) == 1:
|
|
return 10
|
|
# For two messages (first + one more), return a value under our limit
|
|
elif len(msgs) == 2:
|
|
return 30 # This is under our 40 token remaining budget (50-10)
|
|
# For three messages, return a value just under our limit
|
|
elif len(msgs) == 3:
|
|
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
|
# For four messages, return a value just at our limit
|
|
elif len(msgs) == 4:
|
|
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
|
# For five messages, return a value that exceeds our 40 token budget
|
|
elif len(msgs) == 5:
|
|
return 60 # This exceeds our 40 token budget, forcing only 4 more messages
|
|
# For more messages, return a value over our limit
|
|
else:
|
|
return 100 # This exceeds our limit
|
|
|
|
# Don't use side_effect here, directly return the function
|
|
mock_create_wrapper.return_value = token_counter
|
|
|
|
# Call state_modifier with a max token limit of 50
|
|
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
|
|
|
# Should keep first message and some of the others (up to 5 total)
|
|
self.assertEqual(len(result), 5) # First message plus four more
|
|
self.assertEqual(result[0], self.system_message) # First message is preserved
|
|
|
|
# Verify the wrapper was created with the right model
|
|
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
|
|
|
# Verify print_messages_compact was called
|
|
mock_print.assert_called_once()
|
|
|
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
|
@patch("litellm.get_model_info")
|
|
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
|
from ra_aid.config import DEFAULT_MODEL
|
|
|
|
# Setup mocks
|
|
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
|
|
|
# Mock litellm's get_model_info to return a token limit
|
|
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
|
|
|
# Test getting token limit
|
|
result = get_model_token_limit(mock_config)
|
|
self.assertEqual(result, 100000)
|
|
|
|
# Verify get_model_info was called with the right model
|
|
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
|
|
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
|
@patch("litellm.get_model_info")
|
|
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
|
# Setup mocks
|
|
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
|
|
|
# Make litellm's get_model_info raise an exception to test fallback
|
|
mock_get_model_info.side_effect = Exception("Model not found")
|
|
|
|
# Test getting token limit from models_params fallback
|
|
with patch("ra_aid.anthropic_token_limiter.models_params", {
|
|
"anthropic": {
|
|
"claude2": {"token_limit": 100000}
|
|
}
|
|
}):
|
|
result = get_model_token_limit(mock_config)
|
|
self.assertEqual(result, 100000)
|
|
|
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
|
@patch("litellm.get_model_info")
|
|
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
|
from ra_aid.config import DEFAULT_MODEL
|
|
|
|
# Setup mocks for different agent types
|
|
mock_config = {
|
|
"provider": "anthropic",
|
|
"model": DEFAULT_MODEL,
|
|
"research_provider": "openai",
|
|
"research_model": "gpt-4",
|
|
"planner_provider": "anthropic",
|
|
"planner_model": "claude-3-sonnet-20240229"
|
|
}
|
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
|
|
|
# Mock different returns for different models
|
|
def model_info_side_effect(model_name):
|
|
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
|
return {"max_input_tokens": 200000}
|
|
elif "gpt-4" in model_name:
|
|
return {"max_input_tokens": 8192}
|
|
elif "claude-3-sonnet" in model_name:
|
|
return {"max_input_tokens": 100000}
|
|
else:
|
|
raise Exception(f"Unknown model: {model_name}")
|
|
|
|
mock_get_model_info.side_effect = model_info_side_effect
|
|
|
|
# Test default agent type
|
|
result = get_model_token_limit(mock_config, "default")
|
|
self.assertEqual(result, 200000)
|
|
|
|
# Test research agent type
|
|
result = get_model_token_limit(mock_config, "research")
|
|
self.assertEqual(result, 8192)
|
|
|
|
# Test planner agent type
|
|
result = get_model_token_limit(mock_config, "planner")
|
|
self.assertEqual(result, 100000)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|