feat(anthropic_message_utils.py): add utilities for handling Anthropic-specific message formats and trimming to improve message processing
fix(agent_utils.py): remove debug print statement for max_input_tokens to clean up code refactor(anthropic_token_limiter.py): update state_modifier to use anthropic_trim_messages for better token management and maintain message structure
This commit is contained in:
parent
09ba1ee0b9
commit
ee73c85b02
|
|
@ -164,7 +164,6 @@ def create_agent(
|
||||||
max_input_tokens = (
|
max_input_tokens = (
|
||||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||||
)
|
)
|
||||||
print(f"max_input_tokens={max_input_tokens}")
|
|
||||||
|
|
||||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||||
if is_anthropic_claude(config):
|
if is_anthropic_claude(config):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,393 @@
|
||||||
|
"""Utilities for handling Anthropic-specific message formats and trimming."""
|
||||||
|
|
||||||
|
from typing import Callable, List, Literal, Optional, Sequence, Union, cast
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_message_type(
|
||||||
|
message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]]
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a message is of a specific type or types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to check
|
||||||
|
message_types: Type(s) to check against (string name or class)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if message matches any of the specified types
|
||||||
|
"""
|
||||||
|
if not isinstance(message_types, list):
|
||||||
|
message_types = [message_types]
|
||||||
|
|
||||||
|
types_str = [t for t in message_types if isinstance(t, str)]
|
||||||
|
types_classes = tuple(t for t in message_types if isinstance(t, type))
|
||||||
|
|
||||||
|
return message.type in types_str or isinstance(message, types_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def has_tool_use(message: BaseMessage) -> bool:
|
||||||
|
"""Check if a message contains tool use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the message contains tool use
|
||||||
|
"""
|
||||||
|
if not isinstance(message, AIMessage):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check content for tool_use
|
||||||
|
if isinstance(message.content, str) and "tool_use" in message.content:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check content list for tool_use blocks
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
for item in message.content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "tool_use":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check additional_kwargs for tool_calls
|
||||||
|
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
|
||||||
|
"""Check if two messages form a tool use/result pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message1: First message
|
||||||
|
message2: Second message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the messages form a tool use/result pair
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
isinstance(message1, AIMessage) and
|
||||||
|
isinstance(message2, ToolMessage) and
|
||||||
|
has_tool_use(message1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage:
|
||||||
|
"""Fix message content format for Anthropic API compatibility."""
|
||||||
|
if not isinstance(message, AIMessage) or not isinstance(message.content, list):
|
||||||
|
return message
|
||||||
|
|
||||||
|
fixed_message = message.model_copy(deep=True)
|
||||||
|
|
||||||
|
# Ensure first block is valid thinking type
|
||||||
|
if fixed_message.content and isinstance(fixed_message.content[0], dict):
|
||||||
|
first_block_type = fixed_message.content[0].get("type")
|
||||||
|
if first_block_type not in ("thinking", "redacted_thinking"):
|
||||||
|
# Prepend redacted_thinking block instead of thinking
|
||||||
|
fixed_message.content.insert(
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
"type": "redacted_thinking",
|
||||||
|
"data": "ENCRYPTED_REASONING", # Required field for redacted_thinking
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure all thinking blocks have valid structure
|
||||||
|
for i, block in enumerate(fixed_message.content):
|
||||||
|
if block.get("type") == "thinking":
|
||||||
|
# Convert thinking blocks to redacted_thinking to avoid signature validation
|
||||||
|
fixed_message.content[i] = {
|
||||||
|
"type": "redacted_thinking",
|
||||||
|
"data": "ENCRYPTED_REASONING",
|
||||||
|
}
|
||||||
|
elif block.get("type") == "redacted_thinking":
|
||||||
|
# Ensure required data field exists
|
||||||
|
if "data" not in block:
|
||||||
|
fixed_message.content[i]["data"] = "ENCRYPTED_REASONING"
|
||||||
|
|
||||||
|
return fixed_message
|
||||||
|
|
||||||
|
|
||||||
|
def anthropic_trim_messages(
|
||||||
|
messages: Sequence[BaseMessage],
|
||||||
|
*,
|
||||||
|
max_tokens: int,
|
||||||
|
token_counter: Callable[[List[BaseMessage]], int],
|
||||||
|
strategy: Literal["first", "last"] = "last",
|
||||||
|
num_messages_to_keep: int = 2,
|
||||||
|
allow_partial: bool = False,
|
||||||
|
include_system: bool = True,
|
||||||
|
start_on: Optional[Union[str, type, List[Union[str, type]]]] = None,
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Trim messages to fit within a token limit, with Anthropic-specific handling.
|
||||||
|
|
||||||
|
This function is similar to langchain_core's trim_messages but with special
|
||||||
|
handling for Anthropic message formats to avoid API errors.
|
||||||
|
|
||||||
|
It always keeps the first num_messages_to_keep messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages to trim
|
||||||
|
max_tokens: Maximum number of tokens allowed
|
||||||
|
token_counter: Function to count tokens in messages
|
||||||
|
strategy: Whether to keep the "first" or "last" messages
|
||||||
|
allow_partial: Whether to allow partial messages
|
||||||
|
include_system: Whether to always include the system message
|
||||||
|
start_on: Message type to start on (only for "last" strategy)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: Trimmed messages that fit within token limit
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
messages = list(messages)
|
||||||
|
|
||||||
|
# Always keep the first num_messages_to_keep messages
|
||||||
|
kept_messages = messages[:num_messages_to_keep]
|
||||||
|
remaining_msgs = messages[num_messages_to_keep:]
|
||||||
|
|
||||||
|
# Debug: Print message types for all messages
|
||||||
|
print("\nDEBUG - All messages:")
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
msg_type = type(msg).__name__
|
||||||
|
tool_use = (
|
||||||
|
"tool_use"
|
||||||
|
if isinstance(msg, AIMessage)
|
||||||
|
and hasattr(msg, "additional_kwargs")
|
||||||
|
and msg.additional_kwargs.get("tool_calls")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
tool_result = (
|
||||||
|
f"tool_call_id: {msg.tool_call_id}"
|
||||||
|
if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
print(f" [{i}] {msg_type} {tool_use} {tool_result}")
|
||||||
|
|
||||||
|
# For Anthropic, we need to maintain the conversation structure where:
|
||||||
|
# 1. Every AIMessage with tool_use must be followed by a ToolMessage
|
||||||
|
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result
|
||||||
|
|
||||||
|
# First, check if we have any tool_use in the messages
|
||||||
|
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
|
||||||
|
print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}")
|
||||||
|
|
||||||
|
# Print debug info for AIMessages
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
print(f"DEBUG - AIMessage[{i}] details:")
|
||||||
|
print(f" has_tool_use: {has_tool_use(msg)}")
|
||||||
|
if hasattr(msg, "additional_kwargs"):
|
||||||
|
print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}")
|
||||||
|
|
||||||
|
# If we have tool_use anywhere, we need to be very careful about trimming
|
||||||
|
if has_tool_use_anywhere:
|
||||||
|
# For safety, just keep all messages if we're under the token limit
|
||||||
|
if token_counter(messages) <= max_tokens:
|
||||||
|
print("DEBUG - All messages fit within token limit, keeping all")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# We need to identify all tool_use/tool_result relationships
|
||||||
|
# First, find all AIMessage+ToolMessage pairs
|
||||||
|
pairs = []
|
||||||
|
i = 0
|
||||||
|
while i < len(messages) - 1:
|
||||||
|
if is_tool_pair(messages[i], messages[i+1]):
|
||||||
|
pairs.append((i, i+1))
|
||||||
|
print(f"DEBUG - Found tool_use pair: ({i}, {i+1})")
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs")
|
||||||
|
|
||||||
|
# For Anthropic, we need to ensure that:
|
||||||
|
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
|
||||||
|
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
|
||||||
|
|
||||||
|
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together
|
||||||
|
# First, identify all complete pairs
|
||||||
|
complete_pairs = []
|
||||||
|
for start, end in pairs:
|
||||||
|
complete_pairs.append((start, end))
|
||||||
|
|
||||||
|
print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs")
|
||||||
|
|
||||||
|
# Now we'll build our result, starting with the kept_messages
|
||||||
|
# But we need to be careful about the first message if it has tool_use
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Check if the last message in kept_messages has tool_use
|
||||||
|
if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]):
|
||||||
|
# We need to find the corresponding ToolMessage
|
||||||
|
for i, (ai_idx, tool_idx) in enumerate(pairs):
|
||||||
|
if messages[ai_idx] is kept_messages[-1]:
|
||||||
|
# Found the pair, add all kept_messages except the last one
|
||||||
|
result.extend(kept_messages[:-1])
|
||||||
|
# Add the AIMessage and ToolMessage as a pair
|
||||||
|
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
# Remove this pair from the list of pairs to process later
|
||||||
|
pairs = pairs[:i] + pairs[i+1:]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If we didn't find a matching pair, just add all kept_messages
|
||||||
|
result.extend(kept_messages)
|
||||||
|
else:
|
||||||
|
# No tool_use in the last kept message, just add all kept_messages
|
||||||
|
result.extend(kept_messages)
|
||||||
|
|
||||||
|
# If we're using the "last" strategy, we'll try to include pairs from the end
|
||||||
|
if strategy == "last":
|
||||||
|
# First collect all pairs we can include within the token limit
|
||||||
|
pairs_to_include = []
|
||||||
|
|
||||||
|
# Process pairs from the end (newest first)
|
||||||
|
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
|
||||||
|
# Try adding this pair
|
||||||
|
test_msgs = result.copy()
|
||||||
|
|
||||||
|
# Add all previously selected pairs
|
||||||
|
for prev_ai_idx, prev_tool_idx in pairs_to_include:
|
||||||
|
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
|
||||||
|
|
||||||
|
# Add this pair
|
||||||
|
test_msgs.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
|
||||||
|
if token_counter(test_msgs) <= max_tokens:
|
||||||
|
# This pair fits, add it to our list
|
||||||
|
pairs_to_include.append((ai_idx, tool_idx))
|
||||||
|
print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})")
|
||||||
|
else:
|
||||||
|
# This pair would exceed the token limit
|
||||||
|
print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Now add the pairs in the correct order
|
||||||
|
# Sort by index to maintain the original conversation flow
|
||||||
|
pairs_to_include.sort(key=lambda x: x[0])
|
||||||
|
for ai_idx, tool_idx in pairs_to_include:
|
||||||
|
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
|
||||||
|
# No need to sort - we've already added messages in the correct order
|
||||||
|
|
||||||
|
print(f"DEBUG - Final result has {len(result)} messages")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If no tool_use, proceed with normal segmentation
|
||||||
|
segments = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
# Group messages into segments
|
||||||
|
while i < len(remaining_msgs):
|
||||||
|
segments.append([remaining_msgs[i]])
|
||||||
|
print(f"DEBUG - Added message as segment: [{i}]")
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
print(f"\nDEBUG - Created {len(segments)} segments")
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
segment_types = [type(msg).__name__ for msg in segment]
|
||||||
|
print(f" Segment {i}: {segment_types}")
|
||||||
|
|
||||||
|
# Now we have segments that maintain the required structure
|
||||||
|
# We'll add segments from the end (for "last" strategy) or beginning (for "first")
|
||||||
|
# until we hit the token limit
|
||||||
|
|
||||||
|
if strategy == "last":
|
||||||
|
# If we have no segments, just return kept_messages
|
||||||
|
if not segments:
|
||||||
|
return kept_messages
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Process segments from the end
|
||||||
|
for i, segment in enumerate(reversed(segments)):
|
||||||
|
# Try adding this segment
|
||||||
|
test_msgs = segment + result
|
||||||
|
|
||||||
|
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
||||||
|
result = segment + result
|
||||||
|
print(f"DEBUG - Added segment {len(segments)-i-1} to result")
|
||||||
|
else:
|
||||||
|
# This segment would exceed the token limit
|
||||||
|
print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping")
|
||||||
|
break
|
||||||
|
|
||||||
|
final_result = kept_messages + result
|
||||||
|
|
||||||
|
# For Anthropic, we need to ensure the conversation follows a valid structure
|
||||||
|
# We'll do a final check of the entire conversation
|
||||||
|
print("\nDEBUG - Final result before validation:")
|
||||||
|
for i, msg in enumerate(final_result):
|
||||||
|
msg_type = type(msg).__name__
|
||||||
|
print(f" [{i}] {msg_type}")
|
||||||
|
|
||||||
|
# Validate the conversation structure
|
||||||
|
valid_result = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
# Process messages in order
|
||||||
|
while i < len(final_result):
|
||||||
|
current_msg = final_result[i]
|
||||||
|
|
||||||
|
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage
|
||||||
|
if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg):
|
||||||
|
if isinstance(final_result[i+1], ToolMessage):
|
||||||
|
# This is a valid tool_use + tool_result pair
|
||||||
|
valid_result.append(current_msg)
|
||||||
|
valid_result.append(final_result[i+1])
|
||||||
|
print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}")
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
# Invalid: AIMessage with tool_use not followed by ToolMessage
|
||||||
|
print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage")
|
||||||
|
# Skip this message to maintain valid structure
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
# Regular message, just add it
|
||||||
|
valid_result.append(current_msg)
|
||||||
|
print(f"DEBUG - Added regular message at position {i}")
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Final check: don't end with an AIMessage that has tool_use
|
||||||
|
if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]):
|
||||||
|
print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage")
|
||||||
|
valid_result.pop() # Remove the last message
|
||||||
|
|
||||||
|
print("\nDEBUG - Final validated result:")
|
||||||
|
for i, msg in enumerate(valid_result):
|
||||||
|
msg_type = type(msg).__name__
|
||||||
|
print(f" [{i}] {msg_type}")
|
||||||
|
|
||||||
|
return valid_result
|
||||||
|
|
||||||
|
elif strategy == "first":
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Process segments from the beginning
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
# Try adding this segment
|
||||||
|
test_msgs = result + segment
|
||||||
|
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
||||||
|
result = result + segment
|
||||||
|
print(f"DEBUG - Added segment {i} to result")
|
||||||
|
else:
|
||||||
|
# This segment would exceed the token limit
|
||||||
|
print(f"DEBUG - Segment {i} would exceed token limit, stopping")
|
||||||
|
break
|
||||||
|
|
||||||
|
final_result = kept_messages + result
|
||||||
|
print("\nDEBUG - Final result:")
|
||||||
|
for i, msg in enumerate(final_result):
|
||||||
|
msg_type = type(msg).__name__
|
||||||
|
print(f" [{i}] {msg_type}")
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
@ -1,11 +1,17 @@
|
||||||
"""Utilities for handling token limits with Anthropic models."""
|
"""Utilities for handling token limits with Anthropic models."""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.messages import BaseMessage, trim_messages
|
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage, trim_messages
|
||||||
from langchain_core.messages.base import message_to_dict
|
from langchain_core.messages.base import message_to_dict
|
||||||
|
|
||||||
|
from ra_aid.anthropic_message_utils import (
|
||||||
|
fix_anthropic_message_content,
|
||||||
|
anthropic_trim_messages,
|
||||||
|
has_tool_use,
|
||||||
|
)
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
|
|
||||||
|
|
@ -13,7 +19,7 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
from ra_aid.console.output import print_messages_compact
|
from ra_aid.console.output import cpm, print_messages_compact
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -85,68 +91,58 @@ def create_token_counter_wrapper(model: str):
|
||||||
def state_modifier(
|
def state_modifier(
|
||||||
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message.
|
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||||
|
|
||||||
|
This uses anthropic_trim_messages which always keeps the first 2 messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: The current agent state containing messages
|
state: The current agent state containing messages
|
||||||
model: The language model to use for token counting
|
model: The language model to use for token counting
|
||||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||||
"""
|
"""
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
first_message = messages[0]
|
|
||||||
remaining_messages = messages[1:]
|
|
||||||
|
|
||||||
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||||
|
|
||||||
print(f"max_input_tokens={max_input_tokens}")
|
# Keep max_input_tokens at 21000 as requested
|
||||||
max_input_tokens = 17000
|
max_input_tokens = 21000
|
||||||
first_tokens = wrapped_token_counter([first_message])
|
|
||||||
print(f"first_tokens={first_tokens}")
|
|
||||||
new_max_tokens = max_input_tokens - first_tokens
|
|
||||||
|
|
||||||
# Calculate total tokens before trimming
|
print("\nDEBUG - Starting token trimming with max_tokens:", max_input_tokens)
|
||||||
total_tokens_before = wrapped_token_counter(messages)
|
print(f"Current token total: {wrapped_token_counter(messages)}")
|
||||||
|
|
||||||
|
# Print more details about the messages to help debug
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
print(f"DEBUG - AIMessage[{i}] content type: {type(msg.content)}")
|
||||||
|
print(f"DEBUG - AIMessage[{i}] has_tool_use: {has_tool_use(msg)}")
|
||||||
|
if has_tool_use(msg) and i < len(messages) - 1:
|
||||||
print(
|
print(
|
||||||
f"Current token total: {total_tokens_before} (should be at least {first_tokens})"
|
f"DEBUG - Next message is ToolMessage: {isinstance(messages[i+1], ToolMessage)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the token count is correct
|
result = anthropic_trim_messages(
|
||||||
if total_tokens_before < first_tokens:
|
messages,
|
||||||
print(f"WARNING: Token count inconsistency detected! Recounting...")
|
|
||||||
# Count message by message to debug
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
msg_tokens = wrapped_token_counter([msg])
|
|
||||||
print(f" Message {i}: {msg_tokens} tokens")
|
|
||||||
# Try alternative counting method
|
|
||||||
alt_count = sum(wrapped_token_counter([msg]) for msg in messages)
|
|
||||||
print(f" Alternative count method: {alt_count} tokens")
|
|
||||||
|
|
||||||
print_messages_compact(messages)
|
|
||||||
|
|
||||||
trimmed_remaining = trim_messages(
|
|
||||||
remaining_messages,
|
|
||||||
token_counter=wrapped_token_counter,
|
token_counter=wrapped_token_counter,
|
||||||
max_tokens=new_max_tokens,
|
max_tokens=max_input_tokens,
|
||||||
strategy="last",
|
strategy="last",
|
||||||
allow_partial=False,
|
allow_partial=False,
|
||||||
|
include_system=True,
|
||||||
|
num_messages_to_keep=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = [first_message] + trimmed_remaining
|
|
||||||
|
|
||||||
# Only show message if some messages were trimmed
|
|
||||||
if len(result) < len(messages):
|
if len(result) < len(messages):
|
||||||
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
|
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
|
||||||
# Calculate total tokens after trimming
|
|
||||||
total_tokens_after = wrapped_token_counter(result)
|
total_tokens_after = wrapped_token_counter(result)
|
||||||
print(f"New token total: {total_tokens_after}")
|
print(f"New token total: {total_tokens_after}")
|
||||||
|
print("BEFORE TRIMMING")
|
||||||
|
print_messages_compact(messages)
|
||||||
|
print("AFTER TRIMMING")
|
||||||
|
print_messages_compact(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -176,12 +172,14 @@ def sonnet_35_state_modifier(
|
||||||
total_tokens_before = estimate_messages_tokens(messages)
|
total_tokens_before = estimate_messages_tokens(messages)
|
||||||
print(f"Current token total: {total_tokens_before}")
|
print(f"Current token total: {total_tokens_before}")
|
||||||
|
|
||||||
trimmed_remaining = trim_messages(
|
# Trim remaining messages
|
||||||
|
trimmed_remaining = anthropic_trim_messages(
|
||||||
remaining_messages,
|
remaining_messages,
|
||||||
token_counter=estimate_messages_tokens,
|
token_counter=estimate_messages_tokens,
|
||||||
max_tokens=new_max_tokens,
|
max_tokens=new_max_tokens,
|
||||||
strategy="last",
|
strategy="last",
|
||||||
allow_partial=False,
|
allow_partial=False,
|
||||||
|
include_system=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = [first_message] + trimmed_remaining
|
result = [first_message] + trimmed_remaining
|
||||||
|
|
@ -193,6 +191,8 @@ def sonnet_35_state_modifier(
|
||||||
total_tokens_after = estimate_messages_tokens(result)
|
total_tokens_after = estimate_messages_tokens(result)
|
||||||
print(f"New token total: {total_tokens_after}")
|
print(f"New token total: {total_tokens_after}")
|
||||||
|
|
||||||
|
# No need to fix message content as anthropic_trim_messages already handles this
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue