From 5240fb2617a920544935f5a2753f41098a094f42 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 23 Jan 2025 11:48:30 -0800 Subject: [PATCH] feat(agent_utils.py): integrate litellm to retrieve model token limits for better flexibility (#51) fix(agent_utils.py): rename max_tokens to max_input_tokens for clarity in state_modifier function fix(models_tokens.py): update deepseek-reasoner token limit to 64000 for accuracy test(agent_utils.py): add tests for litellm integration and fallback logic in get_model_token_limit function --- pyproject.toml | 2 ++ ra_aid/agent_utils.py | 45 ++++++++++++++++++++---------- ra_aid/models_tokens.py | 4 ++- tests/ra_aid/test_agent_utils.py | 47 ++++++++++++++++++++++++++++++-- 4 files changed, 79 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 87aec73..14b19bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "pytest-timeout>=2.2.0", "pytest>=7.0.0", "pytest-cov>=6.0.0", + "pytest-mock>=3.14.0", ] [project.scripts] @@ -64,3 +65,4 @@ path = "ra_aid/__version__.py" [tool.hatch.build.targets.wheel] packages = ["ra_aid"] + diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 7d52450..7db4408 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -5,6 +5,8 @@ import time import uuid from typing import Optional, Any, List, Dict, Sequence from langchain_core.messages import BaseMessage, trim_messages +from litellm import get_model_info +import litellm import signal @@ -93,7 +95,7 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: def state_modifier( - state: AgentState, max_tokens: int = DEFAULT_TOKEN_LIMIT + state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: """Given the agent state and max_tokens, return a trimmed list of messages. @@ -112,7 +114,7 @@ def state_modifier( first_message = messages[0] remaining_messages = messages[1:] first_tokens = estimate_messages_tokens([first_message]) - new_max_tokens = max_tokens - first_tokens + new_max_tokens = max_input_tokens - first_tokens trimmed_remaining = trim_messages( remaining_messages, @@ -135,16 +137,29 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]: provider = config.get("provider", "") model_name = config.get("model", "") + try: + provider_model = model_name if not provider else f"{provider}/{model_name}" + model_info = get_model_info(provider_model) + max_input_tokens = model_info.get("max_input_tokens") + if max_input_tokens: + logger.debug(f"Using litellm token limit for {model_name}: {max_input_tokens}") + return max_input_tokens + except litellm.exceptions.NotFoundError: + logger.debug(f"Model {model_name} not found in litellm, falling back to models_tokens") + except Exception as e: + logger.debug(f"Error getting model info from litellm: {e}, falling back to models_tokens") + + # Fallback to models_tokens dict + # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) + normalized_name = model_name.replace("-", "") provider_tokens = models_tokens.get(provider, {}) - token_limit = provider_tokens.get(model_name, None) - if token_limit: - logger.debug( - f"Found token limit for {provider}/{model_name}: {token_limit}" - ) + max_input_tokens = provider_tokens.get(normalized_name, None) + if max_input_tokens: + logger.debug(f"Found token limit for {provider}/{model_name}: {max_input_tokens}") else: logger.debug(f"Could not find token limit for {provider}/{model_name}") - return token_limit + return max_input_tokens except Exception as e: logger.warning(f"Failed to get model token limit: {e}") @@ -154,7 +169,7 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]: def build_agent_kwargs( checkpointer: Optional[Any] = None, config: Dict[str, Any] = None, - token_limit: Optional[int] = None, + max_input_tokens: Optional[int] = None, ) -> Dict[str, Any]: """Build kwargs dictionary for agent creation. @@ -174,7 +189,7 @@ def build_agent_kwargs( if config.get("limit_tokens", True) and is_anthropic_claude(config): def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: - return state_modifier(state, max_tokens=token_limit) + return state_modifier(state, max_input_tokens=max_input_tokens) agent_kwargs["state_modifier"] = wrapped_state_modifier @@ -226,23 +241,23 @@ def create_agent( """ try: config = _global_memory.get("config", {}) - token_limit = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT + max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): logger.debug("Using create_react_agent to instantiate agent.") - agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit) + agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) return create_react_agent(model, tools, **agent_kwargs) else: logger.debug("Using CiaynAgent agent instance") - return CiaynAgent(model, tools, max_tokens=token_limit) + return CiaynAgent(model, tools, max_tokens=max_input_tokens) except Exception as e: # Default to REACT agent if provider/model detection fails logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = _global_memory.get("config", {}) - token_limit = get_model_token_limit(config) - agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit) + max_input_tokens = get_model_token_limit(config) + agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) return create_react_agent(model, tools, **agent_kwargs) diff --git a/ra_aid/models_tokens.py b/ra_aid/models_tokens.py index 9307634..01aa2ae 100644 --- a/ra_aid/models_tokens.py +++ b/ra_aid/models_tokens.py @@ -4,6 +4,8 @@ List of model tokens DEFAULT_TOKEN_LIMIT = 100000 + + models_tokens = { "openai": { "gpt-3.5-turbo-0125": 16385, @@ -241,7 +243,7 @@ models_tokens = { "deepseek": { "deepseek-chat": 28672, "deepseek-coder": 16384, - "deepseek-reasoner": 65536, + "deepseek-reasoner": 64000, }, "openrouter": { "deepseek/deepseek-r1": 65536, diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 87216f6..184390d 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -2,10 +2,12 @@ import pytest from langchain_core.messages import SystemMessage, HumanMessage, AIMessage -from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT -from ra_aid.agent_utils import state_modifier, AgentState from unittest.mock import Mock, patch from langchain_core.language_models import BaseChatModel +import litellm + +from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT +from ra_aid.agent_utils import state_modifier, AgentState from ra_aid.agent_utils import create_agent, get_model_token_limit from ra_aid.models_tokens import models_tokens @@ -58,6 +60,45 @@ def test_get_model_token_limit_missing_config(mock_memory): assert token_limit is None +def test_get_model_token_limit_litellm_success(): + """Test get_model_token_limit successfully getting limit from litellm.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch('ra_aid.agent_utils.get_model_info') as mock_get_info: + mock_get_info.return_value = {"max_input_tokens": 100000} + token_limit = get_model_token_limit(config) + assert token_limit == 100000 + +def test_get_model_token_limit_litellm_not_found(): + """Test fallback to models_tokens when litellm raises NotFoundError.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch('ra_aid.agent_utils.get_model_info') as mock_get_info: + mock_get_info.side_effect = litellm.exceptions.NotFoundError( + message="Model not found", + model="claude-2", + llm_provider="anthropic" + ) + token_limit = get_model_token_limit(config) + assert token_limit == models_tokens["anthropic"]["claude2"] + +def test_get_model_token_limit_litellm_error(): + """Test fallback to models_tokens when litellm raises other exceptions.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch('ra_aid.agent_utils.get_model_info') as mock_get_info: + mock_get_info.side_effect = Exception("Unknown error") + token_limit = get_model_token_limit(config) + assert token_limit == models_tokens["anthropic"]["claude2"] + +def test_get_model_token_limit_unexpected_error(): + """Test returning None when unexpected errors occur.""" + config = None # This will cause an attribute error when accessed + + token_limit = get_model_token_limit(config) + assert token_limit is None + + def test_create_agent_anthropic(mock_model, mock_memory): """Test create_agent with Anthropic Claude model.""" mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"} @@ -138,7 +179,7 @@ def test_state_modifier(mock_messages): ) as mock_estimate: mock_estimate.side_effect = lambda msg: 100 if msg else 0 - result = state_modifier(state, max_tokens=250) + result = state_modifier(state, max_input_tokens=250) assert len(result) < len(mock_messages) assert isinstance(result[0], SystemMessage)