diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py index 91c285f..0df4564 100644 --- a/ra_aid/anthropic_message_utils.py +++ b/ra_aid/anthropic_message_utils.py @@ -36,47 +36,49 @@ def _is_message_type( def has_tool_use(message: BaseMessage) -> bool: """Check if a message contains tool use. - + Args: message: The message to check - + Returns: bool: True if the message contains tool use """ if not isinstance(message, AIMessage): return False - + # Check content for tool_use if isinstance(message.content, str) and "tool_use" in message.content: return True - + # Check content list for tool_use blocks if isinstance(message.content, list): for item in message.content: if isinstance(item, dict) and item.get("type") == "tool_use": return True - + # Check additional_kwargs for tool_calls - if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"): + if hasattr(message, "additional_kwargs") and message.additional_kwargs.get( + "tool_calls" + ): return True - + return False def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: """Check if two messages form a tool use/result pair. - + Args: message1: First message message2: Second message - + Returns: bool: True if the messages form a tool use/result pair """ return ( - isinstance(message1, AIMessage) and - isinstance(message2, ToolMessage) and - has_tool_use(message1) + isinstance(message1, AIMessage) + and isinstance(message2, ToolMessage) + and has_tool_use(message1) ) @@ -129,6 +131,8 @@ def anthropic_trim_messages( ) -> List[BaseMessage]: """Trim messages to fit within a token limit, with Anthropic-specific handling. + Warning - not fully implemented - last strategy is supported and test, not + allow partial, not 'first' strategy either. This function is similar to langchain_core's trim_messages but with special handling for Anthropic message formats to avoid API errors. @@ -176,11 +180,11 @@ def anthropic_trim_messages( # For Anthropic, we need to maintain the conversation structure where: # 1. Every AIMessage with tool_use must be followed by a ToolMessage # 2. Every AIMessage that follows a ToolMessage must start with a tool_result - + # First, check if we have any tool_use in the messages has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages) print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}") - + # Print debug info for AIMessages for i, msg in enumerate(messages): if isinstance(msg, AIMessage): @@ -188,46 +192,52 @@ def anthropic_trim_messages( print(f" has_tool_use: {has_tool_use(msg)}") if hasattr(msg, "additional_kwargs"): print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}") - + # If we have tool_use anywhere, we need to be very careful about trimming if has_tool_use_anywhere: # For safety, just keep all messages if we're under the token limit if token_counter(messages) <= max_tokens: print("DEBUG - All messages fit within token limit, keeping all") return messages - + # We need to identify all tool_use/tool_result relationships # First, find all AIMessage+ToolMessage pairs pairs = [] i = 0 while i < len(messages) - 1: - if is_tool_pair(messages[i], messages[i+1]): - pairs.append((i, i+1)) + if is_tool_pair(messages[i], messages[i + 1]): + pairs.append((i, i + 1)) print(f"DEBUG - Found tool_use pair: ({i}, {i+1})") i += 2 else: i += 1 - + print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs") - + # For Anthropic, we need to ensure that: # 1. If we include an AIMessage with tool_use, we must include the following ToolMessage # 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use - + # The safest approach is to always keep complete AIMessage+ToolMessage pairs together # First, identify all complete pairs complete_pairs = [] for start, end in pairs: complete_pairs.append((start, end)) - - print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs") - + + print( + f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs" + ) + # Now we'll build our result, starting with the kept_messages # But we need to be careful about the first message if it has tool_use result = [] - + # Check if the last message in kept_messages has tool_use - if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]): + if ( + kept_messages + and isinstance(kept_messages[-1], AIMessage) + and has_tool_use(kept_messages[-1]) + ): # We need to find the corresponding ToolMessage for i, (ai_idx, tool_idx) in enumerate(pairs): if messages[ai_idx] is kept_messages[-1]: @@ -236,7 +246,7 @@ def anthropic_trim_messages( # Add the AIMessage and ToolMessage as a pair result.extend([messages[ai_idx], messages[tool_idx]]) # Remove this pair from the list of pairs to process later - pairs = pairs[:i] + pairs[i+1:] + pairs = pairs[:i] + pairs[i + 1 :] break else: # If we didn't find a matching pair, just add all kept_messages @@ -244,48 +254,50 @@ def anthropic_trim_messages( else: # No tool_use in the last kept message, just add all kept_messages result.extend(kept_messages) - + # If we're using the "last" strategy, we'll try to include pairs from the end if strategy == "last": # First collect all pairs we can include within the token limit pairs_to_include = [] - + # Process pairs from the end (newest first) for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)): # Try adding this pair test_msgs = result.copy() - + # Add all previously selected pairs for prev_ai_idx, prev_tool_idx in pairs_to_include: test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]]) - + # Add this pair test_msgs.extend([messages[ai_idx], messages[tool_idx]]) - + if token_counter(test_msgs) <= max_tokens: # This pair fits, add it to our list pairs_to_include.append((ai_idx, tool_idx)) print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})") else: # This pair would exceed the token limit - print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping") + print( + f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping" + ) break - + # Now add the pairs in the correct order # Sort by index to maintain the original conversation flow pairs_to_include.sort(key=lambda x: x[0]) for ai_idx, tool_idx in pairs_to_include: result.extend([messages[ai_idx], messages[tool_idx]]) - + # No need to sort - we've already added messages in the correct order - + print(f"DEBUG - Final result has {len(result)} messages") return result - + # If no tool_use, proceed with normal segmentation segments = [] i = 0 - + # Group messages into segments while i < len(remaining_msgs): segments.append([remaining_msgs[i]]) @@ -305,50 +317,60 @@ def anthropic_trim_messages( # If we have no segments, just return kept_messages if not segments: return kept_messages - + result = [] - + # Process segments from the end for i, segment in enumerate(reversed(segments)): # Try adding this segment test_msgs = segment + result - + if token_counter(kept_messages + test_msgs) <= max_tokens: result = segment + result print(f"DEBUG - Added segment {len(segments)-i-1} to result") else: # This segment would exceed the token limit - print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping") + print( + f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping" + ) break - + final_result = kept_messages + result - + # For Anthropic, we need to ensure the conversation follows a valid structure # We'll do a final check of the entire conversation print("\nDEBUG - Final result before validation:") for i, msg in enumerate(final_result): msg_type = type(msg).__name__ print(f" [{i}] {msg_type}") - + # Validate the conversation structure valid_result = [] i = 0 - + # Process messages in order while i < len(final_result): current_msg = final_result[i] - + # If this is an AIMessage with tool_use, it must be followed by a ToolMessage - if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg): - if isinstance(final_result[i+1], ToolMessage): + if ( + i < len(final_result) - 1 + and isinstance(current_msg, AIMessage) + and has_tool_use(current_msg) + ): + if isinstance(final_result[i + 1], ToolMessage): # This is a valid tool_use + tool_result pair valid_result.append(current_msg) - valid_result.append(final_result[i+1]) - print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}") + valid_result.append(final_result[i + 1]) + print( + f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}" + ) i += 2 else: # Invalid: AIMessage with tool_use not followed by ToolMessage - print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage") + print( + f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage" + ) # Skip this message to maintain valid structure i += 1 else: @@ -356,17 +378,23 @@ def anthropic_trim_messages( valid_result.append(current_msg) print(f"DEBUG - Added regular message at position {i}") i += 1 - + # Final check: don't end with an AIMessage that has tool_use - if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]): - print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage") + if ( + valid_result + and isinstance(valid_result[-1], AIMessage) + and has_tool_use(valid_result[-1]) + ): + print( + "WARNING: Last message is AIMessage with tool_use but no following ToolMessage" + ) valid_result.pop() # Remove the last message - + print("\nDEBUG - Final validated result:") for i, msg in enumerate(valid_result): msg_type = type(msg).__name__ print(f" [{i}] {msg_type}") - + return valid_result elif strategy == "first":