RA.Aid/ra_aid/llm.py

301 lines
9.2 KiB
Python

import os
from typing import Any, Dict, List, Optional
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from openai import OpenAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.logging_config import get_logger
from .models_params import models_params
def get_available_openai_models() -> List[str]:
"""Fetch available OpenAI models using OpenAI client.
Returns:
List of available model names
"""
try:
# Use OpenAI client to fetch models
client = OpenAI()
models = client.models.list()
return [str(model.id) for model in models.data]
except Exception:
# Return empty list if unable to fetch models
return []
def select_expert_model(provider: str, model: Optional[str] = None) -> Optional[str]:
"""Select appropriate expert model based on provider and availability.
Args:
provider: The LLM provider
model: Optional explicitly specified model name
Returns:
Selected model name or None if no suitable model found
"""
if provider != "openai" or model is not None:
return model
# Try to get available models
available_models = get_available_openai_models()
# Priority order for expert models
priority_models = ["o3-mini", "o1", "o1-preview"]
# Return first available model from priority list
for model_name in priority_models:
if model_name in available_models:
return model_name
return None
known_temp_providers = {
"openai",
"anthropic",
"openrouter",
"openai-compatible",
"gemini",
"deepseek",
}
# Constants for API request configuration
LLM_REQUEST_TIMEOUT = 180
LLM_MAX_RETRIES = 5
logger = get_logger(__name__)
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
"""Get environment variable with optional expert prefix and fallback."""
prefix = "EXPERT_" if expert else ""
value = os.getenv(f"{prefix}{name}")
# If expert mode and no expert value, fall back to base value
if expert and not value:
value = os.getenv(name)
return value
def create_deepseek_client(
model_name: str,
api_key: str,
base_url: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create DeepSeek client with appropriate configuration."""
if model_name.lower() == "deepseek-reasoner":
return ChatDeepseekReasoner(
api_key=api_key,
base_url=base_url,
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
)
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
temperature=0 if is_expert else (temperature if temperature is not None else 1),
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
)
def create_openrouter_client(
model_name: str,
api_key: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create OpenRouter client with appropriate configuration."""
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
return ChatDeepseekReasoner(
api_key=api_key,
base_url="https://openrouter.ai/api/v1",
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
)
return ChatOpenAI(
api_key=api_key,
base_url="https://openrouter.ai/api/v1",
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
**({"temperature": temperature} if temperature is not None else {}),
)
def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any]:
"""Get provider-specific configuration."""
configs = {
"openai": {
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
"base_url": None,
},
"anthropic": {
"api_key": get_env_var("ANTHROPIC_API_KEY", is_expert),
"base_url": None,
},
"openrouter": {
"api_key": get_env_var("OPENROUTER_API_KEY", is_expert),
"base_url": "https://openrouter.ai/api/v1",
},
"openai-compatible": {
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
"base_url": get_env_var("OPENAI_API_BASE", is_expert),
},
"gemini": {
"api_key": get_env_var("GEMINI_API_KEY", is_expert),
"base_url": None,
},
"deepseek": {
"api_key": get_env_var("DEEPSEEK_API_KEY", is_expert),
"base_url": "https://api.deepseek.com",
},
}
return configs.get(provider, {})
def create_llm_client(
provider: str,
model_name: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create a language model client with appropriate configuration.
Args:
provider: The LLM provider to use
model_name: Name of the model to use
temperature: Optional temperature setting (0.0-2.0)
is_expert: Whether this is an expert model (uses deterministic output)
Returns:
Configured language model client
"""
config = get_provider_config(provider, is_expert)
if not config:
raise ValueError(f"Unsupported provider: {provider}")
if is_expert and provider == "openai":
model_name = select_expert_model(provider, model_name)
if not model_name:
raise ValueError("No suitable expert model available")
logger.debug(
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
provider,
model_name,
temperature,
is_expert,
)
# Get model configuration
model_config = models_params.get(provider, {}).get(model_name, {})
# Default to True for known providers that support temperature if not specified
if "supports_temperature" not in model_config:
model_config["supports_temperature"] = provider in known_temp_providers
supports_temperature = model_config["supports_temperature"]
# Handle temperature settings
if is_expert:
temp_kwargs = {"temperature": 0} if supports_temperature else {}
elif supports_temperature:
if temperature is None:
raise ValueError(
f"Temperature must be provided for model {model_name} which supports temperature"
)
temp_kwargs = {"temperature": temperature}
else:
temp_kwargs = {}
if provider == "deepseek":
return create_deepseek_client(
model_name=model_name,
api_key=config["api_key"],
base_url=config["base_url"],
**temp_kwargs,
is_expert=is_expert,
)
elif provider == "openrouter":
return create_openrouter_client(
model_name=model_name,
api_key=config["api_key"],
**temp_kwargs,
is_expert=is_expert,
)
elif provider == "openai":
openai_kwargs = {
"api_key": config["api_key"],
"model": model_name,
**temp_kwargs,
}
if is_expert:
openai_kwargs["reasoning_effort"] = "high"
return ChatOpenAI(
**{
**openai_kwargs,
"timeout": LLM_REQUEST_TIMEOUT,
"max_retries": LLM_MAX_RETRIES,
}
)
elif provider == "anthropic":
return ChatAnthropic(
api_key=config["api_key"],
model_name=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
**temp_kwargs,
)
elif provider == "openai-compatible":
return ChatOpenAI(
api_key=config["api_key"],
base_url=config["base_url"],
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
**temp_kwargs,
)
elif provider == "gemini":
return ChatGoogleGenerativeAI(
api_key=config["api_key"],
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
**temp_kwargs,
)
else:
raise ValueError(f"Unsupported provider: {provider}")
def initialize_llm(
provider: str, model_name: str, temperature: float | None = None
) -> BaseChatModel:
"""Initialize a language model client based on the specified provider and model."""
return create_llm_client(provider, model_name, temperature, is_expert=False)
def initialize_expert_llm(provider: str, model_name: str) -> BaseChatModel:
"""Initialize an expert language model client based on the specified provider and model."""
return create_llm_client(provider, model_name, temperature=None, is_expert=True)