chore(anthropic_message_utils.py): remove unused fix_anthropic_message_content function to clean up codebase
chore(anthropic_token_limiter.py): remove import of fix_anthropic_message_content as it is no longer needed test: add unit tests for has_tool_use and is_tool_pair functions to ensure correct functionality test: enhance test coverage for anthropic_trim_messages with tool use scenarios to validate message handling
This commit is contained in:
parent
376d486db8
commit
e42f281f94
|
|
@ -82,41 +82,6 @@ def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage:
|
||||
"""Fix message content format for Anthropic API compatibility."""
|
||||
if not isinstance(message, AIMessage) or not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
fixed_message = message.model_copy(deep=True)
|
||||
|
||||
# Ensure first block is valid thinking type
|
||||
if fixed_message.content and isinstance(fixed_message.content[0], dict):
|
||||
first_block_type = fixed_message.content[0].get("type")
|
||||
if first_block_type not in ("thinking", "redacted_thinking"):
|
||||
# Prepend redacted_thinking block instead of thinking
|
||||
fixed_message.content.insert(
|
||||
0,
|
||||
{
|
||||
"type": "redacted_thinking",
|
||||
"data": "ENCRYPTED_REASONING", # Required field for redacted_thinking
|
||||
},
|
||||
)
|
||||
|
||||
# Ensure all thinking blocks have valid structure
|
||||
for i, block in enumerate(fixed_message.content):
|
||||
if block.get("type") == "thinking":
|
||||
# Convert thinking blocks to redacted_thinking to avoid signature validation
|
||||
fixed_message.content[i] = {
|
||||
"type": "redacted_thinking",
|
||||
"data": "ENCRYPTED_REASONING",
|
||||
}
|
||||
elif block.get("type") == "redacted_thinking":
|
||||
# Ensure required data field exists
|
||||
if "data" not in block:
|
||||
fixed_message.content[i]["data"] = "ENCRYPTED_REASONING"
|
||||
|
||||
return fixed_message
|
||||
|
||||
|
||||
def anthropic_trim_messages(
|
||||
messages: Sequence[BaseMessage],
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from langchain_core.messages import (
|
|||
from langchain_core.messages.base import message_to_dict
|
||||
|
||||
from ra_aid.anthropic_message_utils import (
|
||||
fix_anthropic_message_content,
|
||||
anthropic_trim_messages,
|
||||
has_tool_use,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,12 @@ import unittest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage
|
||||
)
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
|
|
@ -10,7 +15,10 @@ from ra_aid.anthropic_token_limiter import (
|
|||
estimate_messages_tokens,
|
||||
get_model_token_limit,
|
||||
state_modifier,
|
||||
sonnet_35_state_modifier,
|
||||
convert_message_to_litellm_format
|
||||
)
|
||||
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||
|
||||
|
||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||
|
|
@ -23,6 +31,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
# Sample messages for testing
|
||||
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
||||
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
||||
self.ai_message = AIMessage(content="I'd be happy to help! What do you need?")
|
||||
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
||||
|
||||
# Create more messages for testing
|
||||
|
|
@ -35,6 +44,34 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
|
||||
next=None,
|
||||
)
|
||||
|
||||
# Create tool-related messages for testing
|
||||
self.ai_with_tool_use = AIMessage(
|
||||
content="I'll use a tool to help you",
|
||||
additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]}
|
||||
)
|
||||
self.tool_message = ToolMessage(
|
||||
content="4",
|
||||
tool_call_id="tool_call_1",
|
||||
name="calculator"
|
||||
)
|
||||
|
||||
def test_convert_message_to_litellm_format(self):
|
||||
"""Test conversion of BaseMessage to litellm format."""
|
||||
# Test human message
|
||||
human_result = convert_message_to_litellm_format(self.human_message)
|
||||
self.assertEqual(human_result["role"], "human")
|
||||
self.assertEqual(human_result["content"], "Hello, can you help me with a task?")
|
||||
|
||||
# Test system message
|
||||
system_result = convert_message_to_litellm_format(self.system_message)
|
||||
self.assertEqual(system_result["role"], "system")
|
||||
self.assertEqual(system_result["content"], "You are a helpful assistant.")
|
||||
|
||||
# Test AI message
|
||||
ai_result = convert_message_to_litellm_format(self.ai_message)
|
||||
self.assertEqual(ai_result["role"], "ai")
|
||||
self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?")
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
||||
def test_create_token_counter_wrapper(self, mock_token_counter):
|
||||
|
|
@ -75,44 +112,66 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
|
||||
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||
def test_state_modifier(self, mock_print, mock_create_wrapper):
|
||||
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
|
||||
# Setup a proper token counter function that returns integers
|
||||
# This function needs to return values that will cause trim_messages to keep only the first message
|
||||
def token_counter(msgs):
|
||||
# For a single message, return a small token count
|
||||
if len(msgs) == 1:
|
||||
return 10
|
||||
# For two messages (first + one more), return a value under our limit
|
||||
elif len(msgs) == 2:
|
||||
return 30 # This is under our 40 token remaining budget (50-10)
|
||||
# For three messages, return a value just under our limit
|
||||
elif len(msgs) == 3:
|
||||
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
||||
# For four messages, return a value just at our limit
|
||||
elif len(msgs) == 4:
|
||||
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
||||
# For five messages, return a value that exceeds our 40 token budget
|
||||
elif len(msgs) == 5:
|
||||
return 60 # This exceeds our 40 token budget, forcing only 4 more messages
|
||||
# For more messages, return a value over our limit
|
||||
else:
|
||||
return 100 # This exceeds our limit
|
||||
# Return token count based on number of messages
|
||||
return len(msgs) * 10
|
||||
|
||||
# Don't use side_effect here, directly return the function
|
||||
# Configure the mock to return our token counter
|
||||
mock_create_wrapper.return_value = token_counter
|
||||
|
||||
# Configure anthropic_trim_messages to return a subset of messages
|
||||
mock_trim_messages.return_value = [self.system_message, self.human_message]
|
||||
|
||||
# Call state_modifier with a max token limit of 50
|
||||
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
||||
|
||||
# Should keep first message and some of the others (up to 5 total)
|
||||
self.assertEqual(len(result), 5) # First message plus four more
|
||||
self.assertEqual(result[0], self.system_message) # First message is preserved
|
||||
# Should return what anthropic_trim_messages returned
|
||||
self.assertEqual(result, [self.system_message, self.human_message])
|
||||
|
||||
# Verify the wrapper was created with the right model
|
||||
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
||||
|
||||
# Verify print_messages_compact was called
|
||||
mock_print.assert_called_once()
|
||||
# Verify anthropic_trim_messages was called with the right parameters
|
||||
mock_trim_messages.assert_called_once()
|
||||
|
||||
# Verify print_messages_compact was called at least once
|
||||
self.assertTrue(mock_print.call_count >= 1)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens")
|
||||
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||
def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate):
|
||||
"""Test the sonnet 35 state modifier function."""
|
||||
# Setup mocks
|
||||
mock_estimate.side_effect = lambda msgs: len(msgs) * 1000
|
||||
mock_trim.return_value = [self.human_message, self.ai_message]
|
||||
|
||||
# Create a state with messages
|
||||
state = {"messages": [self.system_message, self.human_message, self.ai_message]}
|
||||
|
||||
# Test with empty messages
|
||||
empty_state = {"messages": []}
|
||||
self.assertEqual(sonnet_35_state_modifier(empty_state), [])
|
||||
|
||||
# Test with messages under the limit
|
||||
result = sonnet_35_state_modifier(state, max_input_tokens=10000)
|
||||
|
||||
# Should keep the first message and call anthropic_trim_messages for the rest
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0], self.system_message)
|
||||
self.assertEqual(result[1:], [self.human_message, self.ai_message])
|
||||
|
||||
# Verify anthropic_trim_messages was called with the right parameters
|
||||
mock_trim.assert_called_once_with(
|
||||
[self.human_message, self.ai_message],
|
||||
token_counter=mock_estimate,
|
||||
max_tokens=9000, # 10000 - 1000 (first message)
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
include_system=True
|
||||
)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
|
|
@ -192,6 +251,109 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
# Test planner agent type
|
||||
result = get_model_token_limit(mock_config, "planner")
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
def test_has_tool_use(self):
|
||||
"""Test the has_tool_use function."""
|
||||
# Test with regular AI message
|
||||
self.assertFalse(has_tool_use(self.ai_message))
|
||||
|
||||
# Test with AI message containing tool_use in string content
|
||||
ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you")
|
||||
self.assertTrue(has_tool_use(ai_with_tool_str))
|
||||
|
||||
# Test with AI message containing tool_use in structured content
|
||||
ai_with_tool_dict = AIMessage(content=[
|
||||
{"type": "text", "text": "I'll use a tool to help you"},
|
||||
{"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}}
|
||||
])
|
||||
self.assertTrue(has_tool_use(ai_with_tool_dict))
|
||||
|
||||
# Test with AI message containing tool_calls in additional_kwargs
|
||||
self.assertTrue(has_tool_use(self.ai_with_tool_use))
|
||||
|
||||
# Test with non-AI message
|
||||
self.assertFalse(has_tool_use(self.human_message))
|
||||
|
||||
def test_is_tool_pair(self):
|
||||
"""Test the is_tool_pair function."""
|
||||
# Test with valid tool pair
|
||||
self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message))
|
||||
|
||||
# Test with non-tool pair (wrong order)
|
||||
self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use))
|
||||
|
||||
# Test with non-tool pair (wrong types)
|
||||
self.assertFalse(is_tool_pair(self.ai_message, self.human_message))
|
||||
|
||||
# Test with non-tool pair (AI message without tool use)
|
||||
self.assertFalse(is_tool_pair(self.ai_message, self.tool_message))
|
||||
|
||||
@patch("ra_aid.anthropic_message_utils.has_tool_use")
|
||||
def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use):
|
||||
"""Test anthropic_trim_messages with a sequence of messages including tool use."""
|
||||
from ra_aid.anthropic_message_utils import anthropic_trim_messages
|
||||
|
||||
# Setup mock for has_tool_use to return True for AI messages at even indices
|
||||
def side_effect(msg):
|
||||
if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'):
|
||||
return msg.test_index % 2 == 0 # Even indices have tool use
|
||||
return False
|
||||
|
||||
mock_has_tool_use.side_effect = side_effect
|
||||
|
||||
# Create a sequence of alternating human and AI messages with tool use
|
||||
messages = []
|
||||
|
||||
# Start with system message
|
||||
system_msg = SystemMessage(content="You are a helpful assistant.")
|
||||
messages.append(system_msg)
|
||||
|
||||
# Add alternating human and AI messages with tool use
|
||||
for i in range(8):
|
||||
if i % 2 == 0:
|
||||
# Human message
|
||||
msg = HumanMessage(content=f"Human message {i}")
|
||||
messages.append(msg)
|
||||
else:
|
||||
# AI message, every other one has tool use
|
||||
ai_msg = AIMessage(content=f"AI message {i}")
|
||||
# Add a test_index attribute to track position
|
||||
ai_msg.test_index = i
|
||||
messages.append(ai_msg)
|
||||
|
||||
# If this AI message has tool use (even index), add a tool message after it
|
||||
if i % 4 == 1: # 1, 5, etc.
|
||||
tool_msg = ToolMessage(
|
||||
content=f"Tool result {i}",
|
||||
tool_call_id=f"tool_call_{i}",
|
||||
name="test_tool"
|
||||
)
|
||||
messages.append(tool_msg)
|
||||
|
||||
# Define a token counter that returns a fixed value per message
|
||||
def token_counter(msgs):
|
||||
return len(msgs) * 1000
|
||||
|
||||
# Test with a token limit that will require trimming
|
||||
result = anthropic_trim_messages(
|
||||
messages,
|
||||
token_counter=token_counter,
|
||||
max_tokens=5000, # This will allow 5 messages
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
include_system=True,
|
||||
num_messages_to_keep=2 # Keep system and first human message
|
||||
)
|
||||
|
||||
# We should have kept the first 2 messages (system + human)
|
||||
self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit
|
||||
self.assertEqual(result[0], system_msg)
|
||||
|
||||
# Verify that we don't have any AI messages with tool use that aren't followed by a tool message
|
||||
for i in range(len(result) - 1):
|
||||
if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]):
|
||||
self.assertTrue(isinstance(result[i+1], ToolMessage),
|
||||
f"AI message with tool use at index {i} not followed by ToolMessage")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue