feat(agent_utils.py): integrate litellm to retrieve model token limits for better flexibility (#51)
fix(agent_utils.py): rename max_tokens to max_input_tokens for clarity in state_modifier function fix(models_tokens.py): update deepseek-reasoner token limit to 64000 for accuracy test(agent_utils.py): add tests for litellm integration and fallback logic in get_model_token_limit function
This commit is contained in:
parent
c0499ab795
commit
5240fb2617
|
|
@ -45,6 +45,7 @@ dev = [
|
|||
"pytest-timeout>=2.2.0",
|
||||
"pytest>=7.0.0",
|
||||
"pytest-cov>=6.0.0",
|
||||
"pytest-mock>=3.14.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
@ -64,3 +65,4 @@ path = "ra_aid/__version__.py"
|
|||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["ra_aid"]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import time
|
|||
import uuid
|
||||
from typing import Optional, Any, List, Dict, Sequence
|
||||
from langchain_core.messages import BaseMessage, trim_messages
|
||||
from litellm import get_model_info
|
||||
import litellm
|
||||
|
||||
import signal
|
||||
|
||||
|
|
@ -93,7 +95,7 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
|||
|
||||
|
||||
def state_modifier(
|
||||
state: AgentState, max_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
|
|
@ -112,7 +114,7 @@ def state_modifier(
|
|||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
first_tokens = estimate_messages_tokens([first_message])
|
||||
new_max_tokens = max_tokens - first_tokens
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
|
|
@ -135,16 +137,29 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
|||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(f"Using litellm token limit for {model_name}: {max_input_tokens}")
|
||||
return max_input_tokens
|
||||
except litellm.exceptions.NotFoundError:
|
||||
logger.debug(f"Model {model_name} not found in litellm, falling back to models_tokens")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting model info from litellm: {e}, falling back to models_tokens")
|
||||
|
||||
# Fallback to models_tokens dict
|
||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||
normalized_name = model_name.replace("-", "")
|
||||
provider_tokens = models_tokens.get(provider, {})
|
||||
token_limit = provider_tokens.get(model_name, None)
|
||||
if token_limit:
|
||||
logger.debug(
|
||||
f"Found token limit for {provider}/{model_name}: {token_limit}"
|
||||
)
|
||||
max_input_tokens = provider_tokens.get(normalized_name, None)
|
||||
if max_input_tokens:
|
||||
logger.debug(f"Found token limit for {provider}/{model_name}: {max_input_tokens}")
|
||||
else:
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return token_limit
|
||||
return max_input_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
|
|
@ -154,7 +169,7 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
|||
def build_agent_kwargs(
|
||||
checkpointer: Optional[Any] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
token_limit: Optional[int] = None,
|
||||
max_input_tokens: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs dictionary for agent creation.
|
||||
|
||||
|
|
@ -174,7 +189,7 @@ def build_agent_kwargs(
|
|||
if config.get("limit_tokens", True) and is_anthropic_claude(config):
|
||||
|
||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||
return state_modifier(state, max_tokens=token_limit)
|
||||
return state_modifier(state, max_input_tokens=max_input_tokens)
|
||||
|
||||
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
||||
|
||||
|
|
@ -226,23 +241,23 @@ def create_agent(
|
|||
"""
|
||||
try:
|
||||
config = _global_memory.get("config", {})
|
||||
token_limit = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
|
||||
max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
if is_anthropic_claude(config):
|
||||
logger.debug("Using create_react_agent to instantiate agent.")
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
||||
return create_react_agent(model, tools, **agent_kwargs)
|
||||
else:
|
||||
logger.debug("Using CiaynAgent agent instance")
|
||||
return CiaynAgent(model, tools, max_tokens=token_limit)
|
||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens)
|
||||
|
||||
except Exception as e:
|
||||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = _global_memory.get("config", {})
|
||||
token_limit = get_model_token_limit(config)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
|
||||
max_input_tokens = get_model_token_limit(config)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
||||
return create_react_agent(model, tools, **agent_kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ List of model tokens
|
|||
|
||||
DEFAULT_TOKEN_LIMIT = 100000
|
||||
|
||||
|
||||
|
||||
models_tokens = {
|
||||
"openai": {
|
||||
"gpt-3.5-turbo-0125": 16385,
|
||||
|
|
@ -241,7 +243,7 @@ models_tokens = {
|
|||
"deepseek": {
|
||||
"deepseek-chat": 28672,
|
||||
"deepseek-coder": 16384,
|
||||
"deepseek-reasoner": 65536,
|
||||
"deepseek-reasoner": 64000,
|
||||
},
|
||||
"openrouter": {
|
||||
"deepseek/deepseek-r1": 65536,
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.agent_utils import state_modifier, AgentState
|
||||
from unittest.mock import Mock, patch
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
import litellm
|
||||
|
||||
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.agent_utils import state_modifier, AgentState
|
||||
|
||||
from ra_aid.agent_utils import create_agent, get_model_token_limit
|
||||
from ra_aid.models_tokens import models_tokens
|
||||
|
|
@ -58,6 +60,45 @@ def test_get_model_token_limit_missing_config(mock_memory):
|
|||
assert token_limit is None
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_success():
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == 100000
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found():
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found",
|
||||
model="claude-2",
|
||||
llm_provider="anthropic"
|
||||
)
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == models_tokens["anthropic"]["claude2"]
|
||||
|
||||
def test_get_model_token_limit_litellm_error():
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
||||
mock_get_info.side_effect = Exception("Unknown error")
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == models_tokens["anthropic"]["claude2"]
|
||||
|
||||
def test_get_model_token_limit_unexpected_error():
|
||||
"""Test returning None when unexpected errors occur."""
|
||||
config = None # This will cause an attribute error when accessed
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_create_agent_anthropic(mock_model, mock_memory):
|
||||
"""Test create_agent with Anthropic Claude model."""
|
||||
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
|
@ -138,7 +179,7 @@ def test_state_modifier(mock_messages):
|
|||
) as mock_estimate:
|
||||
mock_estimate.side_effect = lambda msg: 100 if msg else 0
|
||||
|
||||
result = state_modifier(state, max_tokens=250)
|
||||
result = state_modifier(state, max_input_tokens=250)
|
||||
|
||||
assert len(result) < len(mock_messages)
|
||||
assert isinstance(result[0], SystemMessage)
|
||||
|
|
|
|||
Loading…
Reference in New Issue