From f1274b3164ff16325c9c3859ee250862883ac050 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Sat, 15 Mar 2025 09:37:26 -0700 Subject: [PATCH] refactor(anthropic_token_limiter.py): update model parameter type in state_modifier to BaseChatModel for better compatibility feat(anthropic_token_limiter.py): add get_model_name_from_chat_model function to extract model name from BaseChatModel instances style(anthropic_token_limiter.py): format code for better readability and consistency in function definitions and logging messages --- ra_aid/anthropic_token_limiter.py | 64 +++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index 794459a..55b5b79 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -3,9 +3,9 @@ from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple from langchain_core.language_models import BaseChatModel +from ra_aid.config import DEFAULT_MODEL from ra_aid.model_detection import is_claude_37 -from langchain_anthropic import ChatAnthropic from langchain_core.messages import ( BaseMessage, trim_messages, @@ -91,7 +91,7 @@ def create_token_counter_wrapper(model: str): def state_modifier( - state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT + state: AgentState, model: BaseChatModel, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: """Given the agent state and max_tokens, return a trimmed list of messages. @@ -110,7 +110,8 @@ def state_modifier( if not messages: return [] - wrapped_token_counter = create_token_counter_wrapper(model.model) + model_name = get_model_name_from_chat_model(model) + wrapped_token_counter = create_token_counter_wrapper(model_name) result = anthropic_trim_messages( messages, @@ -123,7 +124,9 @@ def state_modifier( ) if len(result) < len(messages): - logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages") + logger.info( + f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages" + ) return result @@ -164,13 +167,15 @@ def sonnet_35_state_modifier( return result -def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]: +def get_provider_and_model_for_agent_type( + config: Dict[str, Any], agent_type: str +) -> Tuple[str, str]: """Get the provider and model name for the specified agent type. - + Args: config: Configuration dictionary containing provider and model information agent_type: Type of agent ("default", "research", or "planner") - + Returns: Tuple[str, str]: A tuple containing (provider, model_name) """ @@ -183,36 +188,61 @@ def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: st else: provider = config.get("provider", "") model_name = config.get("model", "") - + return provider, model_name -def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]: +def get_model_name_from_chat_model(model: Optional[BaseChatModel]) -> str: + """Extract the model name from a BaseChatModel instance. + + Args: + model: The BaseChatModel instance + + Returns: + str: The model name extracted from the instance, or DEFAULT_MODEL if not found + """ + if model is None: + return DEFAULT_MODEL + + if hasattr(model, "model"): + return model.model + elif hasattr(model, "model_name"): + return model.model_name + else: + logger.debug(f"Could not extract model name from {model}, using DEFAULT_MODEL") + return DEFAULT_MODEL + + +def adjust_claude_37_token_limit( + max_input_tokens: int, model: Optional[BaseChatModel] +) -> Optional[int]: """Adjust token limit for Claude 3.7 models by subtracting max_tokens. - + Args: max_input_tokens: The original token limit model: The model instance to check - + Returns: Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit """ if not max_input_tokens: return max_input_tokens - - if model and hasattr(model, 'model') and is_claude_37(model.model): - if hasattr(model, 'max_tokens') and model.max_tokens: + + if model and hasattr(model, "model") and is_claude_37(model.model): + if hasattr(model, "max_tokens") and model.max_tokens: effective_max_input_tokens = max_input_tokens - model.max_tokens logger.debug( f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}" ) return effective_max_input_tokens - + return max_input_tokens def get_model_token_limit( - config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None + config: Dict[str, Any], + agent_type: str = "default", + model: Optional[BaseChatModel] = None, ) -> Optional[int]: """Get the token limit for the current model configuration based on agent type. @@ -234,7 +264,7 @@ def get_model_token_limit( # In tests, this may fail because the repository isn't set up # So we'll use the passed config directly pass - + provider, model_name = get_provider_and_model_for_agent_type(config, agent_type) # Always attempt to get model info from litellm first