"""Utilities for handling token limits with Anthropic models.""" from functools import partial from typing import Any, Dict, List, Optional, Sequence, Union from langchain_anthropic import ChatAnthropic from langchain_core.messages import BaseMessage, trim_messages from langchain_core.messages.base import messages_to_dict from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import token_counter from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params from ra_aid.console.output import print_messages_compact logger = get_logger(__name__) def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: """Helper function to estimate total tokens in a sequence of messages. Args: messages: Sequence of messages to count tokens for Returns: Total estimated token count """ if not messages: return 0 estimate_tokens = CiaynAgent._estimate_tokens return sum(estimate_tokens(msg) for msg in messages) def create_token_counter_wrapper(model: str): """Create a wrapper for token counter that handles BaseMessage conversion. Args: model: The model name to use for token counting Returns: A function that accepts BaseMessage objects and returns token count """ # Create a partial function that already has the model parameter set base_token_counter = partial(token_counter, model=model) def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int: """Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage. Args: messages: List of messages (either BaseMessage objects or dicts) Returns: Token count for the messages """ if not messages: return 0 if isinstance(messages[0], BaseMessage): messages_dicts = [msg["data"] for msg in messages_to_dict(messages)] return base_token_counter(messages=messages_dicts) else: # Already in dict format return base_token_counter(messages=messages) return wrapped_token_counter def state_modifier( state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: """Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message. Args: state: The current agent state containing messages model: The language model to use for token counting max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) Returns: list[BaseMessage]: Trimmed list of messages that fits within token limit """ messages = state["messages"] if not messages: return [] first_message = messages[0] remaining_messages = messages[1:] wrapped_token_counter = create_token_counter_wrapper(model.model) first_tokens = wrapped_token_counter([first_message]) new_max_tokens = max_input_tokens - first_tokens print_messages_compact(messages) trimmed_remaining = trim_messages( remaining_messages, token_counter=wrapped_token_counter, max_tokens=new_max_tokens, strategy="last", allow_partial=False, ) return [first_message] + trimmed_remaining def sonnet_3_5_state_modifier( state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: """Given the agent state and max_tokens, return a trimmed list of messages. Args: state: The current agent state containing messages max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) Returns: list[BaseMessage]: Trimmed list of messages that fits within token limit """ messages = state["messages"] if not messages: return [] first_message = messages[0] remaining_messages = messages[1:] first_tokens = estimate_messages_tokens([first_message]) new_max_tokens = max_input_tokens - first_tokens trimmed_remaining = trim_messages( remaining_messages, token_counter=estimate_messages_tokens, max_tokens=new_max_tokens, strategy="last", allow_partial=False, ) return [first_message] + trimmed_remaining def get_model_token_limit( config: Dict[str, Any], agent_type: str = "default" ) -> Optional[int]: """Get the token limit for the current model configuration based on agent type. Args: config: Configuration dictionary containing provider and model information agent_type: Type of agent ("default", "research", or "planner") Returns: Optional[int]: The token limit if found, None otherwise """ try: # Try to get config from repository for production use try: config_from_repo = get_config_repository().get_all() # If we succeeded, use the repository config instead of passed config config = config_from_repo except RuntimeError: # In tests, this may fail because the repository isn't set up # So we'll use the passed config directly pass if agent_type == "research": provider = config.get("research_provider", "") or config.get("provider", "") model_name = config.get("research_model", "") or config.get("model", "") elif agent_type == "planner": provider = config.get("planner_provider", "") or config.get("provider", "") model_name = config.get("planner_model", "") or config.get("model", "") else: provider = config.get("provider", "") model_name = config.get("model", "") try: from litellm import get_model_info provider_model = model_name if not provider else f"{provider}/{model_name}" model_info = get_model_info(provider_model) max_input_tokens = model_info.get("max_input_tokens") if max_input_tokens: logger.debug( f"Using litellm token limit for {model_name}: {max_input_tokens}" ) return max_input_tokens except Exception as e: logger.debug( f"Error getting model info from litellm: {e}, falling back to models_params" ) # Fallback to models_params dict # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) normalized_name = model_name.replace("-", "") provider_tokens = models_params.get(provider, {}) if normalized_name in provider_tokens: max_input_tokens = provider_tokens[normalized_name]["token_limit"] logger.debug( f"Found token limit for {provider}/{model_name}: {max_input_tokens}" ) else: max_input_tokens = None logger.debug(f"Could not find token limit for {provider}/{model_name}") return max_input_tokens except Exception as e: logger.warning(f"Failed to get model token limit: {e}") return None