Added Provider/Model Override Arguments to Seperate Research/Planner Configurations (#53)
* feat: add research and planner provider/model options to enhance configurability for research and planning tasks refactor: create get_effective_model_config function to streamline provider/model resolution logic test: add unit tests for effective model configuration and environment validation for research and planner providers * refactor(agent_utils.py): remove get_effective_model_config function to simplify code and improve readability style(agent_utils.py): format debug log statements for better readability fix(agent_utils.py): update run_agent functions to directly use config without effective model config feat(agent_utils.py): enhance logging for command execution in programmer.py test(tests): remove tests related to get_effective_model_config function as it has been removed * chore(tests): remove outdated tests for research and planner agent configurations to clean up the test suite and improve maintainability * style(tests): apply consistent formatting and spacing in test_provider_integration.py for improved readability and maintainability
This commit is contained in:
parent
54fdebfc3a
commit
6c4acfea8b
|
|
@ -166,6 +166,10 @@ ra-aid -m "Add new feature" --verbose
|
||||||
- `--research-only`: Only perform research without implementation
|
- `--research-only`: Only perform research without implementation
|
||||||
- `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
- `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
||||||
- `--model`: The model name to use (required for non-Anthropic providers)
|
- `--model`: The model name to use (required for non-Anthropic providers)
|
||||||
|
- `--research-provider`: Provider to use specifically for research tasks (falls back to --provider if not specified)
|
||||||
|
- `--research-model`: Model to use specifically for research tasks (falls back to --model if not specified)
|
||||||
|
- `--planner-provider`: Provider to use specifically for planning tasks (falls back to --provider if not specified)
|
||||||
|
- `--planner-model`: Model to use specifically for planning tasks (falls back to --model if not specified)
|
||||||
- `--cowboy-mode`: Skip interactive approval for shell commands
|
- `--cowboy-mode`: Skip interactive approval for shell commands
|
||||||
- `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
- `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
||||||
- `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers)
|
- `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers)
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,7 @@ from rich.console import Console
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.project_info import (
|
from ra_aid.project_info import get_project_info, format_project_info
|
||||||
get_project_info,
|
|
||||||
format_project_info,
|
|
||||||
display_project_status,
|
|
||||||
)
|
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
from ra_aid.tools.human import ask_human
|
from ra_aid.tools.human import ask_human
|
||||||
from ra_aid import print_stage_header, print_error
|
from ra_aid import print_stage_header, print_error
|
||||||
|
|
@ -81,6 +77,26 @@ Examples:
|
||||||
help="The LLM provider to use",
|
help="The LLM provider to use",
|
||||||
)
|
)
|
||||||
parser.add_argument("--model", type=str, help="The model name to use")
|
parser.add_argument("--model", type=str, help="The model name to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"--research-provider",
|
||||||
|
type=str,
|
||||||
|
choices=VALID_PROVIDERS,
|
||||||
|
help="Provider to use specifically for research tasks",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--research-model",
|
||||||
|
type=str,
|
||||||
|
help="Model to use specifically for research tasks",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--planner-provider",
|
||||||
|
type=str,
|
||||||
|
choices=VALID_PROVIDERS,
|
||||||
|
help="Provider to use specifically for planning tasks",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--planner-model", type=str, help="Model to use specifically for planning tasks"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cowboy-mode",
|
"--cowboy-mode",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -130,9 +146,7 @@ Examples:
|
||||||
help="Maximum recursion depth for agent operations (default: 100)",
|
help="Maximum recursion depth for agent operations (default: 100)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--aider-config',
|
"--aider-config", type=str, help="Specify the aider config file path"
|
||||||
type=str,
|
|
||||||
help='Specify the aider config file path'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args is None:
|
if args is None:
|
||||||
|
|
@ -223,7 +237,7 @@ def main():
|
||||||
if expert_missing:
|
if expert_missing:
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
|
"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
|
||||||
+ "\n".join(f"- {m}" for m in expert_missing)
|
+ "\n".join(f"- {m}" for m in expert_missing)
|
||||||
+ "\nSet the required environment variables or args to enable expert mode.",
|
+ "\nSet the required environment variables or args to enable expert mode.",
|
||||||
title="Expert Tools Disabled",
|
title="Expert Tools Disabled",
|
||||||
|
|
@ -234,7 +248,7 @@ def main():
|
||||||
if web_research_missing:
|
if web_research_missing:
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
|
"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
|
||||||
+ "\n".join(f"- {m}" for m in web_research_missing)
|
+ "\n".join(f"- {m}" for m in web_research_missing)
|
||||||
+ "\nSet the required environment variables to enable web research.",
|
+ "\nSet the required environment variables to enable web research.",
|
||||||
title="Web Research Disabled",
|
title="Web Research Disabled",
|
||||||
|
|
@ -242,11 +256,12 @@ def main():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the base model after validation
|
|
||||||
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
|
|
||||||
|
|
||||||
# Handle chat mode
|
# Handle chat mode
|
||||||
if args.chat:
|
if args.chat:
|
||||||
|
# Initialize chat model with default provider/model
|
||||||
|
chat_model = initialize_llm(
|
||||||
|
args.provider, args.model, temperature=args.temperature
|
||||||
|
)
|
||||||
if args.research_only:
|
if args.research_only:
|
||||||
print_error("Chat mode cannot be used with --research-only")
|
print_error("Chat mode cannot be used with --research-only")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
@ -291,7 +306,7 @@ def main():
|
||||||
|
|
||||||
# Create chat agent with appropriate tools
|
# Create chat agent with appropriate tools
|
||||||
chat_agent = create_agent(
|
chat_agent = create_agent(
|
||||||
model,
|
chat_model,
|
||||||
get_chat_tools(
|
get_chat_tools(
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
web_research_enabled=web_research_enabled,
|
web_research_enabled=web_research_enabled,
|
||||||
|
|
@ -334,20 +349,39 @@ def main():
|
||||||
# Store config in global memory for access by is_informational_query
|
# Store config in global memory for access by is_informational_query
|
||||||
_global_memory["config"] = config
|
_global_memory["config"] = config
|
||||||
|
|
||||||
# Store model configuration
|
# Store base provider/model configuration
|
||||||
_global_memory["config"]["provider"] = args.provider
|
_global_memory["config"]["provider"] = args.provider
|
||||||
_global_memory["config"]["model"] = args.model
|
_global_memory["config"]["model"] = args.model
|
||||||
|
|
||||||
# Store expert provider and model in config
|
# Store expert provider/model (no fallback)
|
||||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||||
_global_memory["config"]["expert_model"] = args.expert_model
|
_global_memory["config"]["expert_model"] = args.expert_model
|
||||||
|
|
||||||
|
# Store planner config with fallback to base values
|
||||||
|
_global_memory["config"]["planner_provider"] = (
|
||||||
|
args.planner_provider or args.provider
|
||||||
|
)
|
||||||
|
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||||
|
|
||||||
|
# Store research config with fallback to base values
|
||||||
|
_global_memory["config"]["research_provider"] = (
|
||||||
|
args.research_provider or args.provider
|
||||||
|
)
|
||||||
|
_global_memory["config"]["research_model"] = args.research_model or args.model
|
||||||
|
|
||||||
# Run research stage
|
# Run research stage
|
||||||
print_stage_header("Research Stage")
|
print_stage_header("Research Stage")
|
||||||
|
|
||||||
|
# Initialize research model with potential overrides
|
||||||
|
research_provider = args.research_provider or args.provider
|
||||||
|
research_model_name = args.research_model or args.model
|
||||||
|
research_model = initialize_llm(
|
||||||
|
research_provider, research_model_name, temperature=args.temperature
|
||||||
|
)
|
||||||
|
|
||||||
run_research_agent(
|
run_research_agent(
|
||||||
base_task,
|
base_task,
|
||||||
model,
|
research_model,
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
research_only=args.research_only,
|
research_only=args.research_only,
|
||||||
hil=args.hil,
|
hil=args.hil,
|
||||||
|
|
@ -357,10 +391,17 @@ def main():
|
||||||
|
|
||||||
# Proceed with planning and implementation if not an informational query
|
# Proceed with planning and implementation if not an informational query
|
||||||
if not is_informational_query():
|
if not is_informational_query():
|
||||||
|
# Initialize planning model with potential overrides
|
||||||
|
planner_provider = args.planner_provider or args.provider
|
||||||
|
planner_model_name = args.planner_model or args.model
|
||||||
|
planning_model = initialize_llm(
|
||||||
|
planner_provider, planner_model_name, temperature=args.temperature
|
||||||
|
)
|
||||||
|
|
||||||
# Run planning agent
|
# Run planning agent
|
||||||
run_planning_agent(
|
run_planning_agent(
|
||||||
base_task,
|
base_task,
|
||||||
model,
|
planning_model,
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
hil=args.hil,
|
hil=args.hil,
|
||||||
memory=planning_memory,
|
memory=planning_memory,
|
||||||
|
|
|
||||||
|
|
@ -142,12 +142,18 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
||||||
model_info = get_model_info(provider_model)
|
model_info = get_model_info(provider_model)
|
||||||
max_input_tokens = model_info.get("max_input_tokens")
|
max_input_tokens = model_info.get("max_input_tokens")
|
||||||
if max_input_tokens:
|
if max_input_tokens:
|
||||||
logger.debug(f"Using litellm token limit for {model_name}: {max_input_tokens}")
|
logger.debug(
|
||||||
|
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||||
|
)
|
||||||
return max_input_tokens
|
return max_input_tokens
|
||||||
except litellm.exceptions.NotFoundError:
|
except litellm.exceptions.NotFoundError:
|
||||||
logger.debug(f"Model {model_name} not found in litellm, falling back to models_tokens")
|
logger.debug(
|
||||||
|
f"Model {model_name} not found in litellm, falling back to models_tokens"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error getting model info from litellm: {e}, falling back to models_tokens")
|
logger.debug(
|
||||||
|
f"Error getting model info from litellm: {e}, falling back to models_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
# Fallback to models_tokens dict
|
# Fallback to models_tokens dict
|
||||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||||
|
|
@ -155,7 +161,9 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
||||||
provider_tokens = models_tokens.get(provider, {})
|
provider_tokens = models_tokens.get(provider, {})
|
||||||
max_input_tokens = provider_tokens.get(normalized_name, None)
|
max_input_tokens = provider_tokens.get(normalized_name, None)
|
||||||
if max_input_tokens:
|
if max_input_tokens:
|
||||||
logger.debug(f"Found token limit for {provider}/{model_name}: {max_input_tokens}")
|
logger.debug(
|
||||||
|
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||||
|
|
||||||
|
|
@ -360,7 +368,10 @@ def run_research_agent(
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
run_config = {
|
||||||
|
"configurable": {"thread_id": thread_id},
|
||||||
|
"recursion_limit": recursion_limit,
|
||||||
|
}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
|
|
@ -467,8 +478,12 @@ def run_web_research_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
|
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
run_config = {
|
||||||
|
"configurable": {"thread_id": thread_id},
|
||||||
|
"recursion_limit": recursion_limit,
|
||||||
|
}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
|
|
@ -551,8 +566,12 @@ def run_planning_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
|
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
run_config = {
|
||||||
|
"configurable": {"thread_id": thread_id},
|
||||||
|
"recursion_limit": recursion_limit,
|
||||||
|
}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
|
|
@ -642,7 +661,10 @@ def run_task_implementation_agent(
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
run_config = {
|
||||||
|
"configurable": {"thread_id": thread_id},
|
||||||
|
"recursion_limit": recursion_limit,
|
||||||
|
}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,9 @@ from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
|
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
|
||||||
"""Get environment variable with optional expert prefix and fallback."""
|
"""Get environment variable with optional expert prefix and fallback."""
|
||||||
|
|
@ -122,7 +124,18 @@ def create_llm_client(
|
||||||
if not config:
|
if not config:
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
# Only pass temperature if it's explicitly set and not in expert mode
|
logger.debug(
|
||||||
|
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
|
||||||
|
provider,
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
is_expert
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle temperature for expert mode
|
||||||
|
if is_expert:
|
||||||
|
temperature = 0
|
||||||
|
|
||||||
temp_kwargs = {}
|
temp_kwargs = {}
|
||||||
if not is_expert and temperature is not None:
|
if not is_expert and temperature is not None:
|
||||||
temp_kwargs = {"temperature": temperature}
|
temp_kwargs = {"temperature": temperature}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
@ -10,6 +11,7 @@ from ra_aid.proc.interactive import run_interactive_command
|
||||||
from ra_aid.text.processing import truncate_output
|
from ra_aid.text.processing import truncate_output
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def run_programming_task(instructions: str, files: List[str] = []) -> Dict[str, Union[str, int, bool]]:
|
def run_programming_task(instructions: str, files: List[str] = []) -> Dict[str, Union[str, int, bool]]:
|
||||||
|
|
@ -81,6 +83,7 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True
|
||||||
|
|
||||||
markdown_content = "".join(task_display)
|
markdown_content = "".join(task_display)
|
||||||
console.print(Panel(Markdown(markdown_content), title="🤖 Aider Task", border_style="bright_blue"))
|
console.print(Panel(Markdown(markdown_content), title="🤖 Aider Task", border_style="bright_blue"))
|
||||||
|
logger.debug(f"command: {command}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the command interactively
|
# Run the command interactively
|
||||||
|
|
|
||||||
|
|
@ -60,41 +60,43 @@ def test_get_model_token_limit_missing_config(mock_memory):
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_success():
|
def test_get_model_token_limit_litellm_success():
|
||||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
|
||||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config)
|
||||||
assert token_limit == 100000
|
assert token_limit == 100000
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_not_found():
|
def test_get_model_token_limit_litellm_not_found():
|
||||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
|
||||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||||
message="Model not found",
|
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||||
model="claude-2",
|
|
||||||
llm_provider="anthropic"
|
|
||||||
)
|
)
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config)
|
||||||
assert token_limit == models_tokens["anthropic"]["claude2"]
|
assert token_limit == models_tokens["anthropic"]["claude2"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_error():
|
def test_get_model_token_limit_litellm_error():
|
||||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
|
||||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
mock_get_info.side_effect = Exception("Unknown error")
|
mock_get_info.side_effect = Exception("Unknown error")
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config)
|
||||||
assert token_limit == models_tokens["anthropic"]["claude2"]
|
assert token_limit == models_tokens["anthropic"]["claude2"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_unexpected_error():
|
def test_get_model_token_limit_unexpected_error():
|
||||||
"""Test returning None when unexpected errors occur."""
|
"""Test returning None when unexpected errors occur."""
|
||||||
config = None # This will cause an attribute error when accessed
|
config = None # This will cause an attribute error when accessed
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config)
|
token_limit = get_model_token_limit(config)
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,10 @@ class MockArgs:
|
||||||
expert_provider: str
|
expert_provider: str
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
expert_model: Optional[str] = None
|
expert_model: Optional[str] = None
|
||||||
|
research_provider: Optional[str] = None
|
||||||
|
research_model: Optional[str] = None
|
||||||
|
planner_provider: Optional[str] = None
|
||||||
|
planner_model: Optional[str] = None
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def clean_env(monkeypatch):
|
def clean_env(monkeypatch):
|
||||||
|
|
@ -166,6 +170,7 @@ def test_different_providers_no_expert_key(clean_env, monkeypatch):
|
||||||
assert not web_research_enabled
|
assert not web_research_enabled
|
||||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||||
|
|
||||||
|
|
||||||
def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
|
def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
|
||||||
"""Test behavior with openai-compatible expert and different main provider"""
|
"""Test behavior with openai-compatible expert and different main provider"""
|
||||||
args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")
|
args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")
|
||||||
|
|
|
||||||
|
|
@ -121,3 +121,45 @@ def test_missing_message():
|
||||||
# Verify message is captured when provided
|
# Verify message is captured when provided
|
||||||
args = parse_arguments(["-m", "test"])
|
args = parse_arguments(["-m", "test"])
|
||||||
assert args.message == "test"
|
assert args.message == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_research_model_provider_args(mock_dependencies):
|
||||||
|
"""Test that research-specific model/provider args are correctly stored in config."""
|
||||||
|
from ra_aid.__main__ import main
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
_global_memory.clear()
|
||||||
|
|
||||||
|
with patch.object(sys, 'argv', [
|
||||||
|
'ra-aid', '-m', 'test message',
|
||||||
|
'--research-provider', 'anthropic',
|
||||||
|
'--research-model', 'claude-3-haiku-20240307',
|
||||||
|
'--planner-provider', 'openai',
|
||||||
|
'--planner-model', 'gpt-4'
|
||||||
|
]):
|
||||||
|
main()
|
||||||
|
config = _global_memory["config"]
|
||||||
|
assert config["research_provider"] == "anthropic"
|
||||||
|
assert config["research_model"] == "claude-3-haiku-20240307"
|
||||||
|
assert config["planner_provider"] == "openai"
|
||||||
|
assert config["planner_model"] == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_planner_model_provider_args(mock_dependencies):
|
||||||
|
"""Test that planner provider/model args fall back to main config when not specified."""
|
||||||
|
from ra_aid.__main__ import main
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
_global_memory.clear()
|
||||||
|
|
||||||
|
with patch.object(sys, 'argv', [
|
||||||
|
'ra-aid', '-m', 'test message',
|
||||||
|
'--provider', 'openai',
|
||||||
|
'--model', 'gpt-4'
|
||||||
|
]):
|
||||||
|
main()
|
||||||
|
config = _global_memory["config"]
|
||||||
|
assert config["planner_provider"] == "openai"
|
||||||
|
assert config["planner_model"] == "gpt-4"
|
||||||
|
|
|
||||||
|
|
@ -13,16 +13,23 @@ from ra_aid.provider_strategy import (
|
||||||
OpenAIStrategy,
|
OpenAIStrategy,
|
||||||
OpenAICompatibleStrategy,
|
OpenAICompatibleStrategy,
|
||||||
OpenRouterStrategy,
|
OpenRouterStrategy,
|
||||||
GeminiStrategy
|
GeminiStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockArgs:
|
class MockArgs:
|
||||||
"""Mock arguments for testing."""
|
"""Mock arguments for testing."""
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
expert_provider: Optional[str] = None
|
expert_provider: Optional[str] = None
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
expert_model: Optional[str] = None
|
expert_model: Optional[str] = None
|
||||||
|
research_provider: Optional[str] = None
|
||||||
|
research_model: Optional[str] = None
|
||||||
|
planner_provider: Optional[str] = None
|
||||||
|
planner_model: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def clean_env():
|
def clean_env():
|
||||||
|
|
@ -39,18 +46,18 @@ def clean_env():
|
||||||
"ANTHROPIC_MODEL",
|
"ANTHROPIC_MODEL",
|
||||||
"GEMINI_API_KEY",
|
"GEMINI_API_KEY",
|
||||||
"EXPERT_GEMINI_API_KEY",
|
"EXPERT_GEMINI_API_KEY",
|
||||||
"GEMINI_MODEL"
|
"GEMINI_MODEL",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Store original values
|
# Store original values
|
||||||
original_values = {}
|
original_values = {}
|
||||||
for var in env_vars:
|
for var in env_vars:
|
||||||
original_values[var] = os.environ.get(var)
|
original_values[var] = os.environ.get(var)
|
||||||
if var in os.environ:
|
if var in os.environ:
|
||||||
del os.environ[var]
|
del os.environ[var]
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Restore original values
|
# Restore original values
|
||||||
for var, value in original_values.items():
|
for var, value in original_values.items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
|
|
@ -58,6 +65,7 @@ def clean_env():
|
||||||
elif var in os.environ:
|
elif var in os.environ:
|
||||||
del os.environ[var]
|
del os.environ[var]
|
||||||
|
|
||||||
|
|
||||||
def test_provider_validation_respects_cli_args(clean_env):
|
def test_provider_validation_respects_cli_args(clean_env):
|
||||||
"""Test that provider validation respects CLI args over defaults."""
|
"""Test that provider validation respects CLI args over defaults."""
|
||||||
# Set up environment with only OpenAI credentials
|
# Set up environment with only OpenAI credentials
|
||||||
|
|
@ -65,7 +73,9 @@ def test_provider_validation_respects_cli_args(clean_env):
|
||||||
|
|
||||||
# Should succeed with OpenAI provider
|
# Should succeed with OpenAI provider
|
||||||
args = MockArgs(provider="openai")
|
args = MockArgs(provider="openai")
|
||||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
|
||||||
|
args
|
||||||
|
)
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
|
|
||||||
# Should fail with Anthropic provider even though it's first alphabetically
|
# Should fail with Anthropic provider even though it's first alphabetically
|
||||||
|
|
@ -73,62 +83,72 @@ def test_provider_validation_respects_cli_args(clean_env):
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
|
|
||||||
def test_expert_provider_fallback(clean_env):
|
def test_expert_provider_fallback(clean_env):
|
||||||
"""Test expert provider falls back to main provider keys."""
|
"""Test expert provider falls back to main provider keys."""
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
args = MockArgs(provider="openai", expert_provider="openai")
|
args = MockArgs(provider="openai", expert_provider="openai")
|
||||||
|
|
||||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
|
|
||||||
|
|
||||||
def test_openai_compatible_base_url(clean_env):
|
def test_openai_compatible_base_url(clean_env):
|
||||||
"""Test OpenAI-compatible provider requires base URL."""
|
"""Test OpenAI-compatible provider requires base URL."""
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
args = MockArgs(provider="openai-compatible")
|
args = MockArgs(provider="openai-compatible")
|
||||||
|
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
os.environ["OPENAI_API_BASE"] = "http://test"
|
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
|
|
||||||
|
|
||||||
def test_expert_provider_separate_keys(clean_env):
|
def test_expert_provider_separate_keys(clean_env):
|
||||||
"""Test expert provider can use separate keys."""
|
"""Test expert provider can use separate keys."""
|
||||||
os.environ["OPENAI_API_KEY"] = "main-key"
|
os.environ["OPENAI_API_KEY"] = "main-key"
|
||||||
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
|
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
|
||||||
|
|
||||||
args = MockArgs(provider="openai", expert_provider="openai")
|
args = MockArgs(provider="openai", expert_provider="openai")
|
||||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
|
|
||||||
|
|
||||||
def test_web_research_independent(clean_env):
|
def test_web_research_independent(clean_env):
|
||||||
"""Test web research validation is independent of provider."""
|
"""Test web research validation is independent of provider."""
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
args = MockArgs(provider="openai")
|
args = MockArgs(provider="openai")
|
||||||
|
|
||||||
# Without Tavily key
|
# Without Tavily key
|
||||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
|
||||||
|
args
|
||||||
|
)
|
||||||
assert not web_enabled
|
assert not web_enabled
|
||||||
assert web_missing
|
assert web_missing
|
||||||
|
|
||||||
# With Tavily key
|
# With Tavily key
|
||||||
os.environ["TAVILY_API_KEY"] = "test-key"
|
os.environ["TAVILY_API_KEY"] = "test-key"
|
||||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
|
||||||
|
args
|
||||||
|
)
|
||||||
assert web_enabled
|
assert web_enabled
|
||||||
assert not web_missing
|
assert not web_missing
|
||||||
|
|
||||||
|
|
||||||
def test_provider_factory_unknown_provider(clean_env):
|
def test_provider_factory_unknown_provider(clean_env):
|
||||||
"""Test provider factory handles unknown providers."""
|
"""Test provider factory handles unknown providers."""
|
||||||
strategy = ProviderFactory.create("unknown")
|
strategy = ProviderFactory.create("unknown")
|
||||||
assert strategy is None
|
assert strategy is None
|
||||||
|
|
||||||
args = MockArgs(provider="unknown")
|
args = MockArgs(provider="unknown")
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
|
|
||||||
def test_provider_strategy_validation(clean_env):
|
def test_provider_strategy_validation(clean_env):
|
||||||
"""Test individual provider strategies."""
|
"""Test individual provider strategies."""
|
||||||
# Test Anthropic strategy
|
# Test Anthropic strategy
|
||||||
|
|
@ -154,35 +174,38 @@ def test_provider_strategy_validation(clean_env):
|
||||||
assert result.valid
|
assert result.valid
|
||||||
assert not result.missing_vars
|
assert not result.missing_vars
|
||||||
|
|
||||||
|
|
||||||
def test_missing_provider_arg():
|
def test_missing_provider_arg():
|
||||||
"""Test handling of missing provider argument."""
|
"""Test handling of missing provider argument."""
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(None)
|
validate_environment(None)
|
||||||
|
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(MockArgs(provider=None))
|
validate_environment(MockArgs(provider=None))
|
||||||
|
|
||||||
|
|
||||||
def test_empty_provider_arg():
|
def test_empty_provider_arg():
|
||||||
"""Test handling of empty provider argument."""
|
"""Test handling of empty provider argument."""
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(MockArgs(provider=""))
|
validate_environment(MockArgs(provider=""))
|
||||||
|
|
||||||
|
|
||||||
def test_incomplete_openai_compatible_config(clean_env):
|
def test_incomplete_openai_compatible_config(clean_env):
|
||||||
"""Test OpenAI-compatible provider with incomplete configuration."""
|
"""Test OpenAI-compatible provider with incomplete configuration."""
|
||||||
strategy = OpenAICompatibleStrategy()
|
strategy = OpenAICompatibleStrategy()
|
||||||
|
|
||||||
# No configuration
|
# No configuration
|
||||||
result = strategy.validate()
|
result = strategy.validate()
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||||
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
# Only API key
|
# Only API key
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
result = strategy.validate()
|
result = strategy.validate()
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
# Only base URL
|
# Only base URL
|
||||||
os.environ.pop("OPENAI_API_KEY")
|
os.environ.pop("OPENAI_API_KEY")
|
||||||
os.environ["OPENAI_API_BASE"] = "http://test"
|
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||||
|
|
@ -194,29 +217,30 @@ def test_incomplete_openai_compatible_config(clean_env):
|
||||||
def test_incomplete_gemini_config(clean_env):
|
def test_incomplete_gemini_config(clean_env):
|
||||||
"""Test Gemini provider with incomplete configuration."""
|
"""Test Gemini provider with incomplete configuration."""
|
||||||
strategy = GeminiStrategy()
|
strategy = GeminiStrategy()
|
||||||
|
|
||||||
# No configuration
|
# No configuration
|
||||||
result = strategy.validate()
|
result = strategy.validate()
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert "GEMINI_API_KEY environment variable is not set" in result.missing_vars
|
assert "GEMINI_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
# Valid API key
|
# Valid API key
|
||||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||||
result = strategy.validate()
|
result = strategy.validate()
|
||||||
assert result.valid
|
assert result.valid
|
||||||
assert not result.missing_vars
|
assert not result.missing_vars
|
||||||
|
|
||||||
|
|
||||||
def test_incomplete_expert_config(clean_env):
|
def test_incomplete_expert_config(clean_env):
|
||||||
"""Test expert provider with incomplete configuration."""
|
"""Test expert provider with incomplete configuration."""
|
||||||
# Set main provider but not expert
|
# Set main provider but not expert
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||||
|
|
||||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
assert not expert_enabled
|
assert not expert_enabled
|
||||||
assert len(expert_missing) == 1
|
assert len(expert_missing) == 1
|
||||||
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
|
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
|
||||||
|
|
||||||
# Set expert key but not base URL
|
# Set expert key but not base URL
|
||||||
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
|
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
|
||||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
|
@ -232,7 +256,7 @@ def test_empty_environment_variables(clean_env):
|
||||||
args = MockArgs(provider="openai")
|
args = MockArgs(provider="openai")
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
# Empty base URL
|
# Empty base URL
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
os.environ["OPENAI_API_BASE"] = ""
|
os.environ["OPENAI_API_BASE"] = ""
|
||||||
|
|
@ -240,10 +264,11 @@ def test_empty_environment_variables(clean_env):
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
|
|
||||||
def test_openrouter_validation(clean_env):
|
def test_openrouter_validation(clean_env):
|
||||||
"""Test OpenRouter provider validation."""
|
"""Test OpenRouter provider validation."""
|
||||||
strategy = OpenRouterStrategy()
|
strategy = OpenRouterStrategy()
|
||||||
|
|
||||||
# No API key
|
# No API key
|
||||||
result = strategy.validate()
|
result = strategy.validate()
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
|
|
@ -254,13 +279,14 @@ def test_openrouter_validation(clean_env):
|
||||||
result = strategy.validate()
|
result = strategy.validate()
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
|
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
# Valid API key
|
# Valid API key
|
||||||
os.environ["OPENROUTER_API_KEY"] = "test-key"
|
os.environ["OPENROUTER_API_KEY"] = "test-key"
|
||||||
result = strategy.validate()
|
result = strategy.validate()
|
||||||
assert result.valid
|
assert result.valid
|
||||||
assert not result.missing_vars
|
assert not result.missing_vars
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_expert_providers(clean_env):
|
def test_multiple_expert_providers(clean_env):
|
||||||
"""Test validation with multiple expert providers."""
|
"""Test validation with multiple expert providers."""
|
||||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
|
@ -268,14 +294,17 @@ def test_multiple_expert_providers(clean_env):
|
||||||
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
|
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
|
||||||
|
|
||||||
# First expert provider valid, second invalid
|
# First expert provider valid, second invalid
|
||||||
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
|
args = MockArgs(
|
||||||
|
provider="openai",
|
||||||
|
expert_provider="anthropic",
|
||||||
|
expert_model="claude-3-haiku-20240307",
|
||||||
|
)
|
||||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
|
|
||||||
# Switch to invalid provider
|
# Switch to invalid provider
|
||||||
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
assert not expert_enabled
|
assert not expert_enabled
|
||||||
assert expert_missing
|
assert expert_missing
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue