feat(agent_utils.py): enhance get_model_token_limit to support agent types for better configuration management

test(agent_utils.py): add tests for get_model_token_limit with different agent types to ensure correct functionality
This commit is contained in:
Ariel Frischer 2025-02-01 12:55:36 -08:00
parent dd8b9c0d30
commit c2ba638a95
2 changed files with 77 additions and 31 deletions

View File

@ -5,7 +5,7 @@ import sys
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Dict, List, Literal, Optional, Sequence
import litellm import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
@ -122,13 +122,22 @@ def state_modifier(
return [first_message] + trimmed_remaining return [first_message] + trimmed_remaining
def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]: def get_model_token_limit(
"""Get the token limit for the current model configuration. config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.
Returns: Returns:
Optional[int]: The token limit if found, None otherwise Optional[int]: The token limit if found, None otherwise
""" """
try: try:
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "") provider = config.get("provider", "")
model_name = config.get("model", "") model_name = config.get("model", "")
@ -224,6 +233,7 @@ def create_agent(
tools: List[Any], tools: List[Any],
*, *,
checkpointer: Any = None, checkpointer: Any = None,
agent_type: str = "default",
) -> Any: ) -> Any:
"""Create a react agent with the given configuration. """Create a react agent with the given configuration.
@ -245,7 +255,9 @@ def create_agent(
""" """
try: try:
config = _global_memory.get("config", {}) config = _global_memory.get("config", {})
max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
)
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN # Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config): if is_anthropic_claude(config):
@ -260,7 +272,7 @@ def create_agent(
# Default to REACT agent if provider/model detection fails # Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = _global_memory.get("config", {}) config = _global_memory.get("config", {})
max_input_tokens = get_model_token_limit(config) max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs) return create_react_agent(model, tools, **agent_kwargs)
@ -326,7 +338,7 @@ def run_research_agent(
web_research_enabled=config.get("web_research_enabled", False), web_research_enabled=config.get("web_research_enabled", False),
) )
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
@ -349,9 +361,11 @@ def run_research_agent(
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
base_task=base_task_or_query, base_task=base_task_or_query,
research_only_note="" research_only_note=(
""
if research_only if research_only
else " Only request implementation if the user explicitly asked for changes to be made.", else " Only request implementation if the user explicitly asked for changes to be made."
),
expert_section=expert_section, expert_section=expert_section,
human_section=human_section, human_section=human_section,
web_research_section=web_research_section, web_research_section=web_research_section,
@ -455,7 +469,7 @@ def run_web_research_agent(
tools = get_web_research_tools(expert_enabled=expert_enabled) tools = get_web_research_tools(expert_enabled=expert_enabled)
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
@ -536,7 +550,7 @@ def run_planning_agent(
web_research_enabled=config.get("web_research_enabled", False), web_research_enabled=config.get("web_research_enabled", False),
) )
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
@ -556,9 +570,11 @@ def run_planning_agent(
key_facts=get_memory_value("key_facts"), key_facts=get_memory_value("key_facts"),
key_snippets=get_memory_value("key_snippets"), key_snippets=get_memory_value("key_snippets"),
work_log=get_memory_value("work_log"), work_log=get_memory_value("work_log"),
research_only_note="" research_only_note=(
""
if config.get("research_only") if config.get("research_only")
else " Only request implementation if the user explicitly asked for changes to be made.", else " Only request implementation if the user explicitly asked for changes to be made."
),
) )
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config
@ -634,7 +650,7 @@ def run_task_implementation_agent(
web_research_enabled=config.get("web_research_enabled", False), web_research_enabled=config.get("web_research_enabled", False),
) )
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
prompt = IMPLEMENTATION_PROMPT.format( prompt = IMPLEMENTATION_PROMPT.format(
base_task=base_task, base_task=base_task,
@ -647,12 +663,16 @@ def run_task_implementation_agent(
research_notes=get_memory_value("research_notes"), research_notes=get_memory_value("research_notes"),
work_log=get_memory_value("work_log"), work_log=get_memory_value("work_log"),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION human_section=(
HUMAN_PROMPT_SECTION_IMPLEMENTATION
if _global_memory.get("config", {}).get("hil", False) if _global_memory.get("config", {}).get("hil", False)
else "", else ""
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT ),
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT
if config.get("web_research_enabled") if config.get("web_research_enabled")
else "", else ""
),
) )
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config

View File

@ -35,7 +35,7 @@ def test_get_model_token_limit_anthropic(mock_memory):
"""Test get_model_token_limit with Anthropic model.""" """Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"} config = {"provider": "anthropic", "model": "claude2"}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
@ -43,7 +43,7 @@ def test_get_model_token_limit_openai(mock_memory):
"""Test get_model_token_limit with OpenAI model.""" """Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"} config = {"provider": "openai", "model": "gpt-4"}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
@ -51,7 +51,7 @@ def test_get_model_token_limit_unknown(mock_memory):
"""Test get_model_token_limit with unknown provider/model.""" """Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"} config = {"provider": "unknown", "model": "unknown-model"}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
@ -59,7 +59,7 @@ def test_get_model_token_limit_missing_config(mock_memory):
"""Test get_model_token_limit with missing configuration.""" """Test get_model_token_limit with missing configuration."""
config = {} config = {}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
@ -69,7 +69,7 @@ def test_get_model_token_limit_litellm_success():
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 100000} mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == 100000 assert token_limit == 100000
@ -81,7 +81,7 @@ def test_get_model_token_limit_litellm_not_found():
mock_get_info.side_effect = litellm.exceptions.NotFoundError( mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic" message="Model not found", model="claude-2", llm_provider="anthropic"
) )
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
@ -91,7 +91,7 @@ def test_get_model_token_limit_litellm_error():
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = Exception("Unknown error") mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
@ -99,7 +99,7 @@ def test_get_model_token_limit_unexpected_error():
"""Test returning None when unexpected errors occur.""" """Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
@ -247,3 +247,29 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
assert agent == "react_agent" assert agent == "react_agent"
mock_react.assert_called_once_with(mock_model, []) mock_react.assert_called_once_with(mock_model, [])
def test_get_model_token_limit_research(mock_memory):
"""Test get_model_token_limit with research provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"research_provider": "anthropic",
"research_model": "claude-2"
}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research")
assert token_limit == 150000
def test_get_model_token_limit_planner(mock_memory):
"""Test get_model_token_limit with planner provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"planner_provider": "deepseek",
"planner_model": "dsm-1"
}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000