refactor(anthropic_token_limiter.py): rename messages_to_dict to message_to_dict for consistency and clarity
feat(anthropic_token_limiter.py): add convert_message_to_litellm_format function to standardize message format for litellm fix(anthropic_token_limiter.py): update wrapped_token_counter to handle only BaseMessage objects and improve token counting logic chore(anthropic_token_limiter.py): add debug print statements to track token counts before and after trimming messages
This commit is contained in:
parent
5c9a1e81d2
commit
09ba1ee0b9
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.messages import BaseMessage, trim_messages
|
from langchain_core.messages import BaseMessage, trim_messages
|
||||||
from langchain_core.messages.base import messages_to_dict
|
from langchain_core.messages.base import message_to_dict
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
|
|
||||||
|
|
@ -34,38 +34,51 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
return sum(estimate_tokens(msg) for msg in messages)
|
return sum(estimate_tokens(msg) for msg in messages)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_message_to_litellm_format(message: BaseMessage) -> Dict:
|
||||||
|
"""Convert a BaseMessage to the format expected by litellm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The BaseMessage to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict in litellm format
|
||||||
|
"""
|
||||||
|
message_dict = message_to_dict(message)
|
||||||
|
return {
|
||||||
|
"role": message_dict["type"],
|
||||||
|
"content": message_dict["data"]["content"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_token_counter_wrapper(model: str):
|
def create_token_counter_wrapper(model: str):
|
||||||
"""Create a wrapper for token counter that handles BaseMessage conversion.
|
"""Create a wrapper for token counter that handles BaseMessage conversion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model name to use for token counting
|
model: The model name to use for token counting
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A function that accepts BaseMessage objects and returns token count
|
A function that accepts BaseMessage objects and returns token count
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create a partial function that already has the model parameter set
|
# Create a partial function that already has the model parameter set
|
||||||
base_token_counter = partial(token_counter, model=model)
|
base_token_counter = partial(token_counter, model=model)
|
||||||
|
|
||||||
def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int:
|
def wrapped_token_counter(messages: List[BaseMessage]) -> int:
|
||||||
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
|
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of messages (either BaseMessage objects or dicts)
|
messages: List of BaseMessage objects
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Token count for the messages
|
Token count for the messages
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if isinstance(messages[0], BaseMessage):
|
litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages]
|
||||||
messages_dicts = [msg["data"] for msg in messages_to_dict(messages)]
|
result = base_token_counter(messages=litellm_messages)
|
||||||
return base_token_counter(messages=messages_dicts)
|
return result
|
||||||
else:
|
|
||||||
# Already in dict format
|
|
||||||
return base_token_counter(messages=messages)
|
|
||||||
|
|
||||||
return wrapped_token_counter
|
return wrapped_token_counter
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -90,12 +103,31 @@ def state_modifier(
|
||||||
first_message = messages[0]
|
first_message = messages[0]
|
||||||
remaining_messages = messages[1:]
|
remaining_messages = messages[1:]
|
||||||
|
|
||||||
|
|
||||||
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||||
|
|
||||||
|
print(f"max_input_tokens={max_input_tokens}")
|
||||||
|
max_input_tokens = 17000
|
||||||
first_tokens = wrapped_token_counter([first_message])
|
first_tokens = wrapped_token_counter([first_message])
|
||||||
|
print(f"first_tokens={first_tokens}")
|
||||||
new_max_tokens = max_input_tokens - first_tokens
|
new_max_tokens = max_input_tokens - first_tokens
|
||||||
|
|
||||||
|
# Calculate total tokens before trimming
|
||||||
|
total_tokens_before = wrapped_token_counter(messages)
|
||||||
|
print(
|
||||||
|
f"Current token total: {total_tokens_before} (should be at least {first_tokens})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the token count is correct
|
||||||
|
if total_tokens_before < first_tokens:
|
||||||
|
print(f"WARNING: Token count inconsistency detected! Recounting...")
|
||||||
|
# Count message by message to debug
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
msg_tokens = wrapped_token_counter([msg])
|
||||||
|
print(f" Message {i}: {msg_tokens} tokens")
|
||||||
|
# Try alternative counting method
|
||||||
|
alt_count = sum(wrapped_token_counter([msg]) for msg in messages)
|
||||||
|
print(f" Alternative count method: {alt_count} tokens")
|
||||||
|
|
||||||
print_messages_compact(messages)
|
print_messages_compact(messages)
|
||||||
|
|
||||||
trimmed_remaining = trim_messages(
|
trimmed_remaining = trim_messages(
|
||||||
|
|
@ -106,10 +138,19 @@ def state_modifier(
|
||||||
allow_partial=False,
|
allow_partial=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [first_message] + trimmed_remaining
|
result = [first_message] + trimmed_remaining
|
||||||
|
|
||||||
|
# Only show message if some messages were trimmed
|
||||||
|
if len(result) < len(messages):
|
||||||
|
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
|
||||||
|
# Calculate total tokens after trimming
|
||||||
|
total_tokens_after = wrapped_token_counter(result)
|
||||||
|
print(f"New token total: {total_tokens_after}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def sonnet_3_5_state_modifier(
|
def sonnet_35_state_modifier(
|
||||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||||
|
|
@ -131,6 +172,10 @@ def sonnet_3_5_state_modifier(
|
||||||
first_tokens = estimate_messages_tokens([first_message])
|
first_tokens = estimate_messages_tokens([first_message])
|
||||||
new_max_tokens = max_input_tokens - first_tokens
|
new_max_tokens = max_input_tokens - first_tokens
|
||||||
|
|
||||||
|
# Calculate total tokens before trimming
|
||||||
|
total_tokens_before = estimate_messages_tokens(messages)
|
||||||
|
print(f"Current token total: {total_tokens_before}")
|
||||||
|
|
||||||
trimmed_remaining = trim_messages(
|
trimmed_remaining = trim_messages(
|
||||||
remaining_messages,
|
remaining_messages,
|
||||||
token_counter=estimate_messages_tokens,
|
token_counter=estimate_messages_tokens,
|
||||||
|
|
@ -139,7 +184,16 @@ def sonnet_3_5_state_modifier(
|
||||||
allow_partial=False,
|
allow_partial=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [first_message] + trimmed_remaining
|
result = [first_message] + trimmed_remaining
|
||||||
|
|
||||||
|
# Only show message if some messages were trimmed
|
||||||
|
if len(result) < len(messages):
|
||||||
|
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
|
||||||
|
# Calculate total tokens after trimming
|
||||||
|
total_tokens_after = estimate_messages_tokens(result)
|
||||||
|
print(f"New token total: {total_tokens_after}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
def get_model_token_limit(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue