refactor(anthropic_message_utils.py): clean up whitespace and improve code readability by removing unnecessary blank lines and aligning code formatting

fix(anthropic_message_utils.py): add warning in docstring for anthropic_trim_messages function to indicate incomplete implementation and clarify behavior
fix(anthropic_message_utils.py): ensure consistent formatting in conditional statements and improve readability of logical checks
This commit is contained in:
Ariel Frischer 2025-03-11 23:38:31 -07:00
parent a3284c9d7e
commit 376d486db8
1 changed files with 85 additions and 57 deletions

View File

@ -36,47 +36,49 @@ def _is_message_type(
def has_tool_use(message: BaseMessage) -> bool: def has_tool_use(message: BaseMessage) -> bool:
"""Check if a message contains tool use. """Check if a message contains tool use.
Args: Args:
message: The message to check message: The message to check
Returns: Returns:
bool: True if the message contains tool use bool: True if the message contains tool use
""" """
if not isinstance(message, AIMessage): if not isinstance(message, AIMessage):
return False return False
# Check content for tool_use # Check content for tool_use
if isinstance(message.content, str) and "tool_use" in message.content: if isinstance(message.content, str) and "tool_use" in message.content:
return True return True
# Check content list for tool_use blocks # Check content list for tool_use blocks
if isinstance(message.content, list): if isinstance(message.content, list):
for item in message.content: for item in message.content:
if isinstance(item, dict) and item.get("type") == "tool_use": if isinstance(item, dict) and item.get("type") == "tool_use":
return True return True
# Check additional_kwargs for tool_calls # Check additional_kwargs for tool_calls
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"): if hasattr(message, "additional_kwargs") and message.additional_kwargs.get(
"tool_calls"
):
return True return True
return False return False
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
"""Check if two messages form a tool use/result pair. """Check if two messages form a tool use/result pair.
Args: Args:
message1: First message message1: First message
message2: Second message message2: Second message
Returns: Returns:
bool: True if the messages form a tool use/result pair bool: True if the messages form a tool use/result pair
""" """
return ( return (
isinstance(message1, AIMessage) and isinstance(message1, AIMessage)
isinstance(message2, ToolMessage) and and isinstance(message2, ToolMessage)
has_tool_use(message1) and has_tool_use(message1)
) )
@ -129,6 +131,8 @@ def anthropic_trim_messages(
) -> List[BaseMessage]: ) -> List[BaseMessage]:
"""Trim messages to fit within a token limit, with Anthropic-specific handling. """Trim messages to fit within a token limit, with Anthropic-specific handling.
Warning - not fully implemented - last strategy is supported and test, not
allow partial, not 'first' strategy either.
This function is similar to langchain_core's trim_messages but with special This function is similar to langchain_core's trim_messages but with special
handling for Anthropic message formats to avoid API errors. handling for Anthropic message formats to avoid API errors.
@ -176,11 +180,11 @@ def anthropic_trim_messages(
# For Anthropic, we need to maintain the conversation structure where: # For Anthropic, we need to maintain the conversation structure where:
# 1. Every AIMessage with tool_use must be followed by a ToolMessage # 1. Every AIMessage with tool_use must be followed by a ToolMessage
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result # 2. Every AIMessage that follows a ToolMessage must start with a tool_result
# First, check if we have any tool_use in the messages # First, check if we have any tool_use in the messages
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages) has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}") print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}")
# Print debug info for AIMessages # Print debug info for AIMessages
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
@ -188,46 +192,52 @@ def anthropic_trim_messages(
print(f" has_tool_use: {has_tool_use(msg)}") print(f" has_tool_use: {has_tool_use(msg)}")
if hasattr(msg, "additional_kwargs"): if hasattr(msg, "additional_kwargs"):
print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}") print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}")
# If we have tool_use anywhere, we need to be very careful about trimming # If we have tool_use anywhere, we need to be very careful about trimming
if has_tool_use_anywhere: if has_tool_use_anywhere:
# For safety, just keep all messages if we're under the token limit # For safety, just keep all messages if we're under the token limit
if token_counter(messages) <= max_tokens: if token_counter(messages) <= max_tokens:
print("DEBUG - All messages fit within token limit, keeping all") print("DEBUG - All messages fit within token limit, keeping all")
return messages return messages
# We need to identify all tool_use/tool_result relationships # We need to identify all tool_use/tool_result relationships
# First, find all AIMessage+ToolMessage pairs # First, find all AIMessage+ToolMessage pairs
pairs = [] pairs = []
i = 0 i = 0
while i < len(messages) - 1: while i < len(messages) - 1:
if is_tool_pair(messages[i], messages[i+1]): if is_tool_pair(messages[i], messages[i + 1]):
pairs.append((i, i+1)) pairs.append((i, i + 1))
print(f"DEBUG - Found tool_use pair: ({i}, {i+1})") print(f"DEBUG - Found tool_use pair: ({i}, {i+1})")
i += 2 i += 2
else: else:
i += 1 i += 1
print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs") print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs")
# For Anthropic, we need to ensure that: # For Anthropic, we need to ensure that:
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage # 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use # 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together # The safest approach is to always keep complete AIMessage+ToolMessage pairs together
# First, identify all complete pairs # First, identify all complete pairs
complete_pairs = [] complete_pairs = []
for start, end in pairs: for start, end in pairs:
complete_pairs.append((start, end)) complete_pairs.append((start, end))
print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs") print(
f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs"
)
# Now we'll build our result, starting with the kept_messages # Now we'll build our result, starting with the kept_messages
# But we need to be careful about the first message if it has tool_use # But we need to be careful about the first message if it has tool_use
result = [] result = []
# Check if the last message in kept_messages has tool_use # Check if the last message in kept_messages has tool_use
if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]): if (
kept_messages
and isinstance(kept_messages[-1], AIMessage)
and has_tool_use(kept_messages[-1])
):
# We need to find the corresponding ToolMessage # We need to find the corresponding ToolMessage
for i, (ai_idx, tool_idx) in enumerate(pairs): for i, (ai_idx, tool_idx) in enumerate(pairs):
if messages[ai_idx] is kept_messages[-1]: if messages[ai_idx] is kept_messages[-1]:
@ -236,7 +246,7 @@ def anthropic_trim_messages(
# Add the AIMessage and ToolMessage as a pair # Add the AIMessage and ToolMessage as a pair
result.extend([messages[ai_idx], messages[tool_idx]]) result.extend([messages[ai_idx], messages[tool_idx]])
# Remove this pair from the list of pairs to process later # Remove this pair from the list of pairs to process later
pairs = pairs[:i] + pairs[i+1:] pairs = pairs[:i] + pairs[i + 1 :]
break break
else: else:
# If we didn't find a matching pair, just add all kept_messages # If we didn't find a matching pair, just add all kept_messages
@ -244,48 +254,50 @@ def anthropic_trim_messages(
else: else:
# No tool_use in the last kept message, just add all kept_messages # No tool_use in the last kept message, just add all kept_messages
result.extend(kept_messages) result.extend(kept_messages)
# If we're using the "last" strategy, we'll try to include pairs from the end # If we're using the "last" strategy, we'll try to include pairs from the end
if strategy == "last": if strategy == "last":
# First collect all pairs we can include within the token limit # First collect all pairs we can include within the token limit
pairs_to_include = [] pairs_to_include = []
# Process pairs from the end (newest first) # Process pairs from the end (newest first)
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)): for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
# Try adding this pair # Try adding this pair
test_msgs = result.copy() test_msgs = result.copy()
# Add all previously selected pairs # Add all previously selected pairs
for prev_ai_idx, prev_tool_idx in pairs_to_include: for prev_ai_idx, prev_tool_idx in pairs_to_include:
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]]) test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
# Add this pair # Add this pair
test_msgs.extend([messages[ai_idx], messages[tool_idx]]) test_msgs.extend([messages[ai_idx], messages[tool_idx]])
if token_counter(test_msgs) <= max_tokens: if token_counter(test_msgs) <= max_tokens:
# This pair fits, add it to our list # This pair fits, add it to our list
pairs_to_include.append((ai_idx, tool_idx)) pairs_to_include.append((ai_idx, tool_idx))
print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})") print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})")
else: else:
# This pair would exceed the token limit # This pair would exceed the token limit
print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping") print(
f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping"
)
break break
# Now add the pairs in the correct order # Now add the pairs in the correct order
# Sort by index to maintain the original conversation flow # Sort by index to maintain the original conversation flow
pairs_to_include.sort(key=lambda x: x[0]) pairs_to_include.sort(key=lambda x: x[0])
for ai_idx, tool_idx in pairs_to_include: for ai_idx, tool_idx in pairs_to_include:
result.extend([messages[ai_idx], messages[tool_idx]]) result.extend([messages[ai_idx], messages[tool_idx]])
# No need to sort - we've already added messages in the correct order # No need to sort - we've already added messages in the correct order
print(f"DEBUG - Final result has {len(result)} messages") print(f"DEBUG - Final result has {len(result)} messages")
return result return result
# If no tool_use, proceed with normal segmentation # If no tool_use, proceed with normal segmentation
segments = [] segments = []
i = 0 i = 0
# Group messages into segments # Group messages into segments
while i < len(remaining_msgs): while i < len(remaining_msgs):
segments.append([remaining_msgs[i]]) segments.append([remaining_msgs[i]])
@ -305,50 +317,60 @@ def anthropic_trim_messages(
# If we have no segments, just return kept_messages # If we have no segments, just return kept_messages
if not segments: if not segments:
return kept_messages return kept_messages
result = [] result = []
# Process segments from the end # Process segments from the end
for i, segment in enumerate(reversed(segments)): for i, segment in enumerate(reversed(segments)):
# Try adding this segment # Try adding this segment
test_msgs = segment + result test_msgs = segment + result
if token_counter(kept_messages + test_msgs) <= max_tokens: if token_counter(kept_messages + test_msgs) <= max_tokens:
result = segment + result result = segment + result
print(f"DEBUG - Added segment {len(segments)-i-1} to result") print(f"DEBUG - Added segment {len(segments)-i-1} to result")
else: else:
# This segment would exceed the token limit # This segment would exceed the token limit
print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping") print(
f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping"
)
break break
final_result = kept_messages + result final_result = kept_messages + result
# For Anthropic, we need to ensure the conversation follows a valid structure # For Anthropic, we need to ensure the conversation follows a valid structure
# We'll do a final check of the entire conversation # We'll do a final check of the entire conversation
print("\nDEBUG - Final result before validation:") print("\nDEBUG - Final result before validation:")
for i, msg in enumerate(final_result): for i, msg in enumerate(final_result):
msg_type = type(msg).__name__ msg_type = type(msg).__name__
print(f" [{i}] {msg_type}") print(f" [{i}] {msg_type}")
# Validate the conversation structure # Validate the conversation structure
valid_result = [] valid_result = []
i = 0 i = 0
# Process messages in order # Process messages in order
while i < len(final_result): while i < len(final_result):
current_msg = final_result[i] current_msg = final_result[i]
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage # If this is an AIMessage with tool_use, it must be followed by a ToolMessage
if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg): if (
if isinstance(final_result[i+1], ToolMessage): i < len(final_result) - 1
and isinstance(current_msg, AIMessage)
and has_tool_use(current_msg)
):
if isinstance(final_result[i + 1], ToolMessage):
# This is a valid tool_use + tool_result pair # This is a valid tool_use + tool_result pair
valid_result.append(current_msg) valid_result.append(current_msg)
valid_result.append(final_result[i+1]) valid_result.append(final_result[i + 1])
print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}") print(
f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}"
)
i += 2 i += 2
else: else:
# Invalid: AIMessage with tool_use not followed by ToolMessage # Invalid: AIMessage with tool_use not followed by ToolMessage
print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage") print(
f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage"
)
# Skip this message to maintain valid structure # Skip this message to maintain valid structure
i += 1 i += 1
else: else:
@ -356,17 +378,23 @@ def anthropic_trim_messages(
valid_result.append(current_msg) valid_result.append(current_msg)
print(f"DEBUG - Added regular message at position {i}") print(f"DEBUG - Added regular message at position {i}")
i += 1 i += 1
# Final check: don't end with an AIMessage that has tool_use # Final check: don't end with an AIMessage that has tool_use
if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]): if (
print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage") valid_result
and isinstance(valid_result[-1], AIMessage)
and has_tool_use(valid_result[-1])
):
print(
"WARNING: Last message is AIMessage with tool_use but no following ToolMessage"
)
valid_result.pop() # Remove the last message valid_result.pop() # Remove the last message
print("\nDEBUG - Final validated result:") print("\nDEBUG - Final validated result:")
for i, msg in enumerate(valid_result): for i, msg in enumerate(valid_result):
msg_type = type(msg).__name__ msg_type = type(msg).__name__
print(f" [{i}] {msg_type}") print(f" [{i}] {msg_type}")
return valid_result return valid_result
elif strategy == "first": elif strategy == "first":