Merge pull request #132 from ariel-frischer/fix-token-limiter-2
Fix Sonnet 3.7 Token Limiter - Adjust Effective Max Input Tokens
This commit is contained in:
commit
aaf09c5df6
|
|
@ -5,24 +5,8 @@ import sys
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Add litellm import
|
||||
import litellm
|
||||
|
||||
# Configure litellm to suppress debug logs
|
||||
os.environ["LITELLM_LOG"] = "ERROR"
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
# Explicitly configure LiteLLM's loggers
|
||||
for logger_name in ["litellm", "LiteLLM"]:
|
||||
litellm_logger = logging.getLogger(logger_name)
|
||||
litellm_logger.setLevel(logging.WARNING)
|
||||
litellm_logger.propagate = True
|
||||
|
||||
# Use litellm's internal method to disable debugging
|
||||
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
||||
litellm._logging._disable_debugging()
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
|
@ -99,6 +83,21 @@ from ra_aid.tools.human import ask_human
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Configure litellm to suppress debug logs
|
||||
os.environ["LITELLM_LOG"] = "ERROR"
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
# Explicitly configure LiteLLM's loggers
|
||||
for logger_name in ["litellm", "LiteLLM"]:
|
||||
litellm_logger = logging.getLogger(logger_name)
|
||||
litellm_logger.setLevel(logging.WARNING)
|
||||
litellm_logger.propagate = True
|
||||
|
||||
# Use litellm's internal method to disable debugging
|
||||
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
||||
litellm._logging._disable_debugging()
|
||||
|
||||
|
||||
def launch_webui(host: str, port: int):
|
||||
"""Launch the RA.Aid web interface."""
|
||||
|
|
|
|||
|
|
@ -51,7 +51,12 @@ from ra_aid.database.repositories.human_input_repository import (
|
|||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
sonnet_35_state_modifier,
|
||||
state_modifier,
|
||||
get_model_token_limit,
|
||||
)
|
||||
from ra_aid.model_detection import is_anthropic_claude
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -67,8 +72,6 @@ def output_markdown_message(message: str) -> str:
|
|||
return "Message output."
|
||||
|
||||
|
||||
|
||||
|
||||
def build_agent_kwargs(
|
||||
checkpointer: Optional[Any] = None,
|
||||
model: ChatAnthropic = None,
|
||||
|
|
@ -99,8 +102,13 @@ def build_agent_kwargs(
|
|||
):
|
||||
|
||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]):
|
||||
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens)
|
||||
if any(
|
||||
pattern in model.model
|
||||
for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]
|
||||
):
|
||||
return sonnet_35_state_modifier(
|
||||
state, max_input_tokens=max_input_tokens
|
||||
)
|
||||
|
||||
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||
|
||||
|
|
@ -110,27 +118,6 @@ def build_agent_kwargs(
|
|||
return agent_kwargs
|
||||
|
||||
|
||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
|
||||
Returns:
|
||||
bool: True if this is an Anthropic Claude model
|
||||
"""
|
||||
# For backwards compatibility, allow passing of config directly
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
result = (
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
) or (
|
||||
provider.lower() == "openrouter"
|
||||
and model_name.lower().startswith("anthropic/claude-")
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def create_agent(
|
||||
|
|
@ -169,7 +156,7 @@ def create_agent(
|
|||
# So we'll use the passed config directly
|
||||
pass
|
||||
max_input_tokens = (
|
||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||
get_model_token_limit(config, agent_type, model) or DEFAULT_TOKEN_LIMIT
|
||||
)
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
|
|
@ -188,7 +175,7 @@ def create_agent(
|
|||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
max_input_tokens = get_model_token_limit(config, agent_type, model)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
|
|
@ -301,7 +288,7 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
|
|
@ -464,7 +451,9 @@ def run_agent_with_retry(
|
|||
|
||||
try:
|
||||
_run_agent_stream(agent, msg_list)
|
||||
if fallback_handler and hasattr(fallback_handler, 'reset_fallback_handler'):
|
||||
if fallback_handler and hasattr(
|
||||
fallback_handler, "reset_fallback_handler"
|
||||
):
|
||||
fallback_handler.reset_fallback_handler()
|
||||
should_break, prompt, auto_test, test_attempts = (
|
||||
_execute_test_command_wrapper(
|
||||
|
|
|
|||
|
|
@ -1,31 +1,27 @@
|
|||
"""Utilities for handling token limits with Anthropic models."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from ra_aid.model_detection import is_claude_37
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.messages.base import message_to_dict
|
||||
|
||||
from ra_aid.anthropic_message_utils import (
|
||||
anthropic_trim_messages,
|
||||
has_tool_use,
|
||||
)
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import token_counter
|
||||
from litellm import token_counter, get_model_info
|
||||
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.console.output import cpm, print_messages_compact
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -168,14 +164,62 @@ def sonnet_35_state_modifier(
|
|||
return result
|
||||
|
||||
|
||||
def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]:
|
||||
"""Get the provider and model name for the specified agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing (provider, model_name)
|
||||
"""
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
return provider, model_name
|
||||
|
||||
|
||||
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
|
||||
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
||||
|
||||
Args:
|
||||
max_input_tokens: The original token limit
|
||||
model: The model instance to check
|
||||
|
||||
Returns:
|
||||
Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit
|
||||
"""
|
||||
if not max_input_tokens:
|
||||
return max_input_tokens
|
||||
|
||||
if model and hasattr(model, 'model') and is_claude_37(model.model):
|
||||
if hasattr(model, 'max_tokens') and model.max_tokens:
|
||||
effective_max_input_tokens = max_input_tokens - model.max_tokens
|
||||
logger.debug(
|
||||
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
|
||||
)
|
||||
return effective_max_input_tokens
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any], agent_type: str = "default"
|
||||
config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
model: Optional BaseChatModel instance to check for model-specific attributes
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
|
|
@ -190,27 +234,20 @@ def get_model_token_limit(
|
|||
# In tests, this may fail because the repository isn't set up
|
||||
# So we'll use the passed config directly
|
||||
pass
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
|
||||
|
||||
# Always attempt to get model info from litellm first
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return max_input_tokens
|
||||
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||
|
|
@ -229,7 +266,7 @@ def get_model_token_limit(
|
|||
max_input_tokens = None
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return max_input_tokens
|
||||
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from openai import OpenAI
|
|||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.model_detection import is_claude_37
|
||||
|
||||
from .models_params import models_params
|
||||
|
||||
|
|
@ -218,7 +219,6 @@ def create_llm_client(
|
|||
is_expert,
|
||||
)
|
||||
|
||||
# Get model configuration
|
||||
model_config = models_params.get(provider, {}).get(model_name, {})
|
||||
|
||||
# Default to True for known providers that support temperature if not specified
|
||||
|
|
@ -228,6 +228,10 @@ def create_llm_client(
|
|||
supports_temperature = model_config["supports_temperature"]
|
||||
supports_thinking = model_config.get("supports_thinking", False)
|
||||
|
||||
other_kwargs = {}
|
||||
if is_claude_37(model_name):
|
||||
other_kwargs = {"max_tokens": 64000}
|
||||
|
||||
# Handle temperature settings
|
||||
if is_expert:
|
||||
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
||||
|
|
@ -235,8 +239,12 @@ def create_llm_client(
|
|||
if temperature is None:
|
||||
temperature = 0.7
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
HumanInputRepository,
|
||||
)
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
|
|
@ -250,7 +258,7 @@ def create_llm_client(
|
|||
"display_title": "Information",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
human_input_id=human_input_id,
|
||||
)
|
||||
cpm(
|
||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||
|
|
@ -302,9 +310,9 @@ def create_llm_client(
|
|||
model_name=model_name,
|
||||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
max_tokens=model_config.get("max_tokens", 64000),
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
**other_kwargs,
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
"""Utilities for detecting and working with specific model types."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def is_claude_37(model: str) -> bool:
|
||||
"""Check if the model is a Claude 3.7 model.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a Claude 3.7 model, False otherwise
|
||||
"""
|
||||
patterns = ["claude-3.7", "claude3.7", "claude-3-7"]
|
||||
return any(pattern in model for pattern in patterns)
|
||||
|
||||
|
||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
|
||||
Returns:
|
||||
bool: True if this is an Anthropic Claude model
|
||||
"""
|
||||
# For backwards compatibility, allow passing of config directly
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
result = (
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
) or (
|
||||
provider.lower() == "openrouter"
|
||||
and model_name.lower().startswith("anthropic/claude-")
|
||||
)
|
||||
return result
|
||||
|
|
@ -13,8 +13,7 @@ from ra_aid.agent_context import (
|
|||
)
|
||||
from ra_aid.agent_utils import (
|
||||
AgentState,
|
||||
create_agent,
|
||||
is_anthropic_claude,
|
||||
create_agent
|
||||
)
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
get_model_token_limit,
|
||||
|
|
@ -453,31 +452,6 @@ def test_handle_api_error_retry(monkeypatch):
|
|||
_handle_api_error(Exception("error code 429"), 0, 5, 1)
|
||||
|
||||
|
||||
def test_is_anthropic_claude():
|
||||
"""Test is_anthropic_claude function with various configurations."""
|
||||
# Test Anthropic provider cases
|
||||
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
||||
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
||||
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
||||
|
||||
# Test OpenRouter provider cases
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
||||
)
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
||||
)
|
||||
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
||||
|
||||
# Test edge cases
|
||||
assert not is_anthropic_claude({}) # Empty config
|
||||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||
assert not is_anthropic_claude(
|
||||
{"provider": "other", "model": "claude-2"}
|
||||
) # Wrong provider
|
||||
|
||||
|
||||
def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository):
|
||||
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
||||
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ from ra_aid.anthropic_token_limiter import (
|
|||
get_model_token_limit,
|
||||
state_modifier,
|
||||
sonnet_35_state_modifier,
|
||||
convert_message_to_litellm_format
|
||||
convert_message_to_litellm_format,
|
||||
adjust_claude_37_token_limit
|
||||
)
|
||||
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.models_params import models_params
|
||||
|
||||
|
||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||
|
|
@ -113,9 +114,8 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
self.assertEqual(result, 0)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
|
||||
def test_state_modifier(self, mock_trim_messages, mock_create_wrapper):
|
||||
# Setup a proper token counter function that returns integers
|
||||
def token_counter(msgs):
|
||||
# Return token count based on number of messages
|
||||
|
|
@ -155,8 +155,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
model.model = "claude-3-opus-20240229"
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
|
||||
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
|
||||
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
|
||||
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim:
|
||||
# Setup mock to return a fixed token count per message
|
||||
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
|
||||
# Setup mock to return a subset of messages
|
||||
|
|
@ -206,23 +205,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
self.assertEqual(call_args["include_system"], True)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||
@patch("ra_aid.anthropic_token_limiter.is_claude_37")
|
||||
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
|
||||
def test_get_model_token_limit_from_litellm(self, mock_adjust, mock_is_claude_37, mock_get_model_info, mock_get_config_repo):
|
||||
# Use a specific model name instead of DEFAULT_MODEL to avoid test dependency
|
||||
model_name = "claude-3-7-sonnet-20250219"
|
||||
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
||||
mock_config = {"provider": "anthropic", "model": model_name}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock litellm's get_model_info to return a token limit
|
||||
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
||||
|
||||
# Mock is_claude_37 to return True
|
||||
mock_is_claude_37.return_value = True
|
||||
|
||||
# Mock adjust_claude_37_token_limit to return the original value
|
||||
mock_adjust.return_value = 100000
|
||||
|
||||
# Test getting token limit
|
||||
result = get_model_token_limit(mock_config)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
# Verify get_model_info was called with the right model
|
||||
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
||||
mock_get_model_info.assert_called_once_with(f"anthropic/{model_name}")
|
||||
|
||||
# Verify adjust_claude_37_token_limit was called
|
||||
mock_adjust.assert_called_once_with(100000, None)
|
||||
|
||||
def test_get_model_token_limit_research(self):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
|
|
@ -230,17 +241,24 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2",
|
||||
"research_model": "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
mock_adjust.return_value = 150000
|
||||
|
||||
# Call the function to check the return value
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
self.assertEqual(token_limit, 150000)
|
||||
|
||||
# Verify get_model_info was called with the research model
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
|
||||
# Verify adjust_claude_37_token_limit was called
|
||||
mock_adjust.assert_called_once_with(150000, None)
|
||||
|
||||
def test_get_model_token_limit_planner(self):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
|
|
@ -252,16 +270,23 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
mock_adjust.return_value = 120000
|
||||
|
||||
# Call the function to check the return value
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
self.assertEqual(token_limit, 120000)
|
||||
|
||||
# Verify get_model_info was called with the planner model
|
||||
mock_get_info.assert_called_with("deepseek/dsm-1")
|
||||
mock_get_info.assert_called_once_with("deepseek/dsm-1")
|
||||
# Verify adjust_claude_37_token_limit was called
|
||||
mock_adjust.assert_called_once_with(120000, None)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
|
@ -280,54 +305,87 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
self.assertEqual(result, 100000)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
@patch("ra_aid.anthropic_token_limiter.get_model_info")
|
||||
@patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
|
||||
def test_get_model_token_limit_for_different_agent_types(self, mock_adjust, mock_get_model_info, mock_get_config_repo):
|
||||
# Use specific model names instead of DEFAULT_MODEL to avoid test dependency
|
||||
claude_model = "claude-3-7-sonnet-20250219"
|
||||
|
||||
# Setup mocks for different agent types
|
||||
mock_config = {
|
||||
"provider": "anthropic",
|
||||
"model": DEFAULT_MODEL,
|
||||
"model": claude_model,
|
||||
"research_provider": "openai",
|
||||
"research_model": "gpt-4",
|
||||
"planner_provider": "anthropic",
|
||||
"planner_model": "claude-3-sonnet-20240229"
|
||||
"planner_model": "claude-3-7-opus-20250301"
|
||||
}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock different returns for different models
|
||||
def model_info_side_effect(model_name):
|
||||
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
||||
if "claude-3-7-sonnet" in model_name:
|
||||
return {"max_input_tokens": 200000}
|
||||
elif "gpt-4" in model_name:
|
||||
return {"max_input_tokens": 8192}
|
||||
elif "claude-3-sonnet" in model_name:
|
||||
return {"max_input_tokens": 100000}
|
||||
elif "claude-3-7-opus" in model_name:
|
||||
return {"max_input_tokens": 250000}
|
||||
else:
|
||||
raise Exception(f"Unknown model: {model_name}")
|
||||
|
||||
mock_get_model_info.side_effect = model_info_side_effect
|
||||
|
||||
# Mock adjust_claude_37_token_limit to return the same values
|
||||
mock_adjust.side_effect = lambda tokens, model: tokens
|
||||
|
||||
# Test default agent type
|
||||
result = get_model_token_limit(mock_config, "default")
|
||||
self.assertEqual(result, 200000)
|
||||
mock_get_model_info.assert_called_with(f"anthropic/{claude_model}")
|
||||
|
||||
# Reset mock
|
||||
mock_get_model_info.reset_mock()
|
||||
|
||||
# Test research agent type
|
||||
result = get_model_token_limit(mock_config, "research")
|
||||
self.assertEqual(result, 8192)
|
||||
mock_get_model_info.assert_called_with("openai/gpt-4")
|
||||
|
||||
# Reset mock
|
||||
mock_get_model_info.reset_mock()
|
||||
|
||||
# Test planner agent type
|
||||
result = get_model_token_limit(mock_config, "planner")
|
||||
self.assertEqual(result, 100000)
|
||||
self.assertEqual(result, 250000)
|
||||
mock_get_model_info.assert_called_with("anthropic/claude-3-7-opus-20250301")
|
||||
|
||||
def test_get_model_token_limit_anthropic(self):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
|
||||
patch("litellm.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
|
||||
# Setup mocks
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = Exception("Model not found")
|
||||
|
||||
# Create a mock models_params with claude-3-7
|
||||
mock_models_params_dict = {
|
||||
"anthropic": {
|
||||
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
|
||||
}
|
||||
}
|
||||
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
|
||||
mock_models_params.get.side_effect = mock_models_params_dict.get
|
||||
|
||||
# Mock adjust to return the same value
|
||||
mock_adjust.return_value = 200000
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
self.assertEqual(token_limit, 200000)
|
||||
|
||||
def test_get_model_token_limit_openai(self):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
|
|
@ -358,28 +416,51 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
|
||||
def test_get_model_token_limit_litellm_success(self):
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
mock_adjust.return_value = 100000
|
||||
|
||||
# Call the function to check the return value
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, 100000)
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
|
||||
# Verify get_model_info was called with the right model
|
||||
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
|
||||
mock_adjust.assert_called_once_with(100000, None)
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found(self):
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
patch("litellm.get_model_info") as mock_get_info, \
|
||||
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
|
||||
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
|
||||
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
message="Model not found", model="claude-3-7-sonnet-20250219", llm_provider="anthropic"
|
||||
)
|
||||
|
||||
# Create a mock models_params with claude-3-7
|
||||
mock_models_params_dict = {
|
||||
"anthropic": {
|
||||
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
|
||||
}
|
||||
}
|
||||
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
|
||||
mock_models_params.get.side_effect = mock_models_params_dict.get
|
||||
|
||||
# Mock adjust to return the same value
|
||||
mock_adjust.return_value = 200000
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
self.assertEqual(token_limit, 200000)
|
||||
|
||||
def test_get_model_token_limit_litellm_error(self):
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
|
|
@ -399,6 +480,30 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
|
|||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_adjust_claude_37_token_limit(self):
|
||||
"""Test adjust_claude_37_token_limit function."""
|
||||
# Create a mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.model = "claude-3.7-sonnet"
|
||||
mock_model.max_tokens = 4096
|
||||
|
||||
# Test with Claude 3.7 model
|
||||
result = adjust_claude_37_token_limit(100000, mock_model)
|
||||
self.assertEqual(result, 95904) # 100000 - 4096
|
||||
|
||||
# Test with non-Claude 3.7 model
|
||||
mock_model.model = "claude-3-opus"
|
||||
result = adjust_claude_37_token_limit(100000, mock_model)
|
||||
self.assertEqual(result, 100000) # No adjustment
|
||||
|
||||
# Test with None max_input_tokens
|
||||
result = adjust_claude_37_token_limit(None, mock_model)
|
||||
self.assertIsNone(result)
|
||||
|
||||
# Test with None model
|
||||
result = adjust_claude_37_token_limit(100000, None)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
def test_has_tool_use(self):
|
||||
"""Test the has_tool_use function."""
|
||||
# Test with regular AI message
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
"""Unit tests for model_detection.py."""
|
||||
|
||||
import pytest
|
||||
from ra_aid.model_detection import is_anthropic_claude, is_claude_37
|
||||
|
||||
|
||||
def test_is_anthropic_claude():
|
||||
"""Test is_anthropic_claude function with various configurations."""
|
||||
# Test Anthropic provider cases
|
||||
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
||||
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
||||
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
||||
|
||||
# Test OpenRouter provider cases
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
||||
)
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
||||
)
|
||||
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
||||
|
||||
# Test edge cases
|
||||
assert not is_anthropic_claude({}) # Empty config
|
||||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||
assert not is_anthropic_claude(
|
||||
{"provider": "other", "model": "claude-2"}
|
||||
) # Wrong provider
|
||||
|
||||
|
||||
def test_is_claude_37():
|
||||
"""Test is_claude_37 function with various model names."""
|
||||
# Test positive cases
|
||||
assert is_claude_37("claude-3.7")
|
||||
assert is_claude_37("claude3.7")
|
||||
assert is_claude_37("claude-3-7")
|
||||
assert is_claude_37("anthropic/claude-3.7")
|
||||
assert is_claude_37("anthropic/claude3.7")
|
||||
assert is_claude_37("anthropic/claude-3-7")
|
||||
assert is_claude_37("claude-3.7-sonnet")
|
||||
assert is_claude_37("claude3.7-haiku")
|
||||
|
||||
# Test negative cases
|
||||
assert not is_claude_37("claude-3")
|
||||
assert not is_claude_37("claude-3.5")
|
||||
assert not is_claude_37("claude3.5")
|
||||
assert not is_claude_37("claude-3-5")
|
||||
assert not is_claude_37("gpt-4")
|
||||
assert not is_claude_37("")
|
||||
Loading…
Reference in New Issue