auto detect openai expert models
This commit is contained in:
parent
c9d7e90312
commit
e3a705eb9b
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
|
@ -8,9 +9,51 @@ from langchain_openai import ChatOpenAI
|
|||
|
||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||
from ra_aid.logging_config import get_logger
|
||||
from typing import List
|
||||
|
||||
from .models_params import models_params
|
||||
|
||||
def get_available_openai_models() -> List[str]:
|
||||
"""Fetch available OpenAI models using OpenAI client.
|
||||
|
||||
Returns:
|
||||
List of available model names
|
||||
"""
|
||||
try:
|
||||
# Use OpenAI client to fetch models
|
||||
client = OpenAI()
|
||||
models = client.models.list()
|
||||
return [str(model.id) for model in models.data]
|
||||
except Exception:
|
||||
# Return empty list if unable to fetch models
|
||||
return []
|
||||
|
||||
def select_expert_model(provider: str, model: Optional[str] = None) -> Optional[str]:
|
||||
"""Select appropriate expert model based on provider and availability.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider
|
||||
model: Optional explicitly specified model name
|
||||
|
||||
Returns:
|
||||
Selected model name or None if no suitable model found
|
||||
"""
|
||||
if provider != "openai" or model is not None:
|
||||
return model
|
||||
|
||||
# Try to get available models
|
||||
available_models = get_available_openai_models()
|
||||
|
||||
# Priority order for expert models
|
||||
priority_models = ["o3-mini", "o1", "o1-preview"]
|
||||
|
||||
# Return first available model from priority list
|
||||
for model_name in priority_models:
|
||||
if model_name in available_models:
|
||||
return model_name
|
||||
|
||||
return None
|
||||
|
||||
known_temp_providers = {
|
||||
"openai",
|
||||
"anthropic",
|
||||
|
|
@ -150,6 +193,11 @@ def create_llm_client(
|
|||
if not config:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
if is_expert and provider == "openai":
|
||||
model_name = select_expert_model(provider, model_name)
|
||||
if not model_name:
|
||||
raise ValueError("No suitable expert model available")
|
||||
|
||||
logger.debug(
|
||||
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
|
||||
provider,
|
||||
|
|
|
|||
|
|
@ -47,32 +47,20 @@ class OpenAIStrategy(ProviderStrategy):
|
|||
if not key:
|
||||
missing.append("EXPERT_OPENAI_API_KEY environment variable is not set")
|
||||
|
||||
# Check expert model only for research-only mode
|
||||
if hasattr(args, "research_only") and args.research_only:
|
||||
model = args.expert_model if hasattr(args, "expert_model") else None
|
||||
if not model:
|
||||
model = os.environ.get("EXPERT_OPENAI_MODEL")
|
||||
if not model:
|
||||
model = os.environ.get("OPENAI_MODEL")
|
||||
if not model:
|
||||
missing.append(
|
||||
"Model is required for OpenAI provider in research-only mode"
|
||||
)
|
||||
# Handle expert model selection if none specified
|
||||
if hasattr(args, "expert_model") and not args.expert_model:
|
||||
from ra_aid.llm import select_expert_model
|
||||
model = select_expert_model("openai")
|
||||
if model:
|
||||
args.expert_model = model
|
||||
elif hasattr(args, "research_only") and args.research_only:
|
||||
missing.append("No suitable expert model available")
|
||||
|
||||
else:
|
||||
key = os.environ.get("OPENAI_API_KEY")
|
||||
if not key:
|
||||
missing.append("OPENAI_API_KEY environment variable is not set")
|
||||
|
||||
# Check model only for research-only mode
|
||||
if hasattr(args, "research_only") and args.research_only:
|
||||
model = args.model if hasattr(args, "model") else None
|
||||
if not model:
|
||||
model = os.environ.get("OPENAI_MODEL")
|
||||
if not model:
|
||||
missing.append(
|
||||
"Model is required for OpenAI provider in research-only mode"
|
||||
)
|
||||
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -12,10 +13,12 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
|
|||
from ra_aid.env import validate_environment
|
||||
from ra_aid.llm import (
|
||||
create_llm_client,
|
||||
get_available_openai_models,
|
||||
get_env_var,
|
||||
get_provider_config,
|
||||
initialize_expert_llm,
|
||||
initialize_llm,
|
||||
select_expert_model,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -281,6 +284,53 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
|
|||
)
|
||||
|
||||
|
||||
def test_get_available_openai_models_success():
|
||||
"""Test successful retrieval of OpenAI models."""
|
||||
mock_model = Mock()
|
||||
mock_model.id = "gpt-4"
|
||||
mock_models = Mock()
|
||||
mock_models.data = [mock_model]
|
||||
|
||||
with mock.patch("ra_aid.llm.OpenAI") as mock_client:
|
||||
mock_client.return_value.models.list.return_value = mock_models
|
||||
models = get_available_openai_models()
|
||||
assert models == ["gpt-4"]
|
||||
mock_client.return_value.models.list.assert_called_once()
|
||||
|
||||
def test_get_available_openai_models_failure():
|
||||
"""Test graceful handling of model retrieval failure."""
|
||||
with mock.patch("ra_aid.llm.OpenAI") as mock_client:
|
||||
mock_client.return_value.models.list.side_effect = Exception("API Error")
|
||||
models = get_available_openai_models()
|
||||
assert models == []
|
||||
mock_client.return_value.models.list.assert_called_once()
|
||||
|
||||
def test_select_expert_model_explicit():
|
||||
"""Test model selection with explicitly specified model."""
|
||||
model = select_expert_model("openai", "gpt-4")
|
||||
assert model == "gpt-4"
|
||||
|
||||
def test_select_expert_model_non_openai():
|
||||
"""Test model selection for non-OpenAI provider."""
|
||||
model = select_expert_model("anthropic", None)
|
||||
assert model is None
|
||||
|
||||
def test_select_expert_model_priority():
|
||||
"""Test model selection follows priority order."""
|
||||
available_models = ["gpt-4", "o1", "o3-mini"]
|
||||
|
||||
with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models):
|
||||
model = select_expert_model("openai")
|
||||
assert model == "o3-mini"
|
||||
|
||||
def test_select_expert_model_no_match():
|
||||
"""Test model selection when no priority models available."""
|
||||
available_models = ["gpt-4", "gpt-3.5"]
|
||||
|
||||
with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models):
|
||||
model = select_expert_model("openai")
|
||||
assert model is None
|
||||
|
||||
def test_temperature_validation(clean_env, mock_openai):
|
||||
"""Test temperature validation in command line arguments."""
|
||||
from ra_aid.__main__ import parse_arguments
|
||||
|
|
|
|||
Loading…
Reference in New Issue