chore(anthropic_message_utils.py): remove unused fix_anthropic_message_content function to clean up codebase

chore(anthropic_token_limiter.py): remove import of fix_anthropic_message_content as it is no longer needed
test: add unit tests for has_tool_use and is_tool_pair functions to ensure correct functionality
test: enhance test coverage for anthropic_trim_messages with tool use scenarios to validate message handling
This commit is contained in:
Ariel Frischer 2025-03-11 23:48:08 -07:00
parent 376d486db8
commit e42f281f94
3 changed files with 189 additions and 63 deletions

View File

@ -82,41 +82,6 @@ def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
) )
def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage:
"""Fix message content format for Anthropic API compatibility."""
if not isinstance(message, AIMessage) or not isinstance(message.content, list):
return message
fixed_message = message.model_copy(deep=True)
# Ensure first block is valid thinking type
if fixed_message.content and isinstance(fixed_message.content[0], dict):
first_block_type = fixed_message.content[0].get("type")
if first_block_type not in ("thinking", "redacted_thinking"):
# Prepend redacted_thinking block instead of thinking
fixed_message.content.insert(
0,
{
"type": "redacted_thinking",
"data": "ENCRYPTED_REASONING", # Required field for redacted_thinking
},
)
# Ensure all thinking blocks have valid structure
for i, block in enumerate(fixed_message.content):
if block.get("type") == "thinking":
# Convert thinking blocks to redacted_thinking to avoid signature validation
fixed_message.content[i] = {
"type": "redacted_thinking",
"data": "ENCRYPTED_REASONING",
}
elif block.get("type") == "redacted_thinking":
# Ensure required data field exists
if "data" not in block:
fixed_message.content[i]["data"] = "ENCRYPTED_REASONING"
return fixed_message
def anthropic_trim_messages( def anthropic_trim_messages(
messages: Sequence[BaseMessage], messages: Sequence[BaseMessage],

View File

@ -15,7 +15,6 @@ from langchain_core.messages import (
from langchain_core.messages.base import message_to_dict from langchain_core.messages.base import message_to_dict
from ra_aid.anthropic_message_utils import ( from ra_aid.anthropic_message_utils import (
fix_anthropic_message_content,
anthropic_trim_messages, anthropic_trim_messages,
has_tool_use, has_tool_use,
) )

View File

@ -2,7 +2,12 @@ import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage
)
from langgraph.prebuilt.chat_agent_executor import AgentState from langgraph.prebuilt.chat_agent_executor import AgentState
from ra_aid.anthropic_token_limiter import ( from ra_aid.anthropic_token_limiter import (
@ -10,7 +15,10 @@ from ra_aid.anthropic_token_limiter import (
estimate_messages_tokens, estimate_messages_tokens,
get_model_token_limit, get_model_token_limit,
state_modifier, state_modifier,
sonnet_35_state_modifier,
convert_message_to_litellm_format
) )
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
class TestAnthropicTokenLimiter(unittest.TestCase): class TestAnthropicTokenLimiter(unittest.TestCase):
@ -23,6 +31,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
# Sample messages for testing # Sample messages for testing
self.system_message = SystemMessage(content="You are a helpful assistant.") self.system_message = SystemMessage(content="You are a helpful assistant.")
self.human_message = HumanMessage(content="Hello, can you help me with a task?") self.human_message = HumanMessage(content="Hello, can you help me with a task?")
self.ai_message = AIMessage(content="I'd be happy to help! What do you need?")
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
# Create more messages for testing # Create more messages for testing
@ -36,6 +45,34 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
next=None, next=None,
) )
# Create tool-related messages for testing
self.ai_with_tool_use = AIMessage(
content="I'll use a tool to help you",
additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]}
)
self.tool_message = ToolMessage(
content="4",
tool_call_id="tool_call_1",
name="calculator"
)
def test_convert_message_to_litellm_format(self):
"""Test conversion of BaseMessage to litellm format."""
# Test human message
human_result = convert_message_to_litellm_format(self.human_message)
self.assertEqual(human_result["role"], "human")
self.assertEqual(human_result["content"], "Hello, can you help me with a task?")
# Test system message
system_result = convert_message_to_litellm_format(self.system_message)
self.assertEqual(system_result["role"], "system")
self.assertEqual(system_result["content"], "You are a helpful assistant.")
# Test AI message
ai_result = convert_message_to_litellm_format(self.ai_message)
self.assertEqual(ai_result["role"], "ai")
self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?")
@patch("ra_aid.anthropic_token_limiter.token_counter") @patch("ra_aid.anthropic_token_limiter.token_counter")
def test_create_token_counter_wrapper(self, mock_token_counter): def test_create_token_counter_wrapper(self, mock_token_counter):
from ra_aid.config import DEFAULT_MODEL from ra_aid.config import DEFAULT_MODEL
@ -75,44 +112,66 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
@patch("ra_aid.anthropic_token_limiter.print_messages_compact") @patch("ra_aid.anthropic_token_limiter.print_messages_compact")
def test_state_modifier(self, mock_print, mock_create_wrapper): @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
# Setup a proper token counter function that returns integers # Setup a proper token counter function that returns integers
# This function needs to return values that will cause trim_messages to keep only the first message
def token_counter(msgs): def token_counter(msgs):
# For a single message, return a small token count # Return token count based on number of messages
if len(msgs) == 1: return len(msgs) * 10
return 10
# For two messages (first + one more), return a value under our limit
elif len(msgs) == 2:
return 30 # This is under our 40 token remaining budget (50-10)
# For three messages, return a value just under our limit
elif len(msgs) == 3:
return 40 # This is exactly at our 40 token remaining budget (50-10)
# For four messages, return a value just at our limit
elif len(msgs) == 4:
return 40 # This is exactly at our 40 token remaining budget (50-10)
# For five messages, return a value that exceeds our 40 token budget
elif len(msgs) == 5:
return 60 # This exceeds our 40 token budget, forcing only 4 more messages
# For more messages, return a value over our limit
else:
return 100 # This exceeds our limit
# Don't use side_effect here, directly return the function # Configure the mock to return our token counter
mock_create_wrapper.return_value = token_counter mock_create_wrapper.return_value = token_counter
# Configure anthropic_trim_messages to return a subset of messages
mock_trim_messages.return_value = [self.system_message, self.human_message]
# Call state_modifier with a max token limit of 50 # Call state_modifier with a max token limit of 50
result = state_modifier(self.state, self.mock_model, max_input_tokens=50) result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
# Should keep first message and some of the others (up to 5 total) # Should return what anthropic_trim_messages returned
self.assertEqual(len(result), 5) # First message plus four more self.assertEqual(result, [self.system_message, self.human_message])
self.assertEqual(result[0], self.system_message) # First message is preserved
# Verify the wrapper was created with the right model # Verify the wrapper was created with the right model
mock_create_wrapper.assert_called_with(self.mock_model.model) mock_create_wrapper.assert_called_with(self.mock_model.model)
# Verify print_messages_compact was called # Verify anthropic_trim_messages was called with the right parameters
mock_print.assert_called_once() mock_trim_messages.assert_called_once()
# Verify print_messages_compact was called at least once
self.assertTrue(mock_print.call_count >= 1)
@patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens")
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate):
"""Test the sonnet 35 state modifier function."""
# Setup mocks
mock_estimate.side_effect = lambda msgs: len(msgs) * 1000
mock_trim.return_value = [self.human_message, self.ai_message]
# Create a state with messages
state = {"messages": [self.system_message, self.human_message, self.ai_message]}
# Test with empty messages
empty_state = {"messages": []}
self.assertEqual(sonnet_35_state_modifier(empty_state), [])
# Test with messages under the limit
result = sonnet_35_state_modifier(state, max_input_tokens=10000)
# Should keep the first message and call anthropic_trim_messages for the rest
self.assertEqual(len(result), 3)
self.assertEqual(result[0], self.system_message)
self.assertEqual(result[1:], [self.human_message, self.ai_message])
# Verify anthropic_trim_messages was called with the right parameters
mock_trim.assert_called_once_with(
[self.human_message, self.ai_message],
token_counter=mock_estimate,
max_tokens=9000, # 10000 - 1000 (first message)
strategy="last",
allow_partial=False,
include_system=True
)
@patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info") @patch("litellm.get_model_info")
@ -193,6 +252,109 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
result = get_model_token_limit(mock_config, "planner") result = get_model_token_limit(mock_config, "planner")
self.assertEqual(result, 100000) self.assertEqual(result, 100000)
def test_has_tool_use(self):
"""Test the has_tool_use function."""
# Test with regular AI message
self.assertFalse(has_tool_use(self.ai_message))
# Test with AI message containing tool_use in string content
ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you")
self.assertTrue(has_tool_use(ai_with_tool_str))
# Test with AI message containing tool_use in structured content
ai_with_tool_dict = AIMessage(content=[
{"type": "text", "text": "I'll use a tool to help you"},
{"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}}
])
self.assertTrue(has_tool_use(ai_with_tool_dict))
# Test with AI message containing tool_calls in additional_kwargs
self.assertTrue(has_tool_use(self.ai_with_tool_use))
# Test with non-AI message
self.assertFalse(has_tool_use(self.human_message))
def test_is_tool_pair(self):
"""Test the is_tool_pair function."""
# Test with valid tool pair
self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message))
# Test with non-tool pair (wrong order)
self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use))
# Test with non-tool pair (wrong types)
self.assertFalse(is_tool_pair(self.ai_message, self.human_message))
# Test with non-tool pair (AI message without tool use)
self.assertFalse(is_tool_pair(self.ai_message, self.tool_message))
@patch("ra_aid.anthropic_message_utils.has_tool_use")
def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use):
"""Test anthropic_trim_messages with a sequence of messages including tool use."""
from ra_aid.anthropic_message_utils import anthropic_trim_messages
# Setup mock for has_tool_use to return True for AI messages at even indices
def side_effect(msg):
if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'):
return msg.test_index % 2 == 0 # Even indices have tool use
return False
mock_has_tool_use.side_effect = side_effect
# Create a sequence of alternating human and AI messages with tool use
messages = []
# Start with system message
system_msg = SystemMessage(content="You are a helpful assistant.")
messages.append(system_msg)
# Add alternating human and AI messages with tool use
for i in range(8):
if i % 2 == 0:
# Human message
msg = HumanMessage(content=f"Human message {i}")
messages.append(msg)
else:
# AI message, every other one has tool use
ai_msg = AIMessage(content=f"AI message {i}")
# Add a test_index attribute to track position
ai_msg.test_index = i
messages.append(ai_msg)
# If this AI message has tool use (even index), add a tool message after it
if i % 4 == 1: # 1, 5, etc.
tool_msg = ToolMessage(
content=f"Tool result {i}",
tool_call_id=f"tool_call_{i}",
name="test_tool"
)
messages.append(tool_msg)
# Define a token counter that returns a fixed value per message
def token_counter(msgs):
return len(msgs) * 1000
# Test with a token limit that will require trimming
result = anthropic_trim_messages(
messages,
token_counter=token_counter,
max_tokens=5000, # This will allow 5 messages
strategy="last",
allow_partial=False,
include_system=True,
num_messages_to_keep=2 # Keep system and first human message
)
# We should have kept the first 2 messages (system + human)
self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit
self.assertEqual(result[0], system_msg)
# Verify that we don't have any AI messages with tool use that aren't followed by a tool message
for i in range(len(result) - 1):
if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]):
self.assertTrue(isinstance(result[i+1], ToolMessage),
f"AI message with tool use at index {i} not followed by ToolMessage")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()