diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 05b7bde..74ce26d 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -5,7 +5,7 @@ import sys import threading import time import uuid -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError @@ -122,15 +122,24 @@ def state_modifier( return [first_message] + trimmed_remaining -def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]: - """Get the token limit for the current model configuration. +def get_model_token_limit( + config: Dict[str, Any], agent_type: Literal["default", "research", "planner"] +) -> Optional[int]: + """Get the token limit for the current model configuration based on agent type. Returns: Optional[int]: The token limit if found, None otherwise """ try: - provider = config.get("provider", "") - model_name = config.get("model", "") + if agent_type == "research": + provider = config.get("research_provider", "") or config.get("provider", "") + model_name = config.get("research_model", "") or config.get("model", "") + elif agent_type == "planner": + provider = config.get("planner_provider", "") or config.get("provider", "") + model_name = config.get("planner_model", "") or config.get("model", "") + else: + provider = config.get("provider", "") + model_name = config.get("model", "") try: provider_model = model_name if not provider else f"{provider}/{model_name}" @@ -224,6 +233,7 @@ def create_agent( tools: List[Any], *, checkpointer: Any = None, + agent_type: str = "default", ) -> Any: """Create a react agent with the given configuration. @@ -245,7 +255,9 @@ def create_agent( """ try: config = _global_memory.get("config", {}) - max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT + max_input_tokens = ( + get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT + ) # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): @@ -260,7 +272,7 @@ def create_agent( # Default to REACT agent if provider/model detection fails logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = _global_memory.get("config", {}) - max_input_tokens = get_model_token_limit(config) + max_input_tokens = get_model_token_limit(config, agent_type) agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) return create_react_agent(model, tools, **agent_kwargs) @@ -326,7 +338,7 @@ def run_research_agent( web_research_enabled=config.get("web_research_enabled", False), ) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="research") expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" @@ -349,9 +361,11 @@ def run_research_agent( prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( base_task=base_task_or_query, - research_only_note="" - if research_only - else " Only request implementation if the user explicitly asked for changes to be made.", + research_only_note=( + "" + if research_only + else " Only request implementation if the user explicitly asked for changes to be made." + ), expert_section=expert_section, human_section=human_section, web_research_section=web_research_section, @@ -455,7 +469,7 @@ def run_web_research_agent( tools = get_web_research_tools(expert_enabled=expert_enabled) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="research") expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" @@ -536,7 +550,7 @@ def run_planning_agent( web_research_enabled=config.get("web_research_enabled", False), ) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="planner") expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" @@ -556,9 +570,11 @@ def run_planning_agent( key_facts=get_memory_value("key_facts"), key_snippets=get_memory_value("key_snippets"), work_log=get_memory_value("work_log"), - research_only_note="" - if config.get("research_only") - else " Only request implementation if the user explicitly asked for changes to be made.", + research_only_note=( + "" + if config.get("research_only") + else " Only request implementation if the user explicitly asked for changes to be made." + ), ) config = _global_memory.get("config", {}) if not config else config @@ -634,7 +650,7 @@ def run_task_implementation_agent( web_research_enabled=config.get("web_research_enabled", False), ) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="planner") prompt = IMPLEMENTATION_PROMPT.format( base_task=base_task, @@ -647,12 +663,16 @@ def run_task_implementation_agent( research_notes=get_memory_value("research_notes"), work_log=get_memory_value("work_log"), expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", - human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION - if _global_memory.get("config", {}).get("hil", False) - else "", - web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT - if config.get("web_research_enabled") - else "", + human_section=( + HUMAN_PROMPT_SECTION_IMPLEMENTATION + if _global_memory.get("config", {}).get("hil", False) + else "" + ), + web_research_section=( + WEB_RESEARCH_PROMPT_SECTION_CHAT + if config.get("web_research_enabled") + else "" + ), ) config = _global_memory.get("config", {}) if not config else config diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index d557c97..549a862 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -35,7 +35,7 @@ def test_get_model_token_limit_anthropic(mock_memory): """Test get_model_token_limit with Anthropic model.""" config = {"provider": "anthropic", "model": "claude2"} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] @@ -43,7 +43,7 @@ def test_get_model_token_limit_openai(mock_memory): """Test get_model_token_limit with OpenAI model.""" config = {"provider": "openai", "model": "gpt-4"} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] @@ -51,7 +51,7 @@ def test_get_model_token_limit_unknown(mock_memory): """Test get_model_token_limit with unknown provider/model.""" config = {"provider": "unknown", "model": "unknown-model"} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit is None @@ -59,7 +59,7 @@ def test_get_model_token_limit_missing_config(mock_memory): """Test get_model_token_limit with missing configuration.""" config = {} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit is None @@ -69,7 +69,7 @@ def test_get_model_token_limit_litellm_success(): with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 100000} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == 100000 @@ -81,7 +81,7 @@ def test_get_model_token_limit_litellm_not_found(): mock_get_info.side_effect = litellm.exceptions.NotFoundError( message="Model not found", model="claude-2", llm_provider="anthropic" ) - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] @@ -91,7 +91,7 @@ def test_get_model_token_limit_litellm_error(): with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.side_effect = Exception("Unknown error") - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] @@ -99,7 +99,7 @@ def test_get_model_token_limit_unexpected_error(): """Test returning None when unexpected errors occur.""" config = None # This will cause an attribute error when accessed - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit is None @@ -247,3 +247,29 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory) assert agent == "react_agent" mock_react.assert_called_once_with(mock_model, []) + +def test_get_model_token_limit_research(mock_memory): + """Test get_model_token_limit with research provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "research_provider": "anthropic", + "research_model": "claude-2" + } + with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: + mock_get_info.return_value = {"max_input_tokens": 150000} + token_limit = get_model_token_limit(config, "research") + assert token_limit == 150000 + +def test_get_model_token_limit_planner(mock_memory): + """Test get_model_token_limit with planner provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "planner_provider": "deepseek", + "planner_model": "dsm-1" + } + with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: + mock_get_info.return_value = {"max_input_tokens": 120000} + token_limit = get_model_token_limit(config, "planner") + assert token_limit == 120000