feat(agent_utils.py): integrate litellm to retrieve model token limits for better flexibility (#51)

fix(agent_utils.py): rename max_tokens to max_input_tokens for clarity in state_modifier function
fix(models_tokens.py): update deepseek-reasoner token limit to 64000 for accuracy
test(agent_utils.py): add tests for litellm integration and fallback logic in get_model_token_limit function
This commit is contained in:
Ariel Frischer 2025-01-23 11:48:30 -08:00 committed by GitHub
parent c0499ab795
commit 5240fb2617
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 19 deletions

View File

@ -45,6 +45,7 @@ dev = [
"pytest-timeout>=2.2.0",
"pytest>=7.0.0",
"pytest-cov>=6.0.0",
"pytest-mock>=3.14.0",
]
[project.scripts]
@ -64,3 +65,4 @@ path = "ra_aid/__version__.py"
[tool.hatch.build.targets.wheel]
packages = ["ra_aid"]

View File

@ -5,6 +5,8 @@ import time
import uuid
from typing import Optional, Any, List, Dict, Sequence
from langchain_core.messages import BaseMessage, trim_messages
from litellm import get_model_info
import litellm
import signal
@ -93,7 +95,7 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
def state_modifier(
state: AgentState, max_tokens: int = DEFAULT_TOKEN_LIMIT
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
@ -112,7 +114,7 @@ def state_modifier(
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_tokens - first_tokens
new_max_tokens = max_input_tokens - first_tokens
trimmed_remaining = trim_messages(
remaining_messages,
@ -135,16 +137,29 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
provider = config.get("provider", "")
model_name = config.get("model", "")
try:
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:
logger.debug(f"Using litellm token limit for {model_name}: {max_input_tokens}")
return max_input_tokens
except litellm.exceptions.NotFoundError:
logger.debug(f"Model {model_name} not found in litellm, falling back to models_tokens")
except Exception as e:
logger.debug(f"Error getting model info from litellm: {e}, falling back to models_tokens")
# Fallback to models_tokens dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
normalized_name = model_name.replace("-", "")
provider_tokens = models_tokens.get(provider, {})
token_limit = provider_tokens.get(model_name, None)
if token_limit:
logger.debug(
f"Found token limit for {provider}/{model_name}: {token_limit}"
)
max_input_tokens = provider_tokens.get(normalized_name, None)
if max_input_tokens:
logger.debug(f"Found token limit for {provider}/{model_name}: {max_input_tokens}")
else:
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return token_limit
return max_input_tokens
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
@ -154,7 +169,7 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
def build_agent_kwargs(
checkpointer: Optional[Any] = None,
config: Dict[str, Any] = None,
token_limit: Optional[int] = None,
max_input_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""Build kwargs dictionary for agent creation.
@ -174,7 +189,7 @@ def build_agent_kwargs(
if config.get("limit_tokens", True) and is_anthropic_claude(config):
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
return state_modifier(state, max_tokens=token_limit)
return state_modifier(state, max_input_tokens=max_input_tokens)
agent_kwargs["state_modifier"] = wrapped_state_modifier
@ -226,23 +241,23 @@ def create_agent(
"""
try:
config = _global_memory.get("config", {})
token_limit = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs)
else:
logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=token_limit)
return CiaynAgent(model, tools, max_tokens=max_input_tokens)
except Exception as e:
# Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = _global_memory.get("config", {})
token_limit = get_model_token_limit(config)
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
max_input_tokens = get_model_token_limit(config)
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs)

View File

@ -4,6 +4,8 @@ List of model tokens
DEFAULT_TOKEN_LIMIT = 100000
models_tokens = {
"openai": {
"gpt-3.5-turbo-0125": 16385,
@ -241,7 +243,7 @@ models_tokens = {
"deepseek": {
"deepseek-chat": 28672,
"deepseek-coder": 16384,
"deepseek-reasoner": 65536,
"deepseek-reasoner": 64000,
},
"openrouter": {
"deepseek/deepseek-r1": 65536,

View File

@ -2,10 +2,12 @@
import pytest
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
from ra_aid.agent_utils import state_modifier, AgentState
from unittest.mock import Mock, patch
from langchain_core.language_models import BaseChatModel
import litellm
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
from ra_aid.agent_utils import state_modifier, AgentState
from ra_aid.agent_utils import create_agent, get_model_token_limit
from ra_aid.models_tokens import models_tokens
@ -58,6 +60,45 @@ def test_get_model_token_limit_missing_config(mock_memory):
assert token_limit is None
def test_get_model_token_limit_litellm_success():
"""Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config)
assert token_limit == 100000
def test_get_model_token_limit_litellm_not_found():
"""Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found",
model="claude-2",
llm_provider="anthropic"
)
token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["anthropic"]["claude2"]
def test_get_model_token_limit_litellm_error():
"""Test fallback to models_tokens when litellm raises other exceptions."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["anthropic"]["claude2"]
def test_get_model_token_limit_unexpected_error():
"""Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config)
assert token_limit is None
def test_create_agent_anthropic(mock_model, mock_memory):
"""Test create_agent with Anthropic Claude model."""
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
@ -138,7 +179,7 @@ def test_state_modifier(mock_messages):
) as mock_estimate:
mock_estimate.side_effect = lambda msg: 100 if msg else 0
result = state_modifier(state, max_tokens=250)
result = state_modifier(state, max_input_tokens=250)
assert len(result) < len(mock_messages)
assert isinstance(result[0], SystemMessage)