Merge pull request #136 from ariel-frischer/fix-undefined-model-2
Fix undefined model.model when using openrouter sonnet 3.7
This commit is contained in:
commit
cde8eee4fa
|
|
@ -3,9 +3,9 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
from ra_aid.model_detection import is_claude_37
|
from ra_aid.model_detection import is_claude_37
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
trim_messages,
|
trim_messages,
|
||||||
|
|
@ -91,7 +91,7 @@ def create_token_counter_wrapper(model: str):
|
||||||
|
|
||||||
|
|
||||||
def state_modifier(
|
def state_modifier(
|
||||||
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
state: AgentState, model: BaseChatModel, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||||
|
|
||||||
|
|
@ -110,7 +110,8 @@ def state_modifier(
|
||||||
if not messages:
|
if not messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
model_name = get_model_name_from_chat_model(model)
|
||||||
|
wrapped_token_counter = create_token_counter_wrapper(model_name)
|
||||||
|
|
||||||
result = anthropic_trim_messages(
|
result = anthropic_trim_messages(
|
||||||
messages,
|
messages,
|
||||||
|
|
@ -123,7 +124,9 @@ def state_modifier(
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(result) < len(messages):
|
if len(result) < len(messages):
|
||||||
logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages")
|
logger.info(
|
||||||
|
f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages"
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -164,13 +167,15 @@ def sonnet_35_state_modifier(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]:
|
def get_provider_and_model_for_agent_type(
|
||||||
|
config: Dict[str, Any], agent_type: str
|
||||||
|
) -> Tuple[str, str]:
|
||||||
"""Get the provider and model name for the specified agent type.
|
"""Get the provider and model name for the specified agent type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Configuration dictionary containing provider and model information
|
config: Configuration dictionary containing provider and model information
|
||||||
agent_type: Type of agent ("default", "research", or "planner")
|
agent_type: Type of agent ("default", "research", or "planner")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str]: A tuple containing (provider, model_name)
|
Tuple[str, str]: A tuple containing (provider, model_name)
|
||||||
"""
|
"""
|
||||||
|
|
@ -183,36 +188,61 @@ def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: st
|
||||||
else:
|
else:
|
||||||
provider = config.get("provider", "")
|
provider = config.get("provider", "")
|
||||||
model_name = config.get("model", "")
|
model_name = config.get("model", "")
|
||||||
|
|
||||||
return provider, model_name
|
return provider, model_name
|
||||||
|
|
||||||
|
|
||||||
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
|
def get_model_name_from_chat_model(model: Optional[BaseChatModel]) -> str:
|
||||||
|
"""Extract the model name from a BaseChatModel instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The BaseChatModel instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The model name extracted from the instance, or DEFAULT_MODEL if not found
|
||||||
|
"""
|
||||||
|
if model is None:
|
||||||
|
return DEFAULT_MODEL
|
||||||
|
|
||||||
|
if hasattr(model, "model"):
|
||||||
|
return model.model
|
||||||
|
elif hasattr(model, "model_name"):
|
||||||
|
return model.model_name
|
||||||
|
else:
|
||||||
|
logger.debug(f"Could not extract model name from {model}, using DEFAULT_MODEL")
|
||||||
|
return DEFAULT_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_claude_37_token_limit(
|
||||||
|
max_input_tokens: int, model: Optional[BaseChatModel]
|
||||||
|
) -> Optional[int]:
|
||||||
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_input_tokens: The original token limit
|
max_input_tokens: The original token limit
|
||||||
model: The model instance to check
|
model: The model instance to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit
|
Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit
|
||||||
"""
|
"""
|
||||||
if not max_input_tokens:
|
if not max_input_tokens:
|
||||||
return max_input_tokens
|
return max_input_tokens
|
||||||
|
|
||||||
if model and hasattr(model, 'model') and is_claude_37(model.model):
|
if model and hasattr(model, "model") and is_claude_37(model.model):
|
||||||
if hasattr(model, 'max_tokens') and model.max_tokens:
|
if hasattr(model, "max_tokens") and model.max_tokens:
|
||||||
effective_max_input_tokens = max_input_tokens - model.max_tokens
|
effective_max_input_tokens = max_input_tokens - model.max_tokens
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
|
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
|
||||||
)
|
)
|
||||||
return effective_max_input_tokens
|
return effective_max_input_tokens
|
||||||
|
|
||||||
return max_input_tokens
|
return max_input_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
def get_model_token_limit(
|
||||||
config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None
|
config: Dict[str, Any],
|
||||||
|
agent_type: str = "default",
|
||||||
|
model: Optional[BaseChatModel] = None,
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
"""Get the token limit for the current model configuration based on agent type.
|
"""Get the token limit for the current model configuration based on agent type.
|
||||||
|
|
||||||
|
|
@ -234,7 +264,7 @@ def get_model_token_limit(
|
||||||
# In tests, this may fail because the repository isn't set up
|
# In tests, this may fail because the repository isn't set up
|
||||||
# So we'll use the passed config directly
|
# So we'll use the passed config directly
|
||||||
pass
|
pass
|
||||||
|
|
||||||
provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
|
provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
|
||||||
|
|
||||||
# Always attempt to get model info from litellm first
|
# Always attempt to get model info from litellm first
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue