chore(anthropic_message_utils.py): remove unused fix_anthropic_message_content function to clean up codebase
chore(anthropic_token_limiter.py): remove import of fix_anthropic_message_content as it is no longer needed test: add unit tests for has_tool_use and is_tool_pair functions to ensure correct functionality test: enhance test coverage for anthropic_trim_messages with tool use scenarios to validate message handling
This commit is contained in:
parent
376d486db8
commit
e42f281f94
|
|
@ -82,41 +82,6 @@ def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage:
|
|
||||||
"""Fix message content format for Anthropic API compatibility."""
|
|
||||||
if not isinstance(message, AIMessage) or not isinstance(message.content, list):
|
|
||||||
return message
|
|
||||||
|
|
||||||
fixed_message = message.model_copy(deep=True)
|
|
||||||
|
|
||||||
# Ensure first block is valid thinking type
|
|
||||||
if fixed_message.content and isinstance(fixed_message.content[0], dict):
|
|
||||||
first_block_type = fixed_message.content[0].get("type")
|
|
||||||
if first_block_type not in ("thinking", "redacted_thinking"):
|
|
||||||
# Prepend redacted_thinking block instead of thinking
|
|
||||||
fixed_message.content.insert(
|
|
||||||
0,
|
|
||||||
{
|
|
||||||
"type": "redacted_thinking",
|
|
||||||
"data": "ENCRYPTED_REASONING", # Required field for redacted_thinking
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure all thinking blocks have valid structure
|
|
||||||
for i, block in enumerate(fixed_message.content):
|
|
||||||
if block.get("type") == "thinking":
|
|
||||||
# Convert thinking blocks to redacted_thinking to avoid signature validation
|
|
||||||
fixed_message.content[i] = {
|
|
||||||
"type": "redacted_thinking",
|
|
||||||
"data": "ENCRYPTED_REASONING",
|
|
||||||
}
|
|
||||||
elif block.get("type") == "redacted_thinking":
|
|
||||||
# Ensure required data field exists
|
|
||||||
if "data" not in block:
|
|
||||||
fixed_message.content[i]["data"] = "ENCRYPTED_REASONING"
|
|
||||||
|
|
||||||
return fixed_message
|
|
||||||
|
|
||||||
|
|
||||||
def anthropic_trim_messages(
|
def anthropic_trim_messages(
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from langchain_core.messages import (
|
||||||
from langchain_core.messages.base import message_to_dict
|
from langchain_core.messages.base import message_to_dict
|
||||||
|
|
||||||
from ra_aid.anthropic_message_utils import (
|
from ra_aid.anthropic_message_utils import (
|
||||||
fix_anthropic_message_content,
|
|
||||||
anthropic_trim_messages,
|
anthropic_trim_messages,
|
||||||
has_tool_use,
|
has_tool_use,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,12 @@ import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage
|
||||||
|
)
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
|
|
||||||
from ra_aid.anthropic_token_limiter import (
|
from ra_aid.anthropic_token_limiter import (
|
||||||
|
|
@ -10,7 +15,10 @@ from ra_aid.anthropic_token_limiter import (
|
||||||
estimate_messages_tokens,
|
estimate_messages_tokens,
|
||||||
get_model_token_limit,
|
get_model_token_limit,
|
||||||
state_modifier,
|
state_modifier,
|
||||||
|
sonnet_35_state_modifier,
|
||||||
|
convert_message_to_litellm_format
|
||||||
)
|
)
|
||||||
|
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||||
|
|
||||||
|
|
||||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
|
|
@ -23,6 +31,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
# Sample messages for testing
|
# Sample messages for testing
|
||||||
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
||||||
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
||||||
|
self.ai_message = AIMessage(content="I'd be happy to help! What do you need?")
|
||||||
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
||||||
|
|
||||||
# Create more messages for testing
|
# Create more messages for testing
|
||||||
|
|
@ -36,6 +45,34 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
next=None,
|
next=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create tool-related messages for testing
|
||||||
|
self.ai_with_tool_use = AIMessage(
|
||||||
|
content="I'll use a tool to help you",
|
||||||
|
additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]}
|
||||||
|
)
|
||||||
|
self.tool_message = ToolMessage(
|
||||||
|
content="4",
|
||||||
|
tool_call_id="tool_call_1",
|
||||||
|
name="calculator"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_message_to_litellm_format(self):
|
||||||
|
"""Test conversion of BaseMessage to litellm format."""
|
||||||
|
# Test human message
|
||||||
|
human_result = convert_message_to_litellm_format(self.human_message)
|
||||||
|
self.assertEqual(human_result["role"], "human")
|
||||||
|
self.assertEqual(human_result["content"], "Hello, can you help me with a task?")
|
||||||
|
|
||||||
|
# Test system message
|
||||||
|
system_result = convert_message_to_litellm_format(self.system_message)
|
||||||
|
self.assertEqual(system_result["role"], "system")
|
||||||
|
self.assertEqual(system_result["content"], "You are a helpful assistant.")
|
||||||
|
|
||||||
|
# Test AI message
|
||||||
|
ai_result = convert_message_to_litellm_format(self.ai_message)
|
||||||
|
self.assertEqual(ai_result["role"], "ai")
|
||||||
|
self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?")
|
||||||
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
||||||
def test_create_token_counter_wrapper(self, mock_token_counter):
|
def test_create_token_counter_wrapper(self, mock_token_counter):
|
||||||
from ra_aid.config import DEFAULT_MODEL
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
@ -75,44 +112,66 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||||
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||||
def test_state_modifier(self, mock_print, mock_create_wrapper):
|
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||||
|
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
|
||||||
# Setup a proper token counter function that returns integers
|
# Setup a proper token counter function that returns integers
|
||||||
# This function needs to return values that will cause trim_messages to keep only the first message
|
|
||||||
def token_counter(msgs):
|
def token_counter(msgs):
|
||||||
# For a single message, return a small token count
|
# Return token count based on number of messages
|
||||||
if len(msgs) == 1:
|
return len(msgs) * 10
|
||||||
return 10
|
|
||||||
# For two messages (first + one more), return a value under our limit
|
|
||||||
elif len(msgs) == 2:
|
|
||||||
return 30 # This is under our 40 token remaining budget (50-10)
|
|
||||||
# For three messages, return a value just under our limit
|
|
||||||
elif len(msgs) == 3:
|
|
||||||
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
|
||||||
# For four messages, return a value just at our limit
|
|
||||||
elif len(msgs) == 4:
|
|
||||||
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
|
||||||
# For five messages, return a value that exceeds our 40 token budget
|
|
||||||
elif len(msgs) == 5:
|
|
||||||
return 60 # This exceeds our 40 token budget, forcing only 4 more messages
|
|
||||||
# For more messages, return a value over our limit
|
|
||||||
else:
|
|
||||||
return 100 # This exceeds our limit
|
|
||||||
|
|
||||||
# Don't use side_effect here, directly return the function
|
# Configure the mock to return our token counter
|
||||||
mock_create_wrapper.return_value = token_counter
|
mock_create_wrapper.return_value = token_counter
|
||||||
|
|
||||||
|
# Configure anthropic_trim_messages to return a subset of messages
|
||||||
|
mock_trim_messages.return_value = [self.system_message, self.human_message]
|
||||||
|
|
||||||
# Call state_modifier with a max token limit of 50
|
# Call state_modifier with a max token limit of 50
|
||||||
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
||||||
|
|
||||||
# Should keep first message and some of the others (up to 5 total)
|
# Should return what anthropic_trim_messages returned
|
||||||
self.assertEqual(len(result), 5) # First message plus four more
|
self.assertEqual(result, [self.system_message, self.human_message])
|
||||||
self.assertEqual(result[0], self.system_message) # First message is preserved
|
|
||||||
|
|
||||||
# Verify the wrapper was created with the right model
|
# Verify the wrapper was created with the right model
|
||||||
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
||||||
|
|
||||||
# Verify print_messages_compact was called
|
# Verify anthropic_trim_messages was called with the right parameters
|
||||||
mock_print.assert_called_once()
|
mock_trim_messages.assert_called_once()
|
||||||
|
|
||||||
|
# Verify print_messages_compact was called at least once
|
||||||
|
self.assertTrue(mock_print.call_count >= 1)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens")
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||||
|
def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate):
|
||||||
|
"""Test the sonnet 35 state modifier function."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_estimate.side_effect = lambda msgs: len(msgs) * 1000
|
||||||
|
mock_trim.return_value = [self.human_message, self.ai_message]
|
||||||
|
|
||||||
|
# Create a state with messages
|
||||||
|
state = {"messages": [self.system_message, self.human_message, self.ai_message]}
|
||||||
|
|
||||||
|
# Test with empty messages
|
||||||
|
empty_state = {"messages": []}
|
||||||
|
self.assertEqual(sonnet_35_state_modifier(empty_state), [])
|
||||||
|
|
||||||
|
# Test with messages under the limit
|
||||||
|
result = sonnet_35_state_modifier(state, max_input_tokens=10000)
|
||||||
|
|
||||||
|
# Should keep the first message and call anthropic_trim_messages for the rest
|
||||||
|
self.assertEqual(len(result), 3)
|
||||||
|
self.assertEqual(result[0], self.system_message)
|
||||||
|
self.assertEqual(result[1:], [self.human_message, self.ai_message])
|
||||||
|
|
||||||
|
# Verify anthropic_trim_messages was called with the right parameters
|
||||||
|
mock_trim.assert_called_once_with(
|
||||||
|
[self.human_message, self.ai_message],
|
||||||
|
token_counter=mock_estimate,
|
||||||
|
max_tokens=9000, # 10000 - 1000 (first message)
|
||||||
|
strategy="last",
|
||||||
|
allow_partial=False,
|
||||||
|
include_system=True
|
||||||
|
)
|
||||||
|
|
||||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
@patch("litellm.get_model_info")
|
@patch("litellm.get_model_info")
|
||||||
|
|
@ -193,6 +252,109 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
result = get_model_token_limit(mock_config, "planner")
|
result = get_model_token_limit(mock_config, "planner")
|
||||||
self.assertEqual(result, 100000)
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
|
def test_has_tool_use(self):
|
||||||
|
"""Test the has_tool_use function."""
|
||||||
|
# Test with regular AI message
|
||||||
|
self.assertFalse(has_tool_use(self.ai_message))
|
||||||
|
|
||||||
|
# Test with AI message containing tool_use in string content
|
||||||
|
ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you")
|
||||||
|
self.assertTrue(has_tool_use(ai_with_tool_str))
|
||||||
|
|
||||||
|
# Test with AI message containing tool_use in structured content
|
||||||
|
ai_with_tool_dict = AIMessage(content=[
|
||||||
|
{"type": "text", "text": "I'll use a tool to help you"},
|
||||||
|
{"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}}
|
||||||
|
])
|
||||||
|
self.assertTrue(has_tool_use(ai_with_tool_dict))
|
||||||
|
|
||||||
|
# Test with AI message containing tool_calls in additional_kwargs
|
||||||
|
self.assertTrue(has_tool_use(self.ai_with_tool_use))
|
||||||
|
|
||||||
|
# Test with non-AI message
|
||||||
|
self.assertFalse(has_tool_use(self.human_message))
|
||||||
|
|
||||||
|
def test_is_tool_pair(self):
|
||||||
|
"""Test the is_tool_pair function."""
|
||||||
|
# Test with valid tool pair
|
||||||
|
self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message))
|
||||||
|
|
||||||
|
# Test with non-tool pair (wrong order)
|
||||||
|
self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use))
|
||||||
|
|
||||||
|
# Test with non-tool pair (wrong types)
|
||||||
|
self.assertFalse(is_tool_pair(self.ai_message, self.human_message))
|
||||||
|
|
||||||
|
# Test with non-tool pair (AI message without tool use)
|
||||||
|
self.assertFalse(is_tool_pair(self.ai_message, self.tool_message))
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_message_utils.has_tool_use")
|
||||||
|
def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use):
|
||||||
|
"""Test anthropic_trim_messages with a sequence of messages including tool use."""
|
||||||
|
from ra_aid.anthropic_message_utils import anthropic_trim_messages
|
||||||
|
|
||||||
|
# Setup mock for has_tool_use to return True for AI messages at even indices
|
||||||
|
def side_effect(msg):
|
||||||
|
if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'):
|
||||||
|
return msg.test_index % 2 == 0 # Even indices have tool use
|
||||||
|
return False
|
||||||
|
|
||||||
|
mock_has_tool_use.side_effect = side_effect
|
||||||
|
|
||||||
|
# Create a sequence of alternating human and AI messages with tool use
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# Start with system message
|
||||||
|
system_msg = SystemMessage(content="You are a helpful assistant.")
|
||||||
|
messages.append(system_msg)
|
||||||
|
|
||||||
|
# Add alternating human and AI messages with tool use
|
||||||
|
for i in range(8):
|
||||||
|
if i % 2 == 0:
|
||||||
|
# Human message
|
||||||
|
msg = HumanMessage(content=f"Human message {i}")
|
||||||
|
messages.append(msg)
|
||||||
|
else:
|
||||||
|
# AI message, every other one has tool use
|
||||||
|
ai_msg = AIMessage(content=f"AI message {i}")
|
||||||
|
# Add a test_index attribute to track position
|
||||||
|
ai_msg.test_index = i
|
||||||
|
messages.append(ai_msg)
|
||||||
|
|
||||||
|
# If this AI message has tool use (even index), add a tool message after it
|
||||||
|
if i % 4 == 1: # 1, 5, etc.
|
||||||
|
tool_msg = ToolMessage(
|
||||||
|
content=f"Tool result {i}",
|
||||||
|
tool_call_id=f"tool_call_{i}",
|
||||||
|
name="test_tool"
|
||||||
|
)
|
||||||
|
messages.append(tool_msg)
|
||||||
|
|
||||||
|
# Define a token counter that returns a fixed value per message
|
||||||
|
def token_counter(msgs):
|
||||||
|
return len(msgs) * 1000
|
||||||
|
|
||||||
|
# Test with a token limit that will require trimming
|
||||||
|
result = anthropic_trim_messages(
|
||||||
|
messages,
|
||||||
|
token_counter=token_counter,
|
||||||
|
max_tokens=5000, # This will allow 5 messages
|
||||||
|
strategy="last",
|
||||||
|
allow_partial=False,
|
||||||
|
include_system=True,
|
||||||
|
num_messages_to_keep=2 # Keep system and first human message
|
||||||
|
)
|
||||||
|
|
||||||
|
# We should have kept the first 2 messages (system + human)
|
||||||
|
self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit
|
||||||
|
self.assertEqual(result[0], system_msg)
|
||||||
|
|
||||||
|
# Verify that we don't have any AI messages with tool use that aren't followed by a tool message
|
||||||
|
for i in range(len(result) - 1):
|
||||||
|
if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]):
|
||||||
|
self.assertTrue(isinstance(result[i+1], ToolMessage),
|
||||||
|
f"AI message with tool use at index {i} not followed by ToolMessage")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue