feat(anthropic_message_utils.py): add utilities for handling Anthropic-specific message formats and trimming to improve message processing

fix(agent_utils.py): remove debug print statement for max_input_tokens to clean up code
refactor(anthropic_token_limiter.py): update state_modifier to use anthropic_trim_messages for better token management and maintain message structure
This commit is contained in:
Ariel Frischer 2025-03-11 23:24:57 -07:00
parent 09ba1ee0b9
commit ee73c85b02
3 changed files with 433 additions and 41 deletions

View File

@ -164,7 +164,6 @@ def create_agent(
max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
)
print(f"max_input_tokens={max_input_tokens}")
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config):

View File

@ -0,0 +1,393 @@
"""Utilities for handling Anthropic-specific message formats and trimming."""
from typing import Callable, List, Literal, Optional, Sequence, Union, cast
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
def _is_message_type(
message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]]
) -> bool:
"""Check if a message is of a specific type or types.
Args:
message: The message to check
message_types: Type(s) to check against (string name or class)
Returns:
bool: True if message matches any of the specified types
"""
if not isinstance(message_types, list):
message_types = [message_types]
types_str = [t for t in message_types if isinstance(t, str)]
types_classes = tuple(t for t in message_types if isinstance(t, type))
return message.type in types_str or isinstance(message, types_classes)
def has_tool_use(message: BaseMessage) -> bool:
"""Check if a message contains tool use.
Args:
message: The message to check
Returns:
bool: True if the message contains tool use
"""
if not isinstance(message, AIMessage):
return False
# Check content for tool_use
if isinstance(message.content, str) and "tool_use" in message.content:
return True
# Check content list for tool_use blocks
if isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and item.get("type") == "tool_use":
return True
# Check additional_kwargs for tool_calls
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"):
return True
return False
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
"""Check if two messages form a tool use/result pair.
Args:
message1: First message
message2: Second message
Returns:
bool: True if the messages form a tool use/result pair
"""
return (
isinstance(message1, AIMessage) and
isinstance(message2, ToolMessage) and
has_tool_use(message1)
)
def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage:
"""Fix message content format for Anthropic API compatibility."""
if not isinstance(message, AIMessage) or not isinstance(message.content, list):
return message
fixed_message = message.model_copy(deep=True)
# Ensure first block is valid thinking type
if fixed_message.content and isinstance(fixed_message.content[0], dict):
first_block_type = fixed_message.content[0].get("type")
if first_block_type not in ("thinking", "redacted_thinking"):
# Prepend redacted_thinking block instead of thinking
fixed_message.content.insert(
0,
{
"type": "redacted_thinking",
"data": "ENCRYPTED_REASONING", # Required field for redacted_thinking
},
)
# Ensure all thinking blocks have valid structure
for i, block in enumerate(fixed_message.content):
if block.get("type") == "thinking":
# Convert thinking blocks to redacted_thinking to avoid signature validation
fixed_message.content[i] = {
"type": "redacted_thinking",
"data": "ENCRYPTED_REASONING",
}
elif block.get("type") == "redacted_thinking":
# Ensure required data field exists
if "data" not in block:
fixed_message.content[i]["data"] = "ENCRYPTED_REASONING"
return fixed_message
def anthropic_trim_messages(
messages: Sequence[BaseMessage],
*,
max_tokens: int,
token_counter: Callable[[List[BaseMessage]], int],
strategy: Literal["first", "last"] = "last",
num_messages_to_keep: int = 2,
allow_partial: bool = False,
include_system: bool = True,
start_on: Optional[Union[str, type, List[Union[str, type]]]] = None,
) -> List[BaseMessage]:
"""Trim messages to fit within a token limit, with Anthropic-specific handling.
This function is similar to langchain_core's trim_messages but with special
handling for Anthropic message formats to avoid API errors.
It always keeps the first num_messages_to_keep messages.
Args:
messages: Sequence of messages to trim
max_tokens: Maximum number of tokens allowed
token_counter: Function to count tokens in messages
strategy: Whether to keep the "first" or "last" messages
allow_partial: Whether to allow partial messages
include_system: Whether to always include the system message
start_on: Message type to start on (only for "last" strategy)
Returns:
List[BaseMessage]: Trimmed messages that fit within token limit
"""
if not messages:
return []
messages = list(messages)
# Always keep the first num_messages_to_keep messages
kept_messages = messages[:num_messages_to_keep]
remaining_msgs = messages[num_messages_to_keep:]
# Debug: Print message types for all messages
print("\nDEBUG - All messages:")
for i, msg in enumerate(messages):
msg_type = type(msg).__name__
tool_use = (
"tool_use"
if isinstance(msg, AIMessage)
and hasattr(msg, "additional_kwargs")
and msg.additional_kwargs.get("tool_calls")
else ""
)
tool_result = (
f"tool_call_id: {msg.tool_call_id}"
if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id")
else ""
)
print(f" [{i}] {msg_type} {tool_use} {tool_result}")
# For Anthropic, we need to maintain the conversation structure where:
# 1. Every AIMessage with tool_use must be followed by a ToolMessage
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result
# First, check if we have any tool_use in the messages
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}")
# Print debug info for AIMessages
for i, msg in enumerate(messages):
if isinstance(msg, AIMessage):
print(f"DEBUG - AIMessage[{i}] details:")
print(f" has_tool_use: {has_tool_use(msg)}")
if hasattr(msg, "additional_kwargs"):
print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}")
# If we have tool_use anywhere, we need to be very careful about trimming
if has_tool_use_anywhere:
# For safety, just keep all messages if we're under the token limit
if token_counter(messages) <= max_tokens:
print("DEBUG - All messages fit within token limit, keeping all")
return messages
# We need to identify all tool_use/tool_result relationships
# First, find all AIMessage+ToolMessage pairs
pairs = []
i = 0
while i < len(messages) - 1:
if is_tool_pair(messages[i], messages[i+1]):
pairs.append((i, i+1))
print(f"DEBUG - Found tool_use pair: ({i}, {i+1})")
i += 2
else:
i += 1
print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs")
# For Anthropic, we need to ensure that:
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together
# First, identify all complete pairs
complete_pairs = []
for start, end in pairs:
complete_pairs.append((start, end))
print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs")
# Now we'll build our result, starting with the kept_messages
# But we need to be careful about the first message if it has tool_use
result = []
# Check if the last message in kept_messages has tool_use
if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]):
# We need to find the corresponding ToolMessage
for i, (ai_idx, tool_idx) in enumerate(pairs):
if messages[ai_idx] is kept_messages[-1]:
# Found the pair, add all kept_messages except the last one
result.extend(kept_messages[:-1])
# Add the AIMessage and ToolMessage as a pair
result.extend([messages[ai_idx], messages[tool_idx]])
# Remove this pair from the list of pairs to process later
pairs = pairs[:i] + pairs[i+1:]
break
else:
# If we didn't find a matching pair, just add all kept_messages
result.extend(kept_messages)
else:
# No tool_use in the last kept message, just add all kept_messages
result.extend(kept_messages)
# If we're using the "last" strategy, we'll try to include pairs from the end
if strategy == "last":
# First collect all pairs we can include within the token limit
pairs_to_include = []
# Process pairs from the end (newest first)
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
# Try adding this pair
test_msgs = result.copy()
# Add all previously selected pairs
for prev_ai_idx, prev_tool_idx in pairs_to_include:
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
# Add this pair
test_msgs.extend([messages[ai_idx], messages[tool_idx]])
if token_counter(test_msgs) <= max_tokens:
# This pair fits, add it to our list
pairs_to_include.append((ai_idx, tool_idx))
print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})")
else:
# This pair would exceed the token limit
print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping")
break
# Now add the pairs in the correct order
# Sort by index to maintain the original conversation flow
pairs_to_include.sort(key=lambda x: x[0])
for ai_idx, tool_idx in pairs_to_include:
result.extend([messages[ai_idx], messages[tool_idx]])
# No need to sort - we've already added messages in the correct order
print(f"DEBUG - Final result has {len(result)} messages")
return result
# If no tool_use, proceed with normal segmentation
segments = []
i = 0
# Group messages into segments
while i < len(remaining_msgs):
segments.append([remaining_msgs[i]])
print(f"DEBUG - Added message as segment: [{i}]")
i += 1
print(f"\nDEBUG - Created {len(segments)} segments")
for i, segment in enumerate(segments):
segment_types = [type(msg).__name__ for msg in segment]
print(f" Segment {i}: {segment_types}")
# Now we have segments that maintain the required structure
# We'll add segments from the end (for "last" strategy) or beginning (for "first")
# until we hit the token limit
if strategy == "last":
# If we have no segments, just return kept_messages
if not segments:
return kept_messages
result = []
# Process segments from the end
for i, segment in enumerate(reversed(segments)):
# Try adding this segment
test_msgs = segment + result
if token_counter(kept_messages + test_msgs) <= max_tokens:
result = segment + result
print(f"DEBUG - Added segment {len(segments)-i-1} to result")
else:
# This segment would exceed the token limit
print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping")
break
final_result = kept_messages + result
# For Anthropic, we need to ensure the conversation follows a valid structure
# We'll do a final check of the entire conversation
print("\nDEBUG - Final result before validation:")
for i, msg in enumerate(final_result):
msg_type = type(msg).__name__
print(f" [{i}] {msg_type}")
# Validate the conversation structure
valid_result = []
i = 0
# Process messages in order
while i < len(final_result):
current_msg = final_result[i]
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage
if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg):
if isinstance(final_result[i+1], ToolMessage):
# This is a valid tool_use + tool_result pair
valid_result.append(current_msg)
valid_result.append(final_result[i+1])
print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}")
i += 2
else:
# Invalid: AIMessage with tool_use not followed by ToolMessage
print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage")
# Skip this message to maintain valid structure
i += 1
else:
# Regular message, just add it
valid_result.append(current_msg)
print(f"DEBUG - Added regular message at position {i}")
i += 1
# Final check: don't end with an AIMessage that has tool_use
if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]):
print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage")
valid_result.pop() # Remove the last message
print("\nDEBUG - Final validated result:")
for i, msg in enumerate(valid_result):
msg_type = type(msg).__name__
print(f" [{i}] {msg_type}")
return valid_result
elif strategy == "first":
result = []
# Process segments from the beginning
for i, segment in enumerate(segments):
# Try adding this segment
test_msgs = result + segment
if token_counter(kept_messages + test_msgs) <= max_tokens:
result = result + segment
print(f"DEBUG - Added segment {i} to result")
else:
# This segment would exceed the token limit
print(f"DEBUG - Segment {i} would exceed token limit, stopping")
break
final_result = kept_messages + result
print("\nDEBUG - Final result:")
for i, msg in enumerate(final_result):
msg_type = type(msg).__name__
print(f" [{i}] {msg_type}")
return final_result

View File

@ -1,11 +1,17 @@
"""Utilities for handling token limits with Anthropic models."""
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage, trim_messages
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage, trim_messages
from langchain_core.messages.base import message_to_dict
from ra_aid.anthropic_message_utils import (
fix_anthropic_message_content,
anthropic_trim_messages,
has_tool_use,
)
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import token_counter
@ -13,7 +19,7 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.console.output import print_messages_compact
from ra_aid.console.output import cpm, print_messages_compact
logger = get_logger(__name__)
@ -85,68 +91,58 @@ def create_token_counter_wrapper(model: str):
def state_modifier(
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message.
"""Given the agent state and max_tokens, return a trimmed list of messages.
This uses anthropic_trim_messages which always keeps the first 2 messages.
Args:
state: The current agent state containing messages
model: The language model to use for token counting
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
wrapped_token_counter = create_token_counter_wrapper(model.model)
print(f"max_input_tokens={max_input_tokens}")
max_input_tokens = 17000
first_tokens = wrapped_token_counter([first_message])
print(f"first_tokens={first_tokens}")
new_max_tokens = max_input_tokens - first_tokens
# Keep max_input_tokens at 21000 as requested
max_input_tokens = 21000
# Calculate total tokens before trimming
total_tokens_before = wrapped_token_counter(messages)
print(
f"Current token total: {total_tokens_before} (should be at least {first_tokens})"
)
print("\nDEBUG - Starting token trimming with max_tokens:", max_input_tokens)
print(f"Current token total: {wrapped_token_counter(messages)}")
# Verify the token count is correct
if total_tokens_before < first_tokens:
print(f"WARNING: Token count inconsistency detected! Recounting...")
# Count message by message to debug
for i, msg in enumerate(messages):
msg_tokens = wrapped_token_counter([msg])
print(f" Message {i}: {msg_tokens} tokens")
# Try alternative counting method
alt_count = sum(wrapped_token_counter([msg]) for msg in messages)
print(f" Alternative count method: {alt_count} tokens")
# Print more details about the messages to help debug
for i, msg in enumerate(messages):
if isinstance(msg, AIMessage):
print(f"DEBUG - AIMessage[{i}] content type: {type(msg.content)}")
print(f"DEBUG - AIMessage[{i}] has_tool_use: {has_tool_use(msg)}")
if has_tool_use(msg) and i < len(messages) - 1:
print(
f"DEBUG - Next message is ToolMessage: {isinstance(messages[i+1], ToolMessage)}"
)
print_messages_compact(messages)
trimmed_remaining = trim_messages(
remaining_messages,
result = anthropic_trim_messages(
messages,
token_counter=wrapped_token_counter,
max_tokens=new_max_tokens,
max_tokens=max_input_tokens,
strategy="last",
allow_partial=False,
include_system=True,
num_messages_to_keep=2,
)
result = [first_message] + trimmed_remaining
# Only show message if some messages were trimmed
if len(result) < len(messages):
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
# Calculate total tokens after trimming
total_tokens_after = wrapped_token_counter(result)
print(f"New token total: {total_tokens_after}")
print("BEFORE TRIMMING")
print_messages_compact(messages)
print("AFTER TRIMMING")
print_messages_compact(result)
return result
@ -176,12 +172,14 @@ def sonnet_35_state_modifier(
total_tokens_before = estimate_messages_tokens(messages)
print(f"Current token total: {total_tokens_before}")
trimmed_remaining = trim_messages(
# Trim remaining messages
trimmed_remaining = anthropic_trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
include_system=True,
)
result = [first_message] + trimmed_remaining
@ -193,6 +191,8 @@ def sonnet_35_state_modifier(
total_tokens_after = estimate_messages_tokens(result)
print(f"New token total: {total_tokens_after}")
# No need to fix message content as anthropic_trim_messages already handles this
return result