From e42f281f94853bad796e279716fed89d344cdd44 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 23:48:08 -0700 Subject: [PATCH] chore(anthropic_message_utils.py): remove unused fix_anthropic_message_content function to clean up codebase chore(anthropic_token_limiter.py): remove import of fix_anthropic_message_content as it is no longer needed test: add unit tests for has_tool_use and is_tool_pair functions to ensure correct functionality test: enhance test coverage for anthropic_trim_messages with tool use scenarios to validate message handling --- ra_aid/anthropic_message_utils.py | 35 --- ra_aid/anthropic_token_limiter.py | 1 - tests/ra_aid/test_anthropic_token_limiter.py | 216 ++++++++++++++++--- 3 files changed, 189 insertions(+), 63 deletions(-) diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py index 0df4564..e71d0ed 100644 --- a/ra_aid/anthropic_message_utils.py +++ b/ra_aid/anthropic_message_utils.py @@ -82,41 +82,6 @@ def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: ) -def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage: - """Fix message content format for Anthropic API compatibility.""" - if not isinstance(message, AIMessage) or not isinstance(message.content, list): - return message - - fixed_message = message.model_copy(deep=True) - - # Ensure first block is valid thinking type - if fixed_message.content and isinstance(fixed_message.content[0], dict): - first_block_type = fixed_message.content[0].get("type") - if first_block_type not in ("thinking", "redacted_thinking"): - # Prepend redacted_thinking block instead of thinking - fixed_message.content.insert( - 0, - { - "type": "redacted_thinking", - "data": "ENCRYPTED_REASONING", # Required field for redacted_thinking - }, - ) - - # Ensure all thinking blocks have valid structure - for i, block in enumerate(fixed_message.content): - if block.get("type") == "thinking": - # Convert thinking blocks to redacted_thinking to avoid signature validation - fixed_message.content[i] = { - "type": "redacted_thinking", - "data": "ENCRYPTED_REASONING", - } - elif block.get("type") == "redacted_thinking": - # Ensure required data field exists - if "data" not in block: - fixed_message.content[i]["data"] = "ENCRYPTED_REASONING" - - return fixed_message - def anthropic_trim_messages( messages: Sequence[BaseMessage], diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index ae82ab2..d9e7355 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -15,7 +15,6 @@ from langchain_core.messages import ( from langchain_core.messages.base import message_to_dict from ra_aid.anthropic_message_utils import ( - fix_anthropic_message_content, anthropic_trim_messages, has_tool_use, ) diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py index 933c73e..3f7e35e 100644 --- a/tests/ra_aid/test_anthropic_token_limiter.py +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -2,7 +2,12 @@ import unittest from unittest.mock import MagicMock, patch from langchain_anthropic import ChatAnthropic -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage +) from langgraph.prebuilt.chat_agent_executor import AgentState from ra_aid.anthropic_token_limiter import ( @@ -10,7 +15,10 @@ from ra_aid.anthropic_token_limiter import ( estimate_messages_tokens, get_model_token_limit, state_modifier, + sonnet_35_state_modifier, + convert_message_to_litellm_format ) +from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair class TestAnthropicTokenLimiter(unittest.TestCase): @@ -23,6 +31,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Sample messages for testing self.system_message = SystemMessage(content="You are a helpful assistant.") self.human_message = HumanMessage(content="Hello, can you help me with a task?") + self.ai_message = AIMessage(content="I'd be happy to help! What do you need?") self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming # Create more messages for testing @@ -35,6 +44,34 @@ class TestAnthropicTokenLimiter(unittest.TestCase): messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages, next=None, ) + + # Create tool-related messages for testing + self.ai_with_tool_use = AIMessage( + content="I'll use a tool to help you", + additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]} + ) + self.tool_message = ToolMessage( + content="4", + tool_call_id="tool_call_1", + name="calculator" + ) + + def test_convert_message_to_litellm_format(self): + """Test conversion of BaseMessage to litellm format.""" + # Test human message + human_result = convert_message_to_litellm_format(self.human_message) + self.assertEqual(human_result["role"], "human") + self.assertEqual(human_result["content"], "Hello, can you help me with a task?") + + # Test system message + system_result = convert_message_to_litellm_format(self.system_message) + self.assertEqual(system_result["role"], "system") + self.assertEqual(system_result["content"], "You are a helpful assistant.") + + # Test AI message + ai_result = convert_message_to_litellm_format(self.ai_message) + self.assertEqual(ai_result["role"], "ai") + self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?") @patch("ra_aid.anthropic_token_limiter.token_counter") def test_create_token_counter_wrapper(self, mock_token_counter): @@ -75,44 +112,66 @@ class TestAnthropicTokenLimiter(unittest.TestCase): @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") @patch("ra_aid.anthropic_token_limiter.print_messages_compact") - def test_state_modifier(self, mock_print, mock_create_wrapper): + @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") + def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper): # Setup a proper token counter function that returns integers - # This function needs to return values that will cause trim_messages to keep only the first message def token_counter(msgs): - # For a single message, return a small token count - if len(msgs) == 1: - return 10 - # For two messages (first + one more), return a value under our limit - elif len(msgs) == 2: - return 30 # This is under our 40 token remaining budget (50-10) - # For three messages, return a value just under our limit - elif len(msgs) == 3: - return 40 # This is exactly at our 40 token remaining budget (50-10) - # For four messages, return a value just at our limit - elif len(msgs) == 4: - return 40 # This is exactly at our 40 token remaining budget (50-10) - # For five messages, return a value that exceeds our 40 token budget - elif len(msgs) == 5: - return 60 # This exceeds our 40 token budget, forcing only 4 more messages - # For more messages, return a value over our limit - else: - return 100 # This exceeds our limit + # Return token count based on number of messages + return len(msgs) * 10 - # Don't use side_effect here, directly return the function + # Configure the mock to return our token counter mock_create_wrapper.return_value = token_counter + # Configure anthropic_trim_messages to return a subset of messages + mock_trim_messages.return_value = [self.system_message, self.human_message] + # Call state_modifier with a max token limit of 50 result = state_modifier(self.state, self.mock_model, max_input_tokens=50) - # Should keep first message and some of the others (up to 5 total) - self.assertEqual(len(result), 5) # First message plus four more - self.assertEqual(result[0], self.system_message) # First message is preserved + # Should return what anthropic_trim_messages returned + self.assertEqual(result, [self.system_message, self.human_message]) # Verify the wrapper was created with the right model mock_create_wrapper.assert_called_with(self.mock_model.model) - # Verify print_messages_compact was called - mock_print.assert_called_once() + # Verify anthropic_trim_messages was called with the right parameters + mock_trim_messages.assert_called_once() + + # Verify print_messages_compact was called at least once + self.assertTrue(mock_print.call_count >= 1) + + @patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens") + @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") + def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate): + """Test the sonnet 35 state modifier function.""" + # Setup mocks + mock_estimate.side_effect = lambda msgs: len(msgs) * 1000 + mock_trim.return_value = [self.human_message, self.ai_message] + + # Create a state with messages + state = {"messages": [self.system_message, self.human_message, self.ai_message]} + + # Test with empty messages + empty_state = {"messages": []} + self.assertEqual(sonnet_35_state_modifier(empty_state), []) + + # Test with messages under the limit + result = sonnet_35_state_modifier(state, max_input_tokens=10000) + + # Should keep the first message and call anthropic_trim_messages for the rest + self.assertEqual(len(result), 3) + self.assertEqual(result[0], self.system_message) + self.assertEqual(result[1:], [self.human_message, self.ai_message]) + + # Verify anthropic_trim_messages was called with the right parameters + mock_trim.assert_called_once_with( + [self.human_message, self.ai_message], + token_counter=mock_estimate, + max_tokens=9000, # 10000 - 1000 (first message) + strategy="last", + allow_partial=False, + include_system=True + ) @patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("litellm.get_model_info") @@ -192,6 +251,109 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Test planner agent type result = get_model_token_limit(mock_config, "planner") self.assertEqual(result, 100000) + + def test_has_tool_use(self): + """Test the has_tool_use function.""" + # Test with regular AI message + self.assertFalse(has_tool_use(self.ai_message)) + + # Test with AI message containing tool_use in string content + ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you") + self.assertTrue(has_tool_use(ai_with_tool_str)) + + # Test with AI message containing tool_use in structured content + ai_with_tool_dict = AIMessage(content=[ + {"type": "text", "text": "I'll use a tool to help you"}, + {"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}} + ]) + self.assertTrue(has_tool_use(ai_with_tool_dict)) + + # Test with AI message containing tool_calls in additional_kwargs + self.assertTrue(has_tool_use(self.ai_with_tool_use)) + + # Test with non-AI message + self.assertFalse(has_tool_use(self.human_message)) + + def test_is_tool_pair(self): + """Test the is_tool_pair function.""" + # Test with valid tool pair + self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message)) + + # Test with non-tool pair (wrong order) + self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use)) + + # Test with non-tool pair (wrong types) + self.assertFalse(is_tool_pair(self.ai_message, self.human_message)) + + # Test with non-tool pair (AI message without tool use) + self.assertFalse(is_tool_pair(self.ai_message, self.tool_message)) + + @patch("ra_aid.anthropic_message_utils.has_tool_use") + def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use): + """Test anthropic_trim_messages with a sequence of messages including tool use.""" + from ra_aid.anthropic_message_utils import anthropic_trim_messages + + # Setup mock for has_tool_use to return True for AI messages at even indices + def side_effect(msg): + if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'): + return msg.test_index % 2 == 0 # Even indices have tool use + return False + + mock_has_tool_use.side_effect = side_effect + + # Create a sequence of alternating human and AI messages with tool use + messages = [] + + # Start with system message + system_msg = SystemMessage(content="You are a helpful assistant.") + messages.append(system_msg) + + # Add alternating human and AI messages with tool use + for i in range(8): + if i % 2 == 0: + # Human message + msg = HumanMessage(content=f"Human message {i}") + messages.append(msg) + else: + # AI message, every other one has tool use + ai_msg = AIMessage(content=f"AI message {i}") + # Add a test_index attribute to track position + ai_msg.test_index = i + messages.append(ai_msg) + + # If this AI message has tool use (even index), add a tool message after it + if i % 4 == 1: # 1, 5, etc. + tool_msg = ToolMessage( + content=f"Tool result {i}", + tool_call_id=f"tool_call_{i}", + name="test_tool" + ) + messages.append(tool_msg) + + # Define a token counter that returns a fixed value per message + def token_counter(msgs): + return len(msgs) * 1000 + + # Test with a token limit that will require trimming + result = anthropic_trim_messages( + messages, + token_counter=token_counter, + max_tokens=5000, # This will allow 5 messages + strategy="last", + allow_partial=False, + include_system=True, + num_messages_to_keep=2 # Keep system and first human message + ) + + # We should have kept the first 2 messages (system + human) + self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit + self.assertEqual(result[0], system_msg) + + # Verify that we don't have any AI messages with tool use that aren't followed by a tool message + for i in range(len(result) - 1): + if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]): + self.assertTrue(isinstance(result[i+1], ToolMessage), + f"AI message with tool use at index {i} not followed by ToolMessage") if __name__ == "__main__":