From 36e00d8227b030464c44f4d44c6bad9f3193e28a Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Fri, 13 Dec 2024 14:19:42 -0500 Subject: [PATCH] make expert model configurable --- README.md | 26 ++++++++++++++++++++++++-- ra_aid/__main__.py | 37 +++++++++++++++++++++++++++++++++++++ ra_aid/llm.py | 42 ++++++++++++++++++++++++++++++++++++++++++ ra_aid/tools/expert.py | 6 ++++-- 4 files changed, 107 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 18472b6..487b186 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,8 @@ ra-aid -m "Explain the authentication flow" --research-only - `--cowboy-mode`: Skip interactive approval for shell commands - `--provider`: Specify the model provider (See Model Configuration section) - `--model`: Specify the model name (See Model Configuration section) +- `--expert-provider`: Specify the provider for the expert tool (defaults to OpenAI) +- `--expert-model`: Specify the model name for the expert tool (defaults to o1-preview for OpenAI) ### Model Configuration @@ -155,15 +157,23 @@ RA.Aid supports multiple AI providers and models. The default model is Anthropic The programmer tool (aider) automatically selects its model based on your available API keys. It will use Claude models if ANTHROPIC_API_KEY is set, or fall back to OpenAI models if only OPENAI_API_KEY is available. +Note: The expert tool can be configured to use different providers (OpenAI, Anthropic, OpenRouter) using the --expert-provider flag along with the corresponding EXPERT_*_KEY environment variables. Each provider requires its own API key set through the appropriate environment variable. + #### Environment Variables RA.Aid supports multiple providers through environment variables: - `ANTHROPIC_API_KEY`: Required for the default Anthropic provider -- `OPENAI_API_KEY`: Required for OpenAI provider and expert tool +- `OPENAI_API_KEY`: Required for OpenAI provider - `OPENROUTER_API_KEY`: Required for OpenRouter provider - `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY` +Expert Tool Environment Variables: +- `EXPERT_OPENAI_KEY`: API key for expert tool using OpenAI provider +- `EXPERT_ANTHROPIC_KEY`: API key for expert tool using Anthropic provider +- `EXPERT_OPENROUTER_KEY`: API key for expert tool using OpenRouter provider +- `EXPERT_OPENAI_BASE`: Base URL for expert tool using OpenAI-compatible provider + You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`): ```bash @@ -180,7 +190,7 @@ export OPENROUTER_API_KEY=your_api_key_here export OPENAI_API_BASE=your_api_base_url ``` -Note: The expert tool always uses OpenAI's `o1-preview` model and requires `OPENAI_API_KEY` to be set, even if you're using a different provider for the main application. +Note: The expert tool defaults to OpenAI's o1-preview model with the OpenAI provider, but this can be configured using the --expert-provider flag along with the corresponding EXPERT_*_KEY environment variables. #### Examples @@ -203,6 +213,18 @@ Note: The expert tool always uses OpenAI's `o1-preview` model and requires `OPEN ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411 ``` +4. **Configuring Expert Provider** + ```bash + # Use Anthropic for expert tool + ra-aid -m "Your task" --expert-provider anthropic + + # Use OpenRouter for expert tool + ra-aid -m "Your task" --expert-provider openrouter + + # Use default OpenAI for expert tool + ra-aid -m "Your task" --expert-provider openai + ``` + **Important Notes:** - Performance varies between models. The default Claude 3 Sonnet model currently provides the best and most reliable results. - Model configuration is done via command line arguments: `--provider` and `--model` diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 22eff1e..ed5947b 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -67,6 +67,18 @@ Examples: action='store_true', help='Skip interactive approval for shell commands' ) + parser.add_argument( + '--expert-provider', + type=str, + default='openai', + choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'], + help='The LLM provider to use for expert knowledge queries (default: openai)' + ) + parser.add_argument( + '--expert-model', + type=str, + help='The model name to use for expert knowledge queries (required for non-OpenAI providers)' + ) args = parser.parse_args() @@ -77,6 +89,10 @@ Examples: elif not args.model: parser.error(f"--model is required when using provider '{args.provider}'") + # Validate expert model requirement + if args.expert_provider != 'openai' and not args.expert_model: + parser.error(f"--expert-model is required when using expert provider '{args.expert_provider}'") + return args # Create console instance @@ -256,6 +272,7 @@ def validate_environment(args): """ missing = [] provider = args.provider + expert_provider = args.expert_provider # Check API keys based on provider if provider == "anthropic": @@ -273,6 +290,22 @@ def validate_environment(args): if not os.environ.get('OPENAI_API_BASE'): missing.append('OPENAI_API_BASE environment variable is not set') + # Check expert provider keys + if expert_provider == "anthropic": + if not os.environ.get('EXPERT_ANTHROPIC_KEY'): + missing.append('EXPERT_ANTHROPIC_KEY environment variable is not set') + elif expert_provider == "openai": + if not os.environ.get('EXPERT_OPENAI_KEY'): + missing.append('EXPERT_OPENAI_KEY environment variable is not set') + elif expert_provider == "openrouter": + if not os.environ.get('EXPERT_OPENROUTER_KEY'): + missing.append('EXPERT_OPENROUTER_KEY environment variable is not set') + elif expert_provider == "openai-compatible": + if not os.environ.get('EXPERT_OPENAI_KEY'): + missing.append('EXPERT_OPENAI_KEY environment variable is not set') + if not os.environ.get('EXPERT_OPENAI_BASE'): + missing.append('EXPERT_OPENAI_BASE environment variable is not set') + if missing: print_error("Missing required dependencies:") for item in missing: @@ -307,6 +340,10 @@ def main(): # Store config in global memory for access by is_informational_query _global_memory['config'] = config + # Store expert provider and model in config + _global_memory['config']['expert_provider'] = args.expert_provider + _global_memory['config']['expert_model'] = args.expert_model + # Run research stage print_stage_header("Research Stage") diff --git a/ra_aid/llm.py b/ra_aid/llm.py index b94cbd8..6157140 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -43,3 +43,45 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel: ) else: raise ValueError(f"Unsupported provider: {provider}") + +def initialize_expert_llm(provider: str = "openai", model_name: str = "o1-preview") -> BaseChatModel: + """Initialize an expert language model client based on the specified provider and model. + + Note: Environment variables must be validated before calling this function. + Use validate_environment() to ensure all required variables are set. + + Args: + provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible'). + Defaults to 'openai'. + model_name: Name of the model to use. Defaults to 'o1-preview'. + + Returns: + BaseChatModel: Configured expert language model client + + Raises: + ValueError: If the provider is not supported + """ + if provider == "openai": + return ChatOpenAI( + openai_api_key=os.getenv("EXPERT_OPENAI_KEY"), + model=model_name + ) + elif provider == "anthropic": + return ChatAnthropic( + anthropic_api_key=os.getenv("EXPERT_ANTHROPIC_KEY"), + model=model_name + ) + elif provider == "openrouter": + return ChatOpenAI( + openai_api_key=os.getenv("EXPERT_OPENROUTER_KEY"), + openai_api_base="https://openrouter.ai/api/v1", + model=model_name + ) + elif provider == "openai-compatible": + return ChatOpenAI( + openai_api_key=os.getenv("EXPERT_OPENAI_KEY"), + openai_api_base=os.getenv("EXPERT_OPENAI_BASE"), + model=model_name + ) + else: + raise ValueError(f"Unsupported provider: {provider}") diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index 5214b0a..10adcf2 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -4,7 +4,7 @@ from langchain_core.tools import tool from rich.console import Console from rich.panel import Panel from rich.markdown import Markdown -from langchain_openai import ChatOpenAI +from ..llm import initialize_expert_llm from .memory import get_memory_value, get_related_files console = Console() @@ -13,7 +13,9 @@ _model = None def get_model(): global _model if _model is None: - _model = ChatOpenAI(model_name="o1-preview") + provider = get_memory_value('expert_provider') or 'openai' + model = get_memory_value('expert_model') or 'o1-preview' + _model = initialize_expert_llm(provider, model) return _model # Keep track of context globally