diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index 97f235c..fc6d3ac 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union from langchain_anthropic import ChatAnthropic from langchain_core.messages import BaseMessage, trim_messages -from langchain_core.messages.base import messages_to_dict +from langchain_core.messages.base import message_to_dict from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import token_counter @@ -34,38 +34,51 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: return sum(estimate_tokens(msg) for msg in messages) +def convert_message_to_litellm_format(message: BaseMessage) -> Dict: + """Convert a BaseMessage to the format expected by litellm. + + Args: + message: The BaseMessage to convert + + Returns: + Dict in litellm format + """ + message_dict = message_to_dict(message) + return { + "role": message_dict["type"], + "content": message_dict["data"]["content"], + } + + def create_token_counter_wrapper(model: str): """Create a wrapper for token counter that handles BaseMessage conversion. - + Args: model: The model name to use for token counting - + Returns: A function that accepts BaseMessage objects and returns token count """ - + # Create a partial function that already has the model parameter set base_token_counter = partial(token_counter, model=model) - - def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int: + + def wrapped_token_counter(messages: List[BaseMessage]) -> int: """Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage. - + Args: - messages: List of messages (either BaseMessage objects or dicts) - + messages: List of BaseMessage objects + Returns: Token count for the messages """ if not messages: return 0 - - if isinstance(messages[0], BaseMessage): - messages_dicts = [msg["data"] for msg in messages_to_dict(messages)] - return base_token_counter(messages=messages_dicts) - else: - # Already in dict format - return base_token_counter(messages=messages) - + + litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages] + result = base_token_counter(messages=litellm_messages) + return result + return wrapped_token_counter @@ -90,12 +103,31 @@ def state_modifier( first_message = messages[0] remaining_messages = messages[1:] - wrapped_token_counter = create_token_counter_wrapper(model.model) - + + print(f"max_input_tokens={max_input_tokens}") + max_input_tokens = 17000 first_tokens = wrapped_token_counter([first_message]) + print(f"first_tokens={first_tokens}") new_max_tokens = max_input_tokens - first_tokens + # Calculate total tokens before trimming + total_tokens_before = wrapped_token_counter(messages) + print( + f"Current token total: {total_tokens_before} (should be at least {first_tokens})" + ) + + # Verify the token count is correct + if total_tokens_before < first_tokens: + print(f"WARNING: Token count inconsistency detected! Recounting...") + # Count message by message to debug + for i, msg in enumerate(messages): + msg_tokens = wrapped_token_counter([msg]) + print(f" Message {i}: {msg_tokens} tokens") + # Try alternative counting method + alt_count = sum(wrapped_token_counter([msg]) for msg in messages) + print(f" Alternative count method: {alt_count} tokens") + print_messages_compact(messages) trimmed_remaining = trim_messages( @@ -106,10 +138,19 @@ def state_modifier( allow_partial=False, ) - return [first_message] + trimmed_remaining + result = [first_message] + trimmed_remaining + + # Only show message if some messages were trimmed + if len(result) < len(messages): + print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") + # Calculate total tokens after trimming + total_tokens_after = wrapped_token_counter(result) + print(f"New token total: {total_tokens_after}") + + return result -def sonnet_3_5_state_modifier( +def sonnet_35_state_modifier( state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: """Given the agent state and max_tokens, return a trimmed list of messages. @@ -131,6 +172,10 @@ def sonnet_3_5_state_modifier( first_tokens = estimate_messages_tokens([first_message]) new_max_tokens = max_input_tokens - first_tokens + # Calculate total tokens before trimming + total_tokens_before = estimate_messages_tokens(messages) + print(f"Current token total: {total_tokens_before}") + trimmed_remaining = trim_messages( remaining_messages, token_counter=estimate_messages_tokens, @@ -139,7 +184,16 @@ def sonnet_3_5_state_modifier( allow_partial=False, ) - return [first_message] + trimmed_remaining + result = [first_message] + trimmed_remaining + + # Only show message if some messages were trimmed + if len(result) < len(messages): + print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") + # Calculate total tokens after trimming + total_tokens_after = estimate_messages_tokens(result) + print(f"New token total: {total_tokens_after}") + + return result def get_model_token_limit(