auto detect openai expert models

This commit is contained in:
AI Christianson 2025-02-12 15:40:21 -05:00
parent c9d7e90312
commit e3a705eb9b
3 changed files with 107 additions and 21 deletions

View File

@ -1,6 +1,7 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from openai import OpenAI
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
@ -8,9 +9,51 @@ from langchain_openai import ChatOpenAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from typing import List
from .models_params import models_params from .models_params import models_params
def get_available_openai_models() -> List[str]:
"""Fetch available OpenAI models using OpenAI client.
Returns:
List of available model names
"""
try:
# Use OpenAI client to fetch models
client = OpenAI()
models = client.models.list()
return [str(model.id) for model in models.data]
except Exception:
# Return empty list if unable to fetch models
return []
def select_expert_model(provider: str, model: Optional[str] = None) -> Optional[str]:
"""Select appropriate expert model based on provider and availability.
Args:
provider: The LLM provider
model: Optional explicitly specified model name
Returns:
Selected model name or None if no suitable model found
"""
if provider != "openai" or model is not None:
return model
# Try to get available models
available_models = get_available_openai_models()
# Priority order for expert models
priority_models = ["o3-mini", "o1", "o1-preview"]
# Return first available model from priority list
for model_name in priority_models:
if model_name in available_models:
return model_name
return None
known_temp_providers = { known_temp_providers = {
"openai", "openai",
"anthropic", "anthropic",
@ -150,6 +193,11 @@ def create_llm_client(
if not config: if not config:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")
if is_expert and provider == "openai":
model_name = select_expert_model(provider, model_name)
if not model_name:
raise ValueError("No suitable expert model available")
logger.debug( logger.debug(
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s", "Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
provider, provider,

View File

@ -47,32 +47,20 @@ class OpenAIStrategy(ProviderStrategy):
if not key: if not key:
missing.append("EXPERT_OPENAI_API_KEY environment variable is not set") missing.append("EXPERT_OPENAI_API_KEY environment variable is not set")
# Check expert model only for research-only mode # Handle expert model selection if none specified
if hasattr(args, "research_only") and args.research_only: if hasattr(args, "expert_model") and not args.expert_model:
model = args.expert_model if hasattr(args, "expert_model") else None from ra_aid.llm import select_expert_model
if not model: model = select_expert_model("openai")
model = os.environ.get("EXPERT_OPENAI_MODEL") if model:
if not model: args.expert_model = model
model = os.environ.get("OPENAI_MODEL") elif hasattr(args, "research_only") and args.research_only:
if not model: missing.append("No suitable expert model available")
missing.append(
"Model is required for OpenAI provider in research-only mode"
)
else: else:
key = os.environ.get("OPENAI_API_KEY") key = os.environ.get("OPENAI_API_KEY")
if not key: if not key:
missing.append("OPENAI_API_KEY environment variable is not set") missing.append("OPENAI_API_KEY environment variable is not set")
# Check model only for research-only mode
if hasattr(args, "research_only") and args.research_only:
model = args.model if hasattr(args, "model") else None
if not model:
model = os.environ.get("OPENAI_MODEL")
if not model:
missing.append(
"Model is required for OpenAI provider in research-only mode"
)
return ValidationResult(valid=len(missing) == 0, missing_vars=missing) return ValidationResult(valid=len(missing) == 0, missing_vars=missing)

View File

@ -1,5 +1,6 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from unittest import mock
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -12,10 +13,12 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
from ra_aid.llm import ( from ra_aid.llm import (
create_llm_client, create_llm_client,
get_available_openai_models,
get_env_var, get_env_var,
get_provider_config, get_provider_config,
initialize_expert_llm, initialize_expert_llm,
initialize_llm, initialize_llm,
select_expert_model,
) )
@ -281,6 +284,53 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
) )
def test_get_available_openai_models_success():
"""Test successful retrieval of OpenAI models."""
mock_model = Mock()
mock_model.id = "gpt-4"
mock_models = Mock()
mock_models.data = [mock_model]
with mock.patch("ra_aid.llm.OpenAI") as mock_client:
mock_client.return_value.models.list.return_value = mock_models
models = get_available_openai_models()
assert models == ["gpt-4"]
mock_client.return_value.models.list.assert_called_once()
def test_get_available_openai_models_failure():
"""Test graceful handling of model retrieval failure."""
with mock.patch("ra_aid.llm.OpenAI") as mock_client:
mock_client.return_value.models.list.side_effect = Exception("API Error")
models = get_available_openai_models()
assert models == []
mock_client.return_value.models.list.assert_called_once()
def test_select_expert_model_explicit():
"""Test model selection with explicitly specified model."""
model = select_expert_model("openai", "gpt-4")
assert model == "gpt-4"
def test_select_expert_model_non_openai():
"""Test model selection for non-OpenAI provider."""
model = select_expert_model("anthropic", None)
assert model is None
def test_select_expert_model_priority():
"""Test model selection follows priority order."""
available_models = ["gpt-4", "o1", "o3-mini"]
with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models):
model = select_expert_model("openai")
assert model == "o3-mini"
def test_select_expert_model_no_match():
"""Test model selection when no priority models available."""
available_models = ["gpt-4", "gpt-3.5"]
with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models):
model = select_expert_model("openai")
assert model is None
def test_temperature_validation(clean_env, mock_openai): def test_temperature_validation(clean_env, mock_openai):
"""Test temperature validation in command line arguments.""" """Test temperature validation in command line arguments."""
from ra_aid.__main__ import parse_arguments from ra_aid.__main__ import parse_arguments