refactor(anthropic_message_utils.py): clean up whitespace and improve code readability by removing unnecessary blank lines and aligning code formatting
fix(anthropic_message_utils.py): add warning in docstring for anthropic_trim_messages function to indicate incomplete implementation and clarify behavior fix(anthropic_message_utils.py): ensure consistent formatting in conditional statements and improve readability of logical checks
This commit is contained in:
parent
a3284c9d7e
commit
376d486db8
|
|
@ -36,47 +36,49 @@ def _is_message_type(
|
||||||
|
|
||||||
def has_tool_use(message: BaseMessage) -> bool:
|
def has_tool_use(message: BaseMessage) -> bool:
|
||||||
"""Check if a message contains tool use.
|
"""Check if a message contains tool use.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: The message to check
|
message: The message to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the message contains tool use
|
bool: True if the message contains tool use
|
||||||
"""
|
"""
|
||||||
if not isinstance(message, AIMessage):
|
if not isinstance(message, AIMessage):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check content for tool_use
|
# Check content for tool_use
|
||||||
if isinstance(message.content, str) and "tool_use" in message.content:
|
if isinstance(message.content, str) and "tool_use" in message.content:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check content list for tool_use blocks
|
# Check content list for tool_use blocks
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, dict) and item.get("type") == "tool_use":
|
if isinstance(item, dict) and item.get("type") == "tool_use":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check additional_kwargs for tool_calls
|
# Check additional_kwargs for tool_calls
|
||||||
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"):
|
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get(
|
||||||
|
"tool_calls"
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
|
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
|
||||||
"""Check if two messages form a tool use/result pair.
|
"""Check if two messages form a tool use/result pair.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message1: First message
|
message1: First message
|
||||||
message2: Second message
|
message2: Second message
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the messages form a tool use/result pair
|
bool: True if the messages form a tool use/result pair
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
isinstance(message1, AIMessage) and
|
isinstance(message1, AIMessage)
|
||||||
isinstance(message2, ToolMessage) and
|
and isinstance(message2, ToolMessage)
|
||||||
has_tool_use(message1)
|
and has_tool_use(message1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -129,6 +131,8 @@ def anthropic_trim_messages(
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
"""Trim messages to fit within a token limit, with Anthropic-specific handling.
|
"""Trim messages to fit within a token limit, with Anthropic-specific handling.
|
||||||
|
|
||||||
|
Warning - not fully implemented - last strategy is supported and test, not
|
||||||
|
allow partial, not 'first' strategy either.
|
||||||
This function is similar to langchain_core's trim_messages but with special
|
This function is similar to langchain_core's trim_messages but with special
|
||||||
handling for Anthropic message formats to avoid API errors.
|
handling for Anthropic message formats to avoid API errors.
|
||||||
|
|
||||||
|
|
@ -176,11 +180,11 @@ def anthropic_trim_messages(
|
||||||
# For Anthropic, we need to maintain the conversation structure where:
|
# For Anthropic, we need to maintain the conversation structure where:
|
||||||
# 1. Every AIMessage with tool_use must be followed by a ToolMessage
|
# 1. Every AIMessage with tool_use must be followed by a ToolMessage
|
||||||
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result
|
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result
|
||||||
|
|
||||||
# First, check if we have any tool_use in the messages
|
# First, check if we have any tool_use in the messages
|
||||||
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
|
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
|
||||||
print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}")
|
print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}")
|
||||||
|
|
||||||
# Print debug info for AIMessages
|
# Print debug info for AIMessages
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
|
|
@ -188,46 +192,52 @@ def anthropic_trim_messages(
|
||||||
print(f" has_tool_use: {has_tool_use(msg)}")
|
print(f" has_tool_use: {has_tool_use(msg)}")
|
||||||
if hasattr(msg, "additional_kwargs"):
|
if hasattr(msg, "additional_kwargs"):
|
||||||
print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}")
|
print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}")
|
||||||
|
|
||||||
# If we have tool_use anywhere, we need to be very careful about trimming
|
# If we have tool_use anywhere, we need to be very careful about trimming
|
||||||
if has_tool_use_anywhere:
|
if has_tool_use_anywhere:
|
||||||
# For safety, just keep all messages if we're under the token limit
|
# For safety, just keep all messages if we're under the token limit
|
||||||
if token_counter(messages) <= max_tokens:
|
if token_counter(messages) <= max_tokens:
|
||||||
print("DEBUG - All messages fit within token limit, keeping all")
|
print("DEBUG - All messages fit within token limit, keeping all")
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
# We need to identify all tool_use/tool_result relationships
|
# We need to identify all tool_use/tool_result relationships
|
||||||
# First, find all AIMessage+ToolMessage pairs
|
# First, find all AIMessage+ToolMessage pairs
|
||||||
pairs = []
|
pairs = []
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(messages) - 1:
|
while i < len(messages) - 1:
|
||||||
if is_tool_pair(messages[i], messages[i+1]):
|
if is_tool_pair(messages[i], messages[i + 1]):
|
||||||
pairs.append((i, i+1))
|
pairs.append((i, i + 1))
|
||||||
print(f"DEBUG - Found tool_use pair: ({i}, {i+1})")
|
print(f"DEBUG - Found tool_use pair: ({i}, {i+1})")
|
||||||
i += 2
|
i += 2
|
||||||
else:
|
else:
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs")
|
print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs")
|
||||||
|
|
||||||
# For Anthropic, we need to ensure that:
|
# For Anthropic, we need to ensure that:
|
||||||
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
|
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
|
||||||
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
|
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
|
||||||
|
|
||||||
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together
|
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together
|
||||||
# First, identify all complete pairs
|
# First, identify all complete pairs
|
||||||
complete_pairs = []
|
complete_pairs = []
|
||||||
for start, end in pairs:
|
for start, end in pairs:
|
||||||
complete_pairs.append((start, end))
|
complete_pairs.append((start, end))
|
||||||
|
|
||||||
print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs")
|
print(
|
||||||
|
f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs"
|
||||||
|
)
|
||||||
|
|
||||||
# Now we'll build our result, starting with the kept_messages
|
# Now we'll build our result, starting with the kept_messages
|
||||||
# But we need to be careful about the first message if it has tool_use
|
# But we need to be careful about the first message if it has tool_use
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
# Check if the last message in kept_messages has tool_use
|
# Check if the last message in kept_messages has tool_use
|
||||||
if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]):
|
if (
|
||||||
|
kept_messages
|
||||||
|
and isinstance(kept_messages[-1], AIMessage)
|
||||||
|
and has_tool_use(kept_messages[-1])
|
||||||
|
):
|
||||||
# We need to find the corresponding ToolMessage
|
# We need to find the corresponding ToolMessage
|
||||||
for i, (ai_idx, tool_idx) in enumerate(pairs):
|
for i, (ai_idx, tool_idx) in enumerate(pairs):
|
||||||
if messages[ai_idx] is kept_messages[-1]:
|
if messages[ai_idx] is kept_messages[-1]:
|
||||||
|
|
@ -236,7 +246,7 @@ def anthropic_trim_messages(
|
||||||
# Add the AIMessage and ToolMessage as a pair
|
# Add the AIMessage and ToolMessage as a pair
|
||||||
result.extend([messages[ai_idx], messages[tool_idx]])
|
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
# Remove this pair from the list of pairs to process later
|
# Remove this pair from the list of pairs to process later
|
||||||
pairs = pairs[:i] + pairs[i+1:]
|
pairs = pairs[:i] + pairs[i + 1 :]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# If we didn't find a matching pair, just add all kept_messages
|
# If we didn't find a matching pair, just add all kept_messages
|
||||||
|
|
@ -244,48 +254,50 @@ def anthropic_trim_messages(
|
||||||
else:
|
else:
|
||||||
# No tool_use in the last kept message, just add all kept_messages
|
# No tool_use in the last kept message, just add all kept_messages
|
||||||
result.extend(kept_messages)
|
result.extend(kept_messages)
|
||||||
|
|
||||||
# If we're using the "last" strategy, we'll try to include pairs from the end
|
# If we're using the "last" strategy, we'll try to include pairs from the end
|
||||||
if strategy == "last":
|
if strategy == "last":
|
||||||
# First collect all pairs we can include within the token limit
|
# First collect all pairs we can include within the token limit
|
||||||
pairs_to_include = []
|
pairs_to_include = []
|
||||||
|
|
||||||
# Process pairs from the end (newest first)
|
# Process pairs from the end (newest first)
|
||||||
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
|
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
|
||||||
# Try adding this pair
|
# Try adding this pair
|
||||||
test_msgs = result.copy()
|
test_msgs = result.copy()
|
||||||
|
|
||||||
# Add all previously selected pairs
|
# Add all previously selected pairs
|
||||||
for prev_ai_idx, prev_tool_idx in pairs_to_include:
|
for prev_ai_idx, prev_tool_idx in pairs_to_include:
|
||||||
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
|
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
|
||||||
|
|
||||||
# Add this pair
|
# Add this pair
|
||||||
test_msgs.extend([messages[ai_idx], messages[tool_idx]])
|
test_msgs.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
|
||||||
if token_counter(test_msgs) <= max_tokens:
|
if token_counter(test_msgs) <= max_tokens:
|
||||||
# This pair fits, add it to our list
|
# This pair fits, add it to our list
|
||||||
pairs_to_include.append((ai_idx, tool_idx))
|
pairs_to_include.append((ai_idx, tool_idx))
|
||||||
print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})")
|
print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})")
|
||||||
else:
|
else:
|
||||||
# This pair would exceed the token limit
|
# This pair would exceed the token limit
|
||||||
print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping")
|
print(
|
||||||
|
f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Now add the pairs in the correct order
|
# Now add the pairs in the correct order
|
||||||
# Sort by index to maintain the original conversation flow
|
# Sort by index to maintain the original conversation flow
|
||||||
pairs_to_include.sort(key=lambda x: x[0])
|
pairs_to_include.sort(key=lambda x: x[0])
|
||||||
for ai_idx, tool_idx in pairs_to_include:
|
for ai_idx, tool_idx in pairs_to_include:
|
||||||
result.extend([messages[ai_idx], messages[tool_idx]])
|
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
|
||||||
# No need to sort - we've already added messages in the correct order
|
# No need to sort - we've already added messages in the correct order
|
||||||
|
|
||||||
print(f"DEBUG - Final result has {len(result)} messages")
|
print(f"DEBUG - Final result has {len(result)} messages")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# If no tool_use, proceed with normal segmentation
|
# If no tool_use, proceed with normal segmentation
|
||||||
segments = []
|
segments = []
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
# Group messages into segments
|
# Group messages into segments
|
||||||
while i < len(remaining_msgs):
|
while i < len(remaining_msgs):
|
||||||
segments.append([remaining_msgs[i]])
|
segments.append([remaining_msgs[i]])
|
||||||
|
|
@ -305,50 +317,60 @@ def anthropic_trim_messages(
|
||||||
# If we have no segments, just return kept_messages
|
# If we have no segments, just return kept_messages
|
||||||
if not segments:
|
if not segments:
|
||||||
return kept_messages
|
return kept_messages
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
# Process segments from the end
|
# Process segments from the end
|
||||||
for i, segment in enumerate(reversed(segments)):
|
for i, segment in enumerate(reversed(segments)):
|
||||||
# Try adding this segment
|
# Try adding this segment
|
||||||
test_msgs = segment + result
|
test_msgs = segment + result
|
||||||
|
|
||||||
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
||||||
result = segment + result
|
result = segment + result
|
||||||
print(f"DEBUG - Added segment {len(segments)-i-1} to result")
|
print(f"DEBUG - Added segment {len(segments)-i-1} to result")
|
||||||
else:
|
else:
|
||||||
# This segment would exceed the token limit
|
# This segment would exceed the token limit
|
||||||
print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping")
|
print(
|
||||||
|
f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
final_result = kept_messages + result
|
final_result = kept_messages + result
|
||||||
|
|
||||||
# For Anthropic, we need to ensure the conversation follows a valid structure
|
# For Anthropic, we need to ensure the conversation follows a valid structure
|
||||||
# We'll do a final check of the entire conversation
|
# We'll do a final check of the entire conversation
|
||||||
print("\nDEBUG - Final result before validation:")
|
print("\nDEBUG - Final result before validation:")
|
||||||
for i, msg in enumerate(final_result):
|
for i, msg in enumerate(final_result):
|
||||||
msg_type = type(msg).__name__
|
msg_type = type(msg).__name__
|
||||||
print(f" [{i}] {msg_type}")
|
print(f" [{i}] {msg_type}")
|
||||||
|
|
||||||
# Validate the conversation structure
|
# Validate the conversation structure
|
||||||
valid_result = []
|
valid_result = []
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
# Process messages in order
|
# Process messages in order
|
||||||
while i < len(final_result):
|
while i < len(final_result):
|
||||||
current_msg = final_result[i]
|
current_msg = final_result[i]
|
||||||
|
|
||||||
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage
|
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage
|
||||||
if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg):
|
if (
|
||||||
if isinstance(final_result[i+1], ToolMessage):
|
i < len(final_result) - 1
|
||||||
|
and isinstance(current_msg, AIMessage)
|
||||||
|
and has_tool_use(current_msg)
|
||||||
|
):
|
||||||
|
if isinstance(final_result[i + 1], ToolMessage):
|
||||||
# This is a valid tool_use + tool_result pair
|
# This is a valid tool_use + tool_result pair
|
||||||
valid_result.append(current_msg)
|
valid_result.append(current_msg)
|
||||||
valid_result.append(final_result[i+1])
|
valid_result.append(final_result[i + 1])
|
||||||
print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}")
|
print(
|
||||||
|
f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}"
|
||||||
|
)
|
||||||
i += 2
|
i += 2
|
||||||
else:
|
else:
|
||||||
# Invalid: AIMessage with tool_use not followed by ToolMessage
|
# Invalid: AIMessage with tool_use not followed by ToolMessage
|
||||||
print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage")
|
print(
|
||||||
|
f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage"
|
||||||
|
)
|
||||||
# Skip this message to maintain valid structure
|
# Skip this message to maintain valid structure
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
|
|
@ -356,17 +378,23 @@ def anthropic_trim_messages(
|
||||||
valid_result.append(current_msg)
|
valid_result.append(current_msg)
|
||||||
print(f"DEBUG - Added regular message at position {i}")
|
print(f"DEBUG - Added regular message at position {i}")
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
# Final check: don't end with an AIMessage that has tool_use
|
# Final check: don't end with an AIMessage that has tool_use
|
||||||
if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]):
|
if (
|
||||||
print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage")
|
valid_result
|
||||||
|
and isinstance(valid_result[-1], AIMessage)
|
||||||
|
and has_tool_use(valid_result[-1])
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"WARNING: Last message is AIMessage with tool_use but no following ToolMessage"
|
||||||
|
)
|
||||||
valid_result.pop() # Remove the last message
|
valid_result.pop() # Remove the last message
|
||||||
|
|
||||||
print("\nDEBUG - Final validated result:")
|
print("\nDEBUG - Final validated result:")
|
||||||
for i, msg in enumerate(valid_result):
|
for i, msg in enumerate(valid_result):
|
||||||
msg_type = type(msg).__name__
|
msg_type = type(msg).__name__
|
||||||
print(f" [{i}] {msg_type}")
|
print(f" [{i}] {msg_type}")
|
||||||
|
|
||||||
return valid_result
|
return valid_result
|
||||||
|
|
||||||
elif strategy == "first":
|
elif strategy == "first":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue