Merge pull request #136 from ariel-frischer/fix-undefined-model-2

Fix undefined model.model when using openrouter sonnet 3.7
This commit is contained in:
Andrew I. Christianson 2025-03-15 12:41:09 -04:00 committed by GitHub
commit cde8eee4fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 47 additions and 17 deletions

View File

@ -3,9 +3,9 @@
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple
from langchain_core.language_models import BaseChatModel
from ra_aid.config import DEFAULT_MODEL
from ra_aid.model_detection import is_claude_37
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import (
BaseMessage,
trim_messages,
@ -91,7 +91,7 @@ def create_token_counter_wrapper(model: str):
def state_modifier(
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
state: AgentState, model: BaseChatModel, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
@ -110,7 +110,8 @@ def state_modifier(
if not messages:
return []
wrapped_token_counter = create_token_counter_wrapper(model.model)
model_name = get_model_name_from_chat_model(model)
wrapped_token_counter = create_token_counter_wrapper(model_name)
result = anthropic_trim_messages(
messages,
@ -123,7 +124,9 @@ def state_modifier(
)
if len(result) < len(messages):
logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages")
logger.info(
f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages"
)
return result
@ -164,7 +167,9 @@ def sonnet_35_state_modifier(
return result
def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]:
def get_provider_and_model_for_agent_type(
config: Dict[str, Any], agent_type: str
) -> Tuple[str, str]:
"""Get the provider and model name for the specified agent type.
Args:
@ -187,7 +192,30 @@ def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: st
return provider, model_name
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
def get_model_name_from_chat_model(model: Optional[BaseChatModel]) -> str:
"""Extract the model name from a BaseChatModel instance.
Args:
model: The BaseChatModel instance
Returns:
str: The model name extracted from the instance, or DEFAULT_MODEL if not found
"""
if model is None:
return DEFAULT_MODEL
if hasattr(model, "model"):
return model.model
elif hasattr(model, "model_name"):
return model.model_name
else:
logger.debug(f"Could not extract model name from {model}, using DEFAULT_MODEL")
return DEFAULT_MODEL
def adjust_claude_37_token_limit(
max_input_tokens: int, model: Optional[BaseChatModel]
) -> Optional[int]:
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
Args:
@ -200,8 +228,8 @@ def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChat
if not max_input_tokens:
return max_input_tokens
if model and hasattr(model, 'model') and is_claude_37(model.model):
if hasattr(model, 'max_tokens') and model.max_tokens:
if model and hasattr(model, "model") and is_claude_37(model.model):
if hasattr(model, "max_tokens") and model.max_tokens:
effective_max_input_tokens = max_input_tokens - model.max_tokens
logger.debug(
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
@ -212,7 +240,9 @@ def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChat
def get_model_token_limit(
config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None
config: Dict[str, Any],
agent_type: str = "default",
model: Optional[BaseChatModel] = None,
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.