RA.Aid/ra_aid/anthropic_token_limiter.py

273 lines
9.2 KiB
Python

"""Utilities for handling token limits with Anthropic models."""
from functools import partial
from typing import Any, Dict, List, Optional, Sequence
from dataclasses import dataclass
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import (
AIMessage,
BaseMessage,
RemoveMessage,
ToolMessage,
trim_messages,
)
from langchain_core.messages.base import message_to_dict
from ra_aid.anthropic_message_utils import (
fix_anthropic_message_content,
anthropic_trim_messages,
has_tool_use,
)
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import token_counter
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.console.output import cpm, print_messages_compact
logger = get_logger(__name__)
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
"""Helper function to estimate total tokens in a sequence of messages.
Args:
messages: Sequence of messages to count tokens for
Returns:
Total estimated token count
"""
if not messages:
return 0
estimate_tokens = CiaynAgent._estimate_tokens
return sum(estimate_tokens(msg) for msg in messages)
def convert_message_to_litellm_format(message: BaseMessage) -> Dict:
"""Convert a BaseMessage to the format expected by litellm.
Args:
message: The BaseMessage to convert
Returns:
Dict in litellm format
"""
message_dict = message_to_dict(message)
return {
"role": message_dict["type"],
"content": message_dict["data"]["content"],
}
def create_token_counter_wrapper(model: str):
"""Create a wrapper for token counter that handles BaseMessage conversion.
Args:
model: The model name to use for token counting
Returns:
A function that accepts BaseMessage objects and returns token count
"""
# Create a partial function that already has the model parameter set
base_token_counter = partial(token_counter, model=model)
def wrapped_token_counter(messages: List[BaseMessage]) -> int:
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
Args:
messages: List of BaseMessage objects
Returns:
Token count for the messages
"""
if not messages:
return 0
litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages]
result = base_token_counter(messages=litellm_messages)
return result
return wrapped_token_counter
def state_modifier(
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
This uses anthropic_trim_messages which always keeps the first 2 messages.
Args:
state: The current agent state containing messages
model: The language model to use for token counting
max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
wrapped_token_counter = create_token_counter_wrapper(model.model)
# Keep max_input_tokens at 21000 as requested
max_input_tokens = 21000
print("\nDEBUG - Starting token trimming with max_tokens:", max_input_tokens)
print(f"Current token total: {wrapped_token_counter(messages)}")
# Print more details about the messages to help debug
for i, msg in enumerate(messages):
if isinstance(msg, AIMessage):
print(f"DEBUG - AIMessage[{i}] content type: {type(msg.content)}")
print(f"DEBUG - AIMessage[{i}] has_tool_use: {has_tool_use(msg)}")
if has_tool_use(msg) and i < len(messages) - 1:
print(
f"DEBUG - Next message is ToolMessage: {isinstance(messages[i+1], ToolMessage)}"
)
result = anthropic_trim_messages(
messages,
token_counter=wrapped_token_counter,
max_tokens=max_input_tokens,
strategy="last",
allow_partial=False,
include_system=True,
num_messages_to_keep=2,
)
if len(result) < len(messages):
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
total_tokens_after = wrapped_token_counter(result)
print(f"New token total: {total_tokens_after}")
print("BEFORE TRIMMING")
print_messages_compact(messages)
print("AFTER TRIMMING")
print_messages_compact(result)
return result
def sonnet_35_state_modifier(
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
Args:
state: The current agent state containing messages
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_input_tokens - first_tokens
# Calculate total tokens before trimming
total_tokens_before = estimate_messages_tokens(messages)
print(f"Current token total: {total_tokens_before}")
# Trim remaining messages
trimmed_remaining = anthropic_trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
include_system=True,
)
result = [first_message] + trimmed_remaining
# Only show message if some messages were trimmed
if len(result) < len(messages):
print(f"TRIMMED: {len(messages)} messages → {len(result)} messages")
# Calculate total tokens after trimming
total_tokens_after = estimate_messages_tokens(result)
print(f"New token total: {total_tokens_after}")
# No need to fix message content as anthropic_trim_messages already handles this
return result
def get_model_token_limit(
config: Dict[str, Any], agent_type: str = "default"
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.
Args:
config: Configuration dictionary containing provider and model information
agent_type: Type of agent ("default", "research", or "planner")
Returns:
Optional[int]: The token limit if found, None otherwise
"""
try:
# Try to get config from repository for production use
try:
config_from_repo = get_config_repository().get_all()
# If we succeeded, use the repository config instead of passed config
config = config_from_repo
except RuntimeError:
# In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly
pass
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
try:
from litellm import get_model_info
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:
logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}"
)
return max_input_tokens
except Exception as e:
logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_params"
)
# Fallback to models_params dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
normalized_name = model_name.replace("-", "")
provider_tokens = models_params.get(provider, {})
if normalized_name in provider_tokens:
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
logger.debug(
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
)
else:
max_input_tokens = None
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return max_input_tokens
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
return None