Adding Gemini API due to openrouter's limitations. (#34)

This commit is contained in:
arthrod 2025-01-04 07:19:05 -05:00 committed by GitHub
parent 937c8c2a6a
commit ffe5138a99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 159 additions and 21 deletions

View File

@ -127,6 +127,9 @@ export OPENROUTER_API_KEY=your_api_key_here
# For OpenAI-compatible providers (optional)
export OPENAI_API_BASE=your_api_base_url
# For Gemini provider (optional)
export GEMINI_API_KEY=your_api_key_here
# For web research capabilities
export TAVILY_API_KEY=your_api_key_here
```
@ -140,6 +143,7 @@ You can get your API keys from:
- Anthropic API key: https://console.anthropic.com/
- OpenAI API key: https://platform.openai.com/api-keys
- OpenRouter API key: https://openrouter.ai/keys
- Gemini API key: https://aistudio.google.com/app/apikey
## Usage
@ -165,7 +169,7 @@ ra-aid -m "Add new feature" --verbose
- `--provider`: Specify the model provider (See Model Configuration section)
- `--model`: Specify the model name (See Model Configuration section)
- `--expert-provider`: Specify the provider for the expert tool (defaults to OpenAI)
- `--expert-model`: Specify the model name for the expert tool (defaults to o1-preview for OpenAI)
- `--expert-model`: Specify the model name for the expert tool (defaults to o1 for OpenAI)
- `--chat`: Enable chat mode for interactive assistance
- `--verbose`: Enable detailed logging output for debugging and monitoring
@ -276,7 +280,7 @@ RA.Aid supports multiple AI providers and models. The default model is Anthropic
The programmer tool (aider) automatically selects its model based on your available API keys. It will use Claude models if ANTHROPIC_API_KEY is set, or fall back to OpenAI models if only OPENAI_API_KEY is available.
Note: The expert tool can be configured to use different providers (OpenAI, Anthropic, OpenRouter) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables. Each provider requires its own API key set through the appropriate environment variable.
Note: The expert tool can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables. Each provider requires its own API key set through the appropriate environment variable.
#### Environment Variables
@ -286,12 +290,14 @@ RA.Aid supports multiple providers through environment variables:
- `OPENAI_API_KEY`: Required for OpenAI provider
- `OPENROUTER_API_KEY`: Required for OpenRouter provider
- `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY`
- `GEMINI_API_KEY`: Required for Gemini provider
Expert Tool Environment Variables:
- `EXPERT_OPENAI_API_KEY`: API key for expert tool using OpenAI provider
- `EXPERT_ANTHROPIC_API_KEY`: API key for expert tool using Anthropic provider
- `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider
- `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider
- `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider
You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`):
@ -307,6 +313,9 @@ export OPENROUTER_API_KEY=your_api_key_here
# For OpenAI-compatible providers
export OPENAI_API_BASE=your_api_base_url
# For Gemini provider
export GEMINI_API_KEY=your_api_key_here
```
### Custom Model Examples
@ -332,7 +341,7 @@ export OPENAI_API_BASE=your_api_base_url
4. **Configuring Expert Provider**
The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables.
The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables.
```bash
# Use Anthropic for expert tool
@ -345,7 +354,11 @@ export OPENAI_API_BASE=your_api_base_url
# Use default OpenAI for expert tool
export EXPERT_OPENAI_API_KEY=your_openai_api_key
ra-aid -m "Your task" --expert-provider openai --expert-model o1-preview
ra-aid -m "Your task" --expert-provider openai --expert-model o1
# Use Gemini for expert tool
export EXPERT_GEMINI_API_KEY=your_gemini_api_key
ra-aid -m "Your task" --expert-provider gemini --expert-model gemini-2.0-flash-thinking-exp-1219
```
Aider specific Environment Variables you can add:

View File

