feat(agent_utils.py): enhance get_model_token_limit to support agent types for better configuration management
test(agent_utils.py): add tests for get_model_token_limit with different agent types to ensure correct functionality
This commit is contained in:
parent
dd8b9c0d30
commit
c2ba638a95
|
|
@ -5,7 +5,7 @@ import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Sequence
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
|
|
@ -122,15 +122,24 @@ def state_modifier(
|
||||||
return [first_message] + trimmed_remaining
|
return [first_message] + trimmed_remaining
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
def get_model_token_limit(
|
||||||
"""Get the token limit for the current model configuration.
|
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Get the token limit for the current model configuration based on agent type.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[int]: The token limit if found, None otherwise
|
Optional[int]: The token limit if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
provider = config.get("provider", "")
|
if agent_type == "research":
|
||||||
model_name = config.get("model", "")
|
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("research_model", "") or config.get("model", "")
|
||||||
|
elif agent_type == "planner":
|
||||||
|
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||||
|
else:
|
||||||
|
provider = config.get("provider", "")
|
||||||
|
model_name = config.get("model", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||||
|
|
@ -224,6 +233,7 @@ def create_agent(
|
||||||
tools: List[Any],
|
tools: List[Any],
|
||||||
*,
|
*,
|
||||||
checkpointer: Any = None,
|
checkpointer: Any = None,
|
||||||
|
agent_type: str = "default",
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Create a react agent with the given configuration.
|
"""Create a react agent with the given configuration.
|
||||||
|
|
||||||
|
|
@ -245,7 +255,9 @@ def create_agent(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = _global_memory.get("config", {})
|
config = _global_memory.get("config", {})
|
||||||
max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
|
max_input_tokens = (
|
||||||
|
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||||
|
)
|
||||||
|
|
||||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||||
if is_anthropic_claude(config):
|
if is_anthropic_claude(config):
|
||||||
|
|
@ -260,7 +272,7 @@ def create_agent(
|
||||||
# Default to REACT agent if provider/model detection fails
|
# Default to REACT agent if provider/model detection fails
|
||||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||||
config = _global_memory.get("config", {})
|
config = _global_memory.get("config", {})
|
||||||
max_input_tokens = get_model_token_limit(config)
|
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
||||||
return create_react_agent(model, tools, **agent_kwargs)
|
return create_react_agent(model, tools, **agent_kwargs)
|
||||||
|
|
||||||
|
|
@ -326,7 +338,7 @@ def run_research_agent(
|
||||||
web_research_enabled=config.get("web_research_enabled", False),
|
web_research_enabled=config.get("web_research_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
|
||||||
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||||
|
|
@ -349,9 +361,11 @@ def run_research_agent(
|
||||||
|
|
||||||
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
||||||
base_task=base_task_or_query,
|
base_task=base_task_or_query,
|
||||||
research_only_note=""
|
research_only_note=(
|
||||||
if research_only
|
""
|
||||||
else " Only request implementation if the user explicitly asked for changes to be made.",
|
if research_only
|
||||||
|
else " Only request implementation if the user explicitly asked for changes to be made."
|
||||||
|
),
|
||||||
expert_section=expert_section,
|
expert_section=expert_section,
|
||||||
human_section=human_section,
|
human_section=human_section,
|
||||||
web_research_section=web_research_section,
|
web_research_section=web_research_section,
|
||||||
|
|
@ -455,7 +469,7 @@ def run_web_research_agent(
|
||||||
|
|
||||||
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
||||||
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
|
||||||
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||||
|
|
@ -536,7 +550,7 @@ def run_planning_agent(
|
||||||
web_research_enabled=config.get("web_research_enabled", False),
|
web_research_enabled=config.get("web_research_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
|
||||||
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
||||||
|
|
@ -556,9 +570,11 @@ def run_planning_agent(
|
||||||
key_facts=get_memory_value("key_facts"),
|
key_facts=get_memory_value("key_facts"),
|
||||||
key_snippets=get_memory_value("key_snippets"),
|
key_snippets=get_memory_value("key_snippets"),
|
||||||
work_log=get_memory_value("work_log"),
|
work_log=get_memory_value("work_log"),
|
||||||
research_only_note=""
|
research_only_note=(
|
||||||
if config.get("research_only")
|
""
|
||||||
else " Only request implementation if the user explicitly asked for changes to be made.",
|
if config.get("research_only")
|
||||||
|
else " Only request implementation if the user explicitly asked for changes to be made."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
|
|
@ -634,7 +650,7 @@ def run_task_implementation_agent(
|
||||||
web_research_enabled=config.get("web_research_enabled", False),
|
web_research_enabled=config.get("web_research_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
|
||||||
|
|
||||||
prompt = IMPLEMENTATION_PROMPT.format(
|
prompt = IMPLEMENTATION_PROMPT.format(
|
||||||
base_task=base_task,
|
base_task=base_task,
|
||||||
|
|
@ -647,12 +663,16 @@ def run_task_implementation_agent(
|
||||||
research_notes=get_memory_value("research_notes"),
|
research_notes=get_memory_value("research_notes"),
|
||||||
work_log=get_memory_value("work_log"),
|
work_log=get_memory_value("work_log"),
|
||||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||||
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
human_section=(
|
||||||
if _global_memory.get("config", {}).get("hil", False)
|
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||||
else "",
|
if _global_memory.get("config", {}).get("hil", False)
|
||||||
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT
|
else ""
|
||||||
if config.get("web_research_enabled")
|
),
|
||||||
else "",
|
web_research_section=(
|
||||||
|
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||||
|
if config.get("web_research_enabled")
|
||||||
|
else ""
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ def test_get_model_token_limit_anthropic(mock_memory):
|
||||||
"""Test get_model_token_limit with Anthropic model."""
|
"""Test get_model_token_limit with Anthropic model."""
|
||||||
config = {"provider": "anthropic", "model": "claude2"}
|
config = {"provider": "anthropic", "model": "claude2"}
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -43,7 +43,7 @@ def test_get_model_token_limit_openai(mock_memory):
|
||||||
"""Test get_model_token_limit with OpenAI model."""
|
"""Test get_model_token_limit with OpenAI model."""
|
||||||
config = {"provider": "openai", "model": "gpt-4"}
|
config = {"provider": "openai", "model": "gpt-4"}
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,7 +51,7 @@ def test_get_model_token_limit_unknown(mock_memory):
|
||||||
"""Test get_model_token_limit with unknown provider/model."""
|
"""Test get_model_token_limit with unknown provider/model."""
|
||||||
config = {"provider": "unknown", "model": "unknown-model"}
|
config = {"provider": "unknown", "model": "unknown-model"}
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -59,7 +59,7 @@ def test_get_model_token_limit_missing_config(mock_memory):
|
||||||
"""Test get_model_token_limit with missing configuration."""
|
"""Test get_model_token_limit with missing configuration."""
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,7 +69,7 @@ def test_get_model_token_limit_litellm_success():
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit == 100000
|
assert token_limit == 100000
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -81,7 +81,7 @@ def test_get_model_token_limit_litellm_not_found():
|
||||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||||
)
|
)
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -91,7 +91,7 @@ def test_get_model_token_limit_litellm_error():
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
mock_get_info.side_effect = Exception("Unknown error")
|
mock_get_info.side_effect = Exception("Unknown error")
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -99,7 +99,7 @@ def test_get_model_token_limit_unexpected_error():
|
||||||
"""Test returning None when unexpected errors occur."""
|
"""Test returning None when unexpected errors occur."""
|
||||||
config = None # This will cause an attribute error when accessed
|
config = None # This will cause an attribute error when accessed
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -247,3 +247,29 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
|
||||||
|
|
||||||
assert agent == "react_agent"
|
assert agent == "react_agent"
|
||||||
mock_react.assert_called_once_with(mock_model, [])
|
mock_react.assert_called_once_with(mock_model, [])
|
||||||
|
|
||||||
|
def test_get_model_token_limit_research(mock_memory):
|
||||||
|
"""Test get_model_token_limit with research provider and model."""
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-4",
|
||||||
|
"research_provider": "anthropic",
|
||||||
|
"research_model": "claude-2"
|
||||||
|
}
|
||||||
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
|
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||||
|
token_limit = get_model_token_limit(config, "research")
|
||||||
|
assert token_limit == 150000
|
||||||
|
|
||||||
|
def test_get_model_token_limit_planner(mock_memory):
|
||||||
|
"""Test get_model_token_limit with planner provider and model."""
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-4",
|
||||||
|
"planner_provider": "deepseek",
|
||||||
|
"planner_model": "dsm-1"
|
||||||
|
}
|
||||||
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
|
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||||
|
token_limit = get_model_token_limit(config, "planner")
|
||||||
|
assert token_limit == 120000
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue