Added Provider/Model Override Arguments to Seperate Research/Planner Configurations (#53)

* feat: add research and planner provider/model options to enhance configurability for research and planning tasks
refactor: create get_effective_model_config function to streamline provider/model resolution logic
test: add unit tests for effective model configuration and environment validation for research and planner providers

* refactor(agent_utils.py): remove get_effective_model_config function to simplify code and improve readability
style(agent_utils.py): format debug log statements for better readability
fix(agent_utils.py): update run_agent functions to directly use config without effective model config
feat(agent_utils.py): enhance logging for command execution in programmer.py
test(tests): remove tests related to get_effective_model_config function as it has been removed

* chore(tests): remove outdated tests for research and planner agent configurations to clean up the test suite and improve maintainability

* style(tests): apply consistent formatting and spacing in test_provider_integration.py for improved readability and maintainability
This commit is contained in:
Ariel Frischer 2025-01-24 15:00:47 -08:00 committed by GitHub
parent 54fdebfc3a
commit 6c4acfea8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 3870 additions and 66 deletions

View File

@ -166,6 +166,10 @@ ra-aid -m "Add new feature" --verbose
- `--research-only`: Only perform research without implementation
- `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini)
- `--model`: The model name to use (required for non-Anthropic providers)
- `--research-provider`: Provider to use specifically for research tasks (falls back to --provider if not specified)
- `--research-model`: Model to use specifically for research tasks (falls back to --model if not specified)
- `--planner-provider`: Provider to use specifically for planning tasks (falls back to --provider if not specified)
- `--planner-model`: Model to use specifically for planning tasks (falls back to --model if not specified)
- `--cowboy-mode`: Skip interactive approval for shell commands
- `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini)
- `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers)

View File

@ -7,11 +7,7 @@ from rich.console import Console
from langgraph.checkpoint.memory import MemorySaver
from ra_aid.config import DEFAULT_RECURSION_LIMIT
from ra_aid.env import validate_environment
from ra_aid.project_info import (
get_project_info,
format_project_info,
display_project_status,
)
from ra_aid.project_info import get_project_info, format_project_info
from ra_aid.tools.memory import _global_memory
from ra_aid.tools.human import ask_human
from ra_aid import print_stage_header, print_error
@ -81,6 +77,26 @@ Examples:
help="The LLM provider to use",
)
parser.add_argument("--model", type=str, help="The model name to use")
parser.add_argument(
"--research-provider",
type=str,
choices=VALID_PROVIDERS,
help="Provider to use specifically for research tasks",
)
parser.add_argument(
"--research-model",
type=str,
help="Model to use specifically for research tasks",
)
parser.add_argument(
"--planner-provider",
type=str,
choices=VALID_PROVIDERS,
help="Provider to use specifically for planning tasks",
)
parser.add_argument(
"--planner-model", type=str, help="Model to use specifically for planning tasks"
)
parser.add_argument(
"--cowboy-mode",
action="store_true",
@ -130,9 +146,7 @@ Examples:
help="Maximum recursion depth for agent operations (default: 100)",
)
parser.add_argument(
'--aider-config',
type=str,
help='Specify the aider config file path'
"--aider-config", type=str, help="Specify the aider config file path"
)
if args is None:
@ -223,7 +237,7 @@ def main():
if expert_missing:
console.print(
Panel(
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
+ "\n".join(f"- {m}" for m in expert_missing)
+ "\nSet the required environment variables or args to enable expert mode.",
title="Expert Tools Disabled",
@ -234,7 +248,7 @@ def main():
if web_research_missing:
console.print(
Panel(
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
+ "\n".join(f"- {m}" for m in web_research_missing)
+ "\nSet the required environment variables to enable web research.",
title="Web Research Disabled",
@ -242,11 +256,12 @@ def main():
)
)
# Create the base model after validation
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
# Handle chat mode
if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
sys.exit(1)
@ -291,7 +306,7 @@ def main():
# Create chat agent with appropriate tools
chat_agent = create_agent(
model,
chat_model,
get_chat_tools(
expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled,
@ -334,20 +349,39 @@ def main():
# Store config in global memory for access by is_informational_query
_global_memory["config"] = config
# Store model configuration
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
# Store expert provider and model in config
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = args.research_model or args.model
# Run research stage
print_stage_header("Research Stage")
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model
research_model = initialize_llm(
research_provider, research_model_name, temperature=args.temperature
)
run_research_agent(
base_task,
model,
research_model,
expert_enabled=expert_enabled,
research_only=args.research_only,
hil=args.hil,
@ -357,10 +391,17 @@ def main():
# Proceed with planning and implementation if not an informational query
if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent
run_planning_agent(
base_task,
model,
planning_model,
expert_enabled=expert_enabled,
hil=args.hil,
memory=planning_memory,

View File

@ -142,12 +142,18 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:
logger.debug(f"Using litellm token limit for {model_name}: {max_input_tokens}")
logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}"
)
return max_input_tokens
except litellm.exceptions.NotFoundError:
logger.debug(f"Model {model_name} not found in litellm, falling back to models_tokens")
logger.debug(
f"Model {model_name} not found in litellm, falling back to models_tokens"
)
except Exception as e:
logger.debug(f"Error getting model info from litellm: {e}, falling back to models_tokens")
logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_tokens"
)
# Fallback to models_tokens dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
@ -155,7 +161,9 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
provider_tokens = models_tokens.get(provider, {})
max_input_tokens = provider_tokens.get(normalized_name, None)
if max_input_tokens:
logger.debug(f"Found token limit for {provider}/{model_name}: {max_input_tokens}")
logger.debug(
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
)
else:
logger.debug(f"Could not find token limit for {provider}/{model_name}")
@ -360,7 +368,10 @@ def run_research_agent(
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config:
run_config.update(config)
@ -467,8 +478,12 @@ def run_web_research_agent(
)
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config:
run_config.update(config)
@ -551,8 +566,12 @@ def run_planning_agent(
)
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config:
run_config.update(config)
@ -642,7 +661,10 @@ def run_task_implementation_agent(
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config:
run_config.update(config)

View File

@ -5,7 +5,9 @@ from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
"""Get environment variable with optional expert prefix and fallback."""
@ -122,7 +124,18 @@ def create_llm_client(
if not config:
raise ValueError(f"Unsupported provider: {provider}")
# Only pass temperature if it's explicitly set and not in expert mode
logger.debug(
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
provider,
model_name,
temperature,
is_expert
)
# Handle temperature for expert mode
if is_expert:
temperature = 0
temp_kwargs = {}
if not is_expert and temperature is not None:
temp_kwargs = {"temperature": temperature}

View File

@ -1,5 +1,6 @@
import os
from typing import List, Dict, Union
from ra_aid.logging_config import get_logger
from ra_aid.tools.memory import _global_memory
from langchain_core.tools import tool
from rich.console import Console
@ -10,6 +11,7 @@ from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output
console = Console()
logger = get_logger(__name__)
@tool
def run_programming_task(instructions: str, files: List[str] = []) -> Dict[str, Union[str, int, bool]]:
@ -81,6 +83,7 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True
markdown_content = "".join(task_display)
console.print(Panel(Markdown(markdown_content), title="🤖 Aider Task", border_style="bright_blue"))
logger.debug(f"command: {command}")
try:
# Run the command interactively

View File

@ -60,37 +60,39 @@ def test_get_model_token_limit_missing_config(mock_memory):
assert token_limit is None
def test_get_model_token_limit_litellm_success():
"""Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config)
assert token_limit == 100000
def test_get_model_token_limit_litellm_not_found():
"""Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found",
model="claude-2",
llm_provider="anthropic"
message="Model not found", model="claude-2", llm_provider="anthropic"
)
token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["anthropic"]["claude2"]
def test_get_model_token_limit_litellm_error():
"""Test fallback to models_tokens when litellm raises other exceptions."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["anthropic"]["claude2"]
def test_get_model_token_limit_unexpected_error():
"""Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed

View File

@ -11,6 +11,10 @@ class MockArgs:
expert_provider: str
model: Optional[str] = None
expert_model: Optional[str] = None
research_provider: Optional[str] = None
research_model: Optional[str] = None
planner_provider: Optional[str] = None
planner_model: Optional[str] = None
@pytest.fixture
def clean_env(monkeypatch):
@ -166,6 +170,7 @@ def test_different_providers_no_expert_key(clean_env, monkeypatch):
assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
"""Test behavior with openai-compatible expert and different main provider"""
args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")

View File

@ -121,3 +121,45 @@ def test_missing_message():
# Verify message is captured when provided
args = parse_arguments(["-m", "test"])
assert args.message == "test"
def test_research_model_provider_args(mock_dependencies):
"""Test that research-specific model/provider args are correctly stored in config."""
from ra_aid.__main__ import main
import sys
from unittest.mock import patch
_global_memory.clear()
with patch.object(sys, 'argv', [
'ra-aid', '-m', 'test message',
'--research-provider', 'anthropic',
'--research-model', 'claude-3-haiku-20240307',
'--planner-provider', 'openai',
'--planner-model', 'gpt-4'
]):
main()
config = _global_memory["config"]
assert config["research_provider"] == "anthropic"
assert config["research_model"] == "claude-3-haiku-20240307"
assert config["planner_provider"] == "openai"
assert config["planner_model"] == "gpt-4"
def test_planner_model_provider_args(mock_dependencies):
"""Test that planner provider/model args fall back to main config when not specified."""
from ra_aid.__main__ import main
import sys
from unittest.mock import patch
_global_memory.clear()
with patch.object(sys, 'argv', [
'ra-aid', '-m', 'test message',
'--provider', 'openai',
'--model', 'gpt-4'
]):
main()
config = _global_memory["config"]
assert config["planner_provider"] == "openai"
assert config["planner_model"] == "gpt-4"

View File

@ -13,16 +13,23 @@ from ra_aid.provider_strategy import (
OpenAIStrategy,
OpenAICompatibleStrategy,
OpenRouterStrategy,
GeminiStrategy
GeminiStrategy,
)
@dataclass
class MockArgs:
"""Mock arguments for testing."""
provider: str
expert_provider: Optional[str] = None
model: Optional[str] = None
expert_model: Optional[str] = None
research_provider: Optional[str] = None
research_model: Optional[str] = None
planner_provider: Optional[str] = None
planner_model: Optional[str] = None
@pytest.fixture
def clean_env():
@ -39,7 +46,7 @@ def clean_env():
"ANTHROPIC_MODEL",
"GEMINI_API_KEY",
"EXPERT_GEMINI_API_KEY",
"GEMINI_MODEL"
"GEMINI_MODEL",
]
# Store original values
@ -58,6 +65,7 @@ def clean_env():
elif var in os.environ:
del os.environ[var]
def test_provider_validation_respects_cli_args(clean_env):
"""Test that provider validation respects CLI args over defaults."""
# Set up environment with only OpenAI credentials
@ -65,7 +73,9 @@ def test_provider_validation_respects_cli_args(clean_env):
# Should succeed with OpenAI provider
args = MockArgs(provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
args
)
assert not expert_missing
# Should fail with Anthropic provider even though it's first alphabetically
@ -73,6 +83,7 @@ def test_provider_validation_respects_cli_args(clean_env):
with pytest.raises(SystemExit):
validate_environment(args)
def test_expert_provider_fallback(clean_env):
"""Test expert provider falls back to main provider keys."""
os.environ["OPENAI_API_KEY"] = "test-key"
@ -82,6 +93,7 @@ def test_expert_provider_fallback(clean_env):
assert expert_enabled
assert not expert_missing
def test_openai_compatible_base_url(clean_env):
"""Test OpenAI-compatible provider requires base URL."""
os.environ["OPENAI_API_KEY"] = "test-key"
@ -94,6 +106,7 @@ def test_openai_compatible_base_url(clean_env):
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_missing
def test_expert_provider_separate_keys(clean_env):
"""Test expert provider can use separate keys."""
os.environ["OPENAI_API_KEY"] = "main-key"
@ -104,22 +117,28 @@ def test_expert_provider_separate_keys(clean_env):
assert expert_enabled
assert not expert_missing
def test_web_research_independent(clean_env):
"""Test web research validation is independent of provider."""
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai")
# Without Tavily key
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
args
)
assert not web_enabled
assert web_missing
# With Tavily key
os.environ["TAVILY_API_KEY"] = "test-key"
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
args
)
assert web_enabled
assert not web_missing
def test_provider_factory_unknown_provider(clean_env):
"""Test provider factory handles unknown providers."""
strategy = ProviderFactory.create("unknown")
@ -129,6 +148,7 @@ def test_provider_factory_unknown_provider(clean_env):
with pytest.raises(SystemExit):
validate_environment(args)
def test_provider_strategy_validation(clean_env):
"""Test individual provider strategies."""
# Test Anthropic strategy
@ -154,6 +174,7 @@ def test_provider_strategy_validation(clean_env):
assert result.valid
assert not result.missing_vars
def test_missing_provider_arg():
"""Test handling of missing provider argument."""
with pytest.raises(SystemExit):
@ -162,11 +183,13 @@ def test_missing_provider_arg():
with pytest.raises(SystemExit):
validate_environment(MockArgs(provider=None))
def test_empty_provider_arg():
"""Test handling of empty provider argument."""
with pytest.raises(SystemExit):
validate_environment(MockArgs(provider=""))
def test_incomplete_openai_compatible_config(clean_env):
"""Test OpenAI-compatible provider with incomplete configuration."""
strategy = OpenAICompatibleStrategy()
@ -206,6 +229,7 @@ def test_incomplete_gemini_config(clean_env):
assert result.valid
assert not result.missing_vars
def test_incomplete_expert_config(clean_env):
"""Test expert provider with incomplete configuration."""
# Set main provider but not expert
@ -240,6 +264,7 @@ def test_empty_environment_variables(clean_env):
with pytest.raises(SystemExit):
validate_environment(args)
def test_openrouter_validation(clean_env):
"""Test OpenRouter provider validation."""
strategy = OpenRouterStrategy()
@ -261,6 +286,7 @@ def test_openrouter_validation(clean_env):
assert result.valid
assert not result.missing_vars
def test_multiple_expert_providers(clean_env):
"""Test validation with multiple expert providers."""
os.environ["OPENAI_API_KEY"] = "test-key"
@ -268,7 +294,11 @@ def test_multiple_expert_providers(clean_env):
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
# First expert provider valid, second invalid
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
args = MockArgs(
provider="openai",
expert_provider="anthropic",
expert_model="claude-3-haiku-20240307",
)
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled
assert not expert_missing
@ -278,4 +308,3 @@ def test_multiple_expert_providers(clean_env):
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert expert_missing

3643
uv.lock Normal file

File diff suppressed because it is too large Load Diff