refactor(anthropic_token_limiter.py): rename messages_to_dict to message_to_dict for consistency and clarity

feat(anthropic_token_limiter.py): add convert_message_to_litellm_format function to standardize message format for litellm
fix(anthropic_token_limiter.py): update wrapped_token_counter to handle only BaseMessage objects and improve token counting logic
chore(anthropic_token_limiter.py): add debug print statements to track token counts before and after trimming messages
This commit is contained in:
Ariel Frischer 2025-03-11 21:26:57 -07:00
parent 5c9a1e81d2
commit 09ba1ee0b9
1 changed files with 76 additions and 22 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage, trim_messages
from langchain_core.messages.base import messages_to_dict
from langchain_core.messages.base import message_to_dict
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import token_counter
@ -34,38 +34,51 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
return sum(estimate_tokens(msg) for msg in messages)
def convert_message_to_litellm_format(message: BaseMessage) -> Dict:
"""Convert a BaseMessage to the format expected by litellm.
Args:
message: The BaseMessage to convert
Returns:
Dict in litellm format
"""
message_dict = message_to_dict(message)
return {
"role": message_dict["type"],
"content": message_dict["data"]["content"],
}
def create_token_counter_wrapper(model: str):
"""Create a wrapper for token counter that handles BaseMessage conversion.
Args:
model: The model name to use for token counting
Returns:
A function that accepts BaseMessage objects and returns token count
"""
# Create a partial function that already has the model parameter set
base_token_counter = partial(token_counter, model=model)
def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int:
def wrapped_token_counter(messages: List[BaseMessage]) -> int:
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
Args:
messages: List of messages (either BaseMessage objects or dicts)
messages: List of BaseMessage objects
Returns:
Token count for the messages
"""
if not messages:
return 0
if isinstance(messages[0], BaseMessage):
messages_dicts = [msg["data"] for msg in messages_to_dict(messages)]
return base_token_counter(messages=messages_dicts)
else:
# Already in dict format
return base_token_counter(messages=messages)
litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages]
result = base_token_counter(messages=litellm_messages)
return result
return wrapped_token_counter
@ -90,12 +103,31 @@ def state_modifier(
first_message = messages[0]
remaining_messages = messages[1:]
wrapped_token_counter = create_token_counter_wrapper(model.model)
print(f"max_input_tokens={max_input_tokens}")
max_input_tokens = 17000
first_tokens = wrapped_token_counter([first_message])
print(f"first_tokens={first_tokens}")
new_max_tokens = max_input_tokens - first_tokens
# Calculate total tokens before trimming
total_tokens_before = wrapped_token_counter(messages)
print(
f"Current token total: {total_tokens_before} (should be at least {first_tokens})"
)
# Verify the token count is correct
if total_tokens_before < first_tokens:
print(f"WARNING: Token count inconsistency detected! Recounting...")
# Count message by message to debug
for i, msg in enumerate(messages):
msg_tokens = wrapped_token_counter([msg])
print(f" Message {i}: {msg_tokens} tokens")
# Try alternative counting method
alt_count = sum(wrapped_token_counter([msg]) for msg in messages)
print(f" Alternative count method: {alt_count} tokens")
print_messages_compact(messages)
trimmed_remaining = trim_messages(
@ -106,10 +138,19 @@ def state_modifier(
allow_partial=False,
)
return [first_message] + trimmed_remaining
result = [first_message] + trimmed_remaining
# Only show message if some messages were trimmed
if len(result) < len(messages):
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
# Calculate total tokens after trimming
total_tokens_after = wrapped_token_counter(result)
print(f"New token total: {total_tokens_after}")
return result
def sonnet_3_5_state_modifier(
def sonnet_35_state_modifier(
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
@ -131,6 +172,10 @@ def sonnet_3_5_state_modifier(
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_input_tokens - first_tokens
# Calculate total tokens before trimming
total_tokens_before = estimate_messages_tokens(messages)
print(f"Current token total: {total_tokens_before}")
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
@ -139,7 +184,16 @@ def sonnet_3_5_state_modifier(
allow_partial=False,
)
return [first_message] + trimmed_remaining
result = [first_message] + trimmed_remaining
# Only show message if some messages were trimmed
if len(result) < len(messages):
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
# Calculate total tokens after trimming
total_tokens_after = estimate_messages_tokens(result)
print(f"New token total: {total_tokens_after}")
return result
def get_model_token_limit(