feat(agent_utils.py): enhance get_model_token_limit to support agent types for better configuration management
test(agent_utils.py): add tests for get_model_token_limit with different agent types to ensure correct functionality
This commit is contained in:
parent
dd8b9c0d30
commit
c2ba638a95
|
|
@ -5,7 +5,7 @@ import sys
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
|
|
@ -122,15 +122,24 @@ def state_modifier(
|
|||
return [first_message] + trimmed_remaining
|
||||
|
||||
|
||||
def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration.
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
|
|
@ -224,6 +233,7 @@ def create_agent(
|
|||
tools: List[Any],
|
||||
*,
|
||||
checkpointer: Any = None,
|
||||
agent_type: str = "default",
|
||||
) -> Any:
|
||||
"""Create a react agent with the given configuration.
|
||||
|
||||
|
|
@ -245,7 +255,9 @@ def create_agent(
|
|||
"""
|
||||
try:
|
||||
config = _global_memory.get("config", {})
|
||||
max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
|
||||
max_input_tokens = (
|
||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||
)
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
if is_anthropic_claude(config):
|
||||
|
|
@ -260,7 +272,7 @@ def create_agent(
|
|||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = _global_memory.get("config", {})
|
||||
max_input_tokens = get_model_token_limit(config)
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
||||
return create_react_agent(model, tools, **agent_kwargs)
|
||||
|
||||
|
|
@ -326,7 +338,7 @@ def run_research_agent(
|
|||
web_research_enabled=config.get("web_research_enabled", False),
|
||||
)
|
||||
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
|
||||
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||
|
|
@ -349,9 +361,11 @@ def run_research_agent(
|
|||
|
||||
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
||||
base_task=base_task_or_query,
|
||||
research_only_note=""
|
||||
if research_only
|
||||
else " Only request implementation if the user explicitly asked for changes to be made.",
|
||||
research_only_note=(
|
||||
""
|
||||
if research_only
|
||||
else " Only request implementation if the user explicitly asked for changes to be made."
|
||||
),
|
||||
expert_section=expert_section,
|
||||
human_section=human_section,
|
||||
web_research_section=web_research_section,
|
||||
|
|
@ -455,7 +469,7 @@ def run_web_research_agent(
|
|||
|
||||
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
||||
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
|
||||
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||
|
|
@ -536,7 +550,7 @@ def run_planning_agent(
|
|||
web_research_enabled=config.get("web_research_enabled", False),
|
||||
)
|
||||
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
|
||||
|
||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
||||
|
|
@ -556,9 +570,11 @@ def run_planning_agent(
|
|||
key_facts=get_memory_value("key_facts"),
|
||||
key_snippets=get_memory_value("key_snippets"),
|
||||
work_log=get_memory_value("work_log"),
|
||||
research_only_note=""
|
||||
if config.get("research_only")
|
||||
else " Only request implementation if the user explicitly asked for changes to be made.",
|
||||
research_only_note=(
|
||||
""
|
||||
if config.get("research_only")
|
||||
else " Only request implementation if the user explicitly asked for changes to be made."
|
||||
),
|
||||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
|
|
@ -634,7 +650,7 @@ def run_task_implementation_agent(
|
|||
web_research_enabled=config.get("web_research_enabled", False),
|
||||
)
|
||||
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
|
||||
|
||||
prompt = IMPLEMENTATION_PROMPT.format(
|
||||
base_task=base_task,
|
||||
|
|
@ -647,12 +663,16 @@ def run_task_implementation_agent(
|
|||
research_notes=get_memory_value("research_notes"),
|
||||
work_log=get_memory_value("work_log"),
|
||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||
if _global_memory.get("config", {}).get("hil", False)
|
||||
else "",
|
||||
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
if config.get("web_research_enabled")
|
||||
else "",
|
||||
human_section=(
|
||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||
if _global_memory.get("config", {}).get("hil", False)
|
||||
else ""
|
||||
),
|
||||
web_research_section=(
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
if config.get("web_research_enabled")
|
||||
else ""
|
||||
),
|
||||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def test_get_model_token_limit_anthropic(mock_memory):
|
|||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ def test_get_model_token_limit_openai(mock_memory):
|
|||
"""Test get_model_token_limit with OpenAI model."""
|
||||
config = {"provider": "openai", "model": "gpt-4"}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
||||
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ def test_get_model_token_limit_unknown(mock_memory):
|
|||
"""Test get_model_token_limit with unknown provider/model."""
|
||||
config = {"provider": "unknown", "model": "unknown-model"}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ def test_get_model_token_limit_missing_config(mock_memory):
|
|||
"""Test get_model_token_limit with missing configuration."""
|
||||
config = {}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ def test_get_model_token_limit_litellm_success():
|
|||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == 100000
|
||||
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ def test_get_model_token_limit_litellm_not_found():
|
|||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
)
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ def test_get_model_token_limit_litellm_error():
|
|||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.side_effect = Exception("Unknown error")
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ def test_get_model_token_limit_unexpected_error():
|
|||
"""Test returning None when unexpected errors occur."""
|
||||
config = None # This will cause an attribute error when accessed
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
|
|
@ -247,3 +247,29 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
|
|||
|
||||
assert agent == "react_agent"
|
||||
mock_react.assert_called_once_with(mock_model, [])
|
||||
|
||||
def test_get_model_token_limit_research(mock_memory):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2"
|
||||
}
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
assert token_limit == 150000
|
||||
|
||||
def test_get_model_token_limit_planner(mock_memory):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"planner_provider": "deepseek",
|
||||
"planner_model": "dsm-1"
|
||||
}
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
assert token_limit == 120000
|
||||
|
|
|
|||
Loading…
Reference in New Issue