@ -24,6 +24,7 @@ classifiers = [
dependencies = [
"langchain-anthropic>=0.3.1",
"langchain-openai",
"langchain-google-genai",
"langgraph>=0.2.60",
"langgraph-checkpoint>=2.0.9",
"langgraph-sdk>=0.1.48",

View File

@ -33,7 +33,7 @@ import os
logger = get_logger(__name__)
def parse_arguments(args=None):
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible']
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible', 'gemini']
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022'
OPENAI_DEFAULT_MODEL = 'gpt-4o'

View File

@ -46,6 +46,10 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
},
'openrouter': {
'OPENROUTER_API_KEY': 'EXPERT_OPENROUTER_API_KEY'
},
'gemini': {
'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY',
'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL'
}
}

View File

@ -2,6 +2,9 @@ import os
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel:
"""Initialize a language model client based on the specified provider and model.
@ -10,7 +13,7 @@ def initialize_llm(provider: str, model_name: str, temperature: float | None = N
Use validate_environment() to ensure all required variables are set.
Args:
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible')
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini')
model_name: Name of the model to use
temperature: Optional temperature setting for controlling randomness (0.0-2.0).
If not specified, provider-specific defaults are used.
@ -47,19 +50,25 @@ def initialize_llm(provider: str, model_name: str, temperature: float | None = N
temperature=temperature if temperature is not None else 0.3,
model=model_name,
)
elif provider == "gemini":
return ChatGoogleGenerativeAI(
api_key=os.getenv("GEMINI_API_KEY"),
model=model_name,
**({"temperature": temperature} if temperature is not None else {})
)
else:
raise ValueError(f"Unsupported provider: {provider}")
def initialize_expert_llm(provider: str = "openai", model_name: str = "o1-preview") -> BaseChatModel:
def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> BaseChatModel:
"""Initialize an expert language model client based on the specified provider and model.
Note: Environment variables must be validated before calling this function.
Use validate_environment() to ensure all required variables are set.
Args:
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible').
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini').
Defaults to 'openai'.
model_name: Name of the model to use. Defaults to 'o1-preview'.
model_name: Name of the model to use. Defaults to 'o1'.
Returns:
BaseChatModel: Configured expert language model client
@ -89,5 +98,10 @@ def initialize_expert_llm(provider: str = "openai", model_name: str = "o1-previe
base_url=os.getenv("EXPERT_OPENAI_API_BASE"),
model=model_name,
)
elif provider == "gemini":
return ChatGoogleGenerativeAI(
api_key=os.getenv("EXPERT_GEMINI_API_KEY"),
model=model_name,
)
else:
raise ValueError(f"Unsupported provider: {provider}")
raise ValueError(f"Unsupported provider: {provider}")

View File

@ -213,6 +213,32 @@ class OpenRouterStrategy(ProviderStrategy):
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class GeminiStrategy(ProviderStrategy):
"""Gemini provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate Gemini environment variables."""
missing = []
# Check if we're validating expert config
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'gemini':
key = os.environ.get('EXPERT_GEMINI_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('GEMINI_API_KEY')
if base_key:
os.environ['EXPERT_GEMINI_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_GEMINI_API_KEY environment variable is not set')
else:
key = os.environ.get('GEMINI_API_KEY')
if not key:
missing.append('GEMINI_API_KEY environment variable is not set')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class OllamaStrategy(ProviderStrategy):
"""Ollama provider validation strategy."""
@ -245,6 +271,7 @@ class ProviderFactory:
'openai-compatible': OpenAICompatibleStrategy(),
'anthropic': AnthropicStrategy(),
'openrouter': OpenRouterStrategy(),
'gemini': GeminiStrategy(),
'ollama': OllamaStrategy()
}
strategy = strategies.get(provider)

View File

@ -15,7 +15,7 @@ def get_model():
try:
if _model is None:
provider = _global_memory['config']['expert_provider'] or 'openai'
model = _global_memory['config']['expert_model'] or 'o1-preview'
model = _global_memory['config']['expert_model'] or 'o1'
_model = initialize_expert_llm(provider, model)
except Exception as e:
_model = None

View File

@ -3,6 +3,7 @@ import pytest
from unittest.mock import patch, Mock
from langchain_openai.chat_models import ChatOpenAI
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from dataclasses import dataclass
from ra_aid.agents.ciayn_agent import CiaynAgent
@ -16,7 +17,7 @@ def clean_env(monkeypatch):
env_vars = [
'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY',
'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY',
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE'
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'GEMINI_API_KEY', 'EXPERT_GEMINI_API_KEY'
]
for var in env_vars:
monkeypatch.delenv(var, raising=False)
@ -38,7 +39,7 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch):
mock_openai.assert_called_once_with(
api_key="test-key",
model="o1-preview"
model="o1"
)
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
@ -51,6 +52,16 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
model="gpt-4-preview"
)
def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
"""Test expert Gemini initialization."""
monkeypatch.setenv("EXPERT_GEMINI_API_KEY", "test-key")
llm = initialize_expert_llm("gemini", "gemini-2.0-flash-thinking-exp-1219")
mock_gemini.assert_called_once_with(
api_key="test-key",
model="gemini-2.0-flash-thinking-exp-1219"
)
def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
"""Test expert Anthropic initialization."""
monkeypatch.setenv("EXPERT_ANTHROPIC_API_KEY", "test-key")
@ -114,6 +125,16 @@ def test_initialize_openai(clean_env, mock_openai):
model="gpt-4"
)
def test_initialize_gemini(clean_env, mock_gemini):
"""Test Gemini provider initialization"""
os.environ["GEMINI_API_KEY"] = "test-key"
model = initialize_llm("gemini", "gemini-2.0-flash-thinking-exp-1219")
mock_gemini.assert_called_once_with(
api_key="test-key",
model="gemini-2.0-flash-thinking-exp-1219"
)
def test_initialize_anthropic(clean_env, mock_anthropic):
"""Test Anthropic provider initialization"""
os.environ["ANTHROPIC_API_KEY"] = "test-key"
@ -159,7 +180,7 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic):
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["OPENAI_API_BASE"] = "http://test-url"
os.environ["GEMINI_API_KEY"] = "test-key"
# Test openai-compatible default temperature
initialize_llm("openai-compatible", "test-model")
mock_openai.assert_called_with(
@ -181,12 +202,19 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic):
api_key="test-key",
model_name="test-model"
)
initialize_llm("gemini", "test-model")
mock_gemini.assert_called_with(
api_key="test-key",
model="test-model"
)
def test_explicit_temperature(clean_env, mock_openai, mock_anthropic):
"""Test explicit temperature setting for each provider."""
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["OPENROUTER_API_KEY"] = "test-key"
os.environ["OPENROUTER_API_KEY"] = "test-key",
os.environ["GEMINI_API_KEY"] = "test-key"
test_temp = 0.7
@ -198,6 +226,14 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic):
temperature=test_temp
)
# Test Gemini
initialize_llm("gemini", "test-model", temperature=test_temp)
mock_gemini.assert_called_with(
api_key="test-key",
model="test-model",
temperature=test_temp
)
# Test Anthropic
initialize_llm("anthropic", "test-model", temperature=test_temp)
mock_anthropic.assert_called_with(
@ -234,7 +270,7 @@ def test_temperature_validation(clean_env, mock_openai):
def test_provider_name_validation():
"""Test provider name validation and normalization."""
# Test all supported providers
providers = ["openai", "anthropic", "openrouter", "openai-compatible"]
providers = ["openai", "anthropic", "openrouter", "openai-compatible", "gemini"]
for provider in providers:
try:
with patch(f'ra_aid.llm.ChatOpenAI'), patch('ra_aid.llm.ChatAnthropic'):
@ -247,7 +283,7 @@ def test_provider_name_validation():
with pytest.raises(ValueError):
initialize_llm("OpenAI", "test-model")
def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, monkeypatch):
def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, mock_gemini, monkeypatch):
"""Test initializing different providers in sequence."""
# Initialize OpenAI
monkeypatch.setenv("OPENAI_API_KEY", "openai-key")
@ -257,6 +293,10 @@ def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, m
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
llm2 = initialize_llm("anthropic", "claude-3")
# Initialize Gemini
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
llm3 = initialize_llm("gemini", "gemini-2.0-flash-thinking-exp-1219")
# Verify both were initialized correctly
mock_openai.assert_called_once_with(
api_key="openai-key",
@ -266,7 +306,7 @@ def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, m
api_key="anthropic-key",
model_name="claude-3"
)
mock_gemini.assert_called_once_with(
def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
"""Test environment variable precedence and fallback."""
from ra_aid.env import validate_environment
@ -284,6 +324,7 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "base-key")
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key")
monkeypatch.setenv("TAVILY_API_KEY", "tavily-key")
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
args = Args(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert expert_enabled
@ -294,7 +335,7 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
llm = initialize_expert_llm()
mock_openai.assert_called_with(
api_key="expert-key",
model="o1-preview"
model="o1"
)
# Test empty key validation
@ -302,6 +343,7 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") # Add for provider validation
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
@ -319,3 +361,10 @@ def mock_anthropic():
with patch('ra_aid.llm.ChatAnthropic') as mock:
mock.return_value = Mock(spec=ChatAnthropic)
yield mock
@pytest.fixture
def mock_gemini():
"""Mock ChatGoogleGenerativeAI class for testing Gemini provider initialization."""
with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock:
mock.return_value = Mock(spec=ChatGoogleGenerativeAI)
yield mock

View File

@ -12,7 +12,8 @@ from ra_aid.provider_strategy import (
AnthropicStrategy,
OpenAIStrategy,
OpenAICompatibleStrategy,
OpenRouterStrategy
OpenRouterStrategy,
GeminiStrategy
)
@dataclass
@ -35,7 +36,10 @@ def clean_env():
"EXPERT_OPENAI_API_KEY",
"EXPERT_OPENAI_API_BASE",
"TAVILY_API_KEY",
"ANTHROPIC_MODEL"
"ANTHROPIC_MODEL",
"GEMINI_API_KEY",
"EXPERT_GEMINI_API_KEY",
"GEMINI_MODEL"
]
# Store original values
@ -186,6 +190,30 @@ def test_incomplete_openai_compatible_config(clean_env):
assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
def test_incomplete_gemini_config(clean_env):
"""Test Gemini provider with incomplete configuration."""
strategy = GeminiStrategy()
# No configuration
result = strategy.validate()
assert not result.valid
assert "GEMINI_API_KEY environment variable is not set" in result.missing_vars
assert "GEMINI_MODEL environment variable is not set" in result.missing_vars
# Only API key
os.environ["GEMINI_API_KEY"] = "test-key"
result = strategy.validate()
assert not result.valid
assert "GEMINI_MODEL environment variable is not set" in result.missing_vars
# Only model
os.environ.pop("GEMINI_API_KEY")
os.environ["GEMINI_MODEL"] = "gemini-2.0-flash-exp"
result = strategy.validate()
assert not result.valid
assert "GEMINI_API_KEY environment variable is not set" in result.missing_vars
def test_incomplete_expert_config(clean_env):
"""Test expert provider with incomplete configuration."""
# Set main provider but not expert
@ -204,6 +232,7 @@ def test_incomplete_expert_config(clean_env):
assert len(expert_missing) == 1
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
def test_empty_environment_variables(clean_env):
"""Test handling of empty environment variables."""
# Empty API key
@ -257,3 +286,4 @@ def test_multiple_expert_providers(clean_env):
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert expert_missing