refactor(anthropic_token_limiter.py): rename messages_to_dict to message_to_dict for consistency and clarity
feat(anthropic_token_limiter.py): add convert_message_to_litellm_format function to standardize message format for litellm fix(anthropic_token_limiter.py): update wrapped_token_counter to handle only BaseMessage objects and improve token counting logic chore(anthropic_token_limiter.py): add debug print statements to track token counts before and after trimming messages
This commit is contained in:
parent
5c9a1e81d2
commit
09ba1ee0b9
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
|||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import BaseMessage, trim_messages
|
||||
from langchain_core.messages.base import messages_to_dict
|
||||
from langchain_core.messages.base import message_to_dict
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import token_counter
|
||||
|
||||
|
|
@ -34,38 +34,51 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
|||
return sum(estimate_tokens(msg) for msg in messages)
|
||||
|
||||
|
||||
def convert_message_to_litellm_format(message: BaseMessage) -> Dict:
|
||||
"""Convert a BaseMessage to the format expected by litellm.
|
||||
|
||||
Args:
|
||||
message: The BaseMessage to convert
|
||||
|
||||
Returns:
|
||||
Dict in litellm format
|
||||
"""
|
||||
message_dict = message_to_dict(message)
|
||||
return {
|
||||
"role": message_dict["type"],
|
||||
"content": message_dict["data"]["content"],
|
||||
}
|
||||
|
||||
|
||||
def create_token_counter_wrapper(model: str):
|
||||
"""Create a wrapper for token counter that handles BaseMessage conversion.
|
||||
|
||||
|
||||
Args:
|
||||
model: The model name to use for token counting
|
||||
|
||||
|
||||
Returns:
|
||||
A function that accepts BaseMessage objects and returns token count
|
||||
"""
|
||||
|
||||
|
||||
# Create a partial function that already has the model parameter set
|
||||
base_token_counter = partial(token_counter, model=model)
|
||||
|
||||
def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int:
|
||||
|
||||
def wrapped_token_counter(messages: List[BaseMessage]) -> int:
|
||||
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of messages (either BaseMessage objects or dicts)
|
||||
|
||||
messages: List of BaseMessage objects
|
||||
|
||||
Returns:
|
||||
Token count for the messages
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
if isinstance(messages[0], BaseMessage):
|
||||
messages_dicts = [msg["data"] for msg in messages_to_dict(messages)]
|
||||
return base_token_counter(messages=messages_dicts)
|
||||
else:
|
||||
# Already in dict format
|
||||
return base_token_counter(messages=messages)
|
||||
|
||||
|
||||
litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages]
|
||||
result = base_token_counter(messages=litellm_messages)
|
||||
return result
|
||||
|
||||
return wrapped_token_counter
|
||||
|
||||
|
||||
|
|
@ -90,12 +103,31 @@ def state_modifier(
|
|||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
|
||||
|
||||
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||
|
||||
|
||||
print(f"max_input_tokens={max_input_tokens}")
|
||||
max_input_tokens = 17000
|
||||
first_tokens = wrapped_token_counter([first_message])
|
||||
print(f"first_tokens={first_tokens}")
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
# Calculate total tokens before trimming
|
||||
total_tokens_before = wrapped_token_counter(messages)
|
||||
print(
|
||||
f"Current token total: {total_tokens_before} (should be at least {first_tokens})"
|
||||
)
|
||||
|
||||
# Verify the token count is correct
|
||||
if total_tokens_before < first_tokens:
|
||||
print(f"WARNING: Token count inconsistency detected! Recounting...")
|
||||
# Count message by message to debug
|
||||
for i, msg in enumerate(messages):
|
||||
msg_tokens = wrapped_token_counter([msg])
|
||||
print(f" Message {i}: {msg_tokens} tokens")
|
||||
# Try alternative counting method
|
||||
alt_count = sum(wrapped_token_counter([msg]) for msg in messages)
|
||||
print(f" Alternative count method: {alt_count} tokens")
|
||||
|
||||
print_messages_compact(messages)
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
|
|
@ -106,10 +138,19 @@ def state_modifier(
|
|||
allow_partial=False,
|
||||
)
|
||||
|
||||
return [first_message] + trimmed_remaining
|
||||
result = [first_message] + trimmed_remaining
|
||||
|
||||
# Only show message if some messages were trimmed
|
||||
if len(result) < len(messages):
|
||||
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
|
||||
# Calculate total tokens after trimming
|
||||
total_tokens_after = wrapped_token_counter(result)
|
||||
print(f"New token total: {total_tokens_after}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def sonnet_3_5_state_modifier(
|
||||
def sonnet_35_state_modifier(
|
||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
|
@ -131,6 +172,10 @@ def sonnet_3_5_state_modifier(
|
|||
first_tokens = estimate_messages_tokens([first_message])
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
# Calculate total tokens before trimming
|
||||
total_tokens_before = estimate_messages_tokens(messages)
|
||||
print(f"Current token total: {total_tokens_before}")
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
token_counter=estimate_messages_tokens,
|
||||
|
|
@ -139,7 +184,16 @@ def sonnet_3_5_state_modifier(
|
|||
allow_partial=False,
|
||||
)
|
||||
|
||||
return [first_message] + trimmed_remaining
|
||||
result = [first_message] + trimmed_remaining
|
||||
|
||||
# Only show message if some messages were trimmed
|
||||
if len(result) < len(messages):
|
||||
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
|
||||
# Calculate total tokens after trimming
|
||||
total_tokens_after = estimate_messages_tokens(result)
|
||||
print(f"New token total: {total_tokens_after}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
|
|
|
|||
Loading…
Reference in New Issue