auto detect openai expert models
This commit is contained in:
parent
c9d7e90312
commit
e3a705eb9b
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
@ -8,9 +9,51 @@ from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from .models_params import models_params
|
from .models_params import models_params
|
||||||
|
|
||||||
|
def get_available_openai_models() -> List[str]:
|
||||||
|
"""Fetch available OpenAI models using OpenAI client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available model names
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use OpenAI client to fetch models
|
||||||
|
client = OpenAI()
|
||||||
|
models = client.models.list()
|
||||||
|
return [str(model.id) for model in models.data]
|
||||||
|
except Exception:
|
||||||
|
# Return empty list if unable to fetch models
|
||||||
|
return []
|
||||||
|
|
||||||
|
def select_expert_model(provider: str, model: Optional[str] = None) -> Optional[str]:
|
||||||
|
"""Select appropriate expert model based on provider and availability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The LLM provider
|
||||||
|
model: Optional explicitly specified model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected model name or None if no suitable model found
|
||||||
|
"""
|
||||||
|
if provider != "openai" or model is not None:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Try to get available models
|
||||||
|
available_models = get_available_openai_models()
|
||||||
|
|
||||||
|
# Priority order for expert models
|
||||||
|
priority_models = ["o3-mini", "o1", "o1-preview"]
|
||||||
|
|
||||||
|
# Return first available model from priority list
|
||||||
|
for model_name in priority_models:
|
||||||
|
if model_name in available_models:
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
known_temp_providers = {
|
known_temp_providers = {
|
||||||
"openai",
|
"openai",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
|
|
@ -150,6 +193,11 @@ def create_llm_client(
|
||||||
if not config:
|
if not config:
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
|
if is_expert and provider == "openai":
|
||||||
|
model_name = select_expert_model(provider, model_name)
|
||||||
|
if not model_name:
|
||||||
|
raise ValueError("No suitable expert model available")
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
|
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
|
||||||
provider,
|
provider,
|
||||||
|
|
|
||||||
|
|
@ -47,32 +47,20 @@ class OpenAIStrategy(ProviderStrategy):
|
||||||
if not key:
|
if not key:
|
||||||
missing.append("EXPERT_OPENAI_API_KEY environment variable is not set")
|
missing.append("EXPERT_OPENAI_API_KEY environment variable is not set")
|
||||||
|
|
||||||
# Check expert model only for research-only mode
|
# Handle expert model selection if none specified
|
||||||
if hasattr(args, "research_only") and args.research_only:
|
if hasattr(args, "expert_model") and not args.expert_model:
|
||||||
model = args.expert_model if hasattr(args, "expert_model") else None
|
from ra_aid.llm import select_expert_model
|
||||||
if not model:
|
model = select_expert_model("openai")
|
||||||
model = os.environ.get("EXPERT_OPENAI_MODEL")
|
if model:
|
||||||
if not model:
|
args.expert_model = model
|
||||||
model = os.environ.get("OPENAI_MODEL")
|
elif hasattr(args, "research_only") and args.research_only:
|
||||||
if not model:
|
missing.append("No suitable expert model available")
|
||||||
missing.append(
|
|
||||||
"Model is required for OpenAI provider in research-only mode"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
key = os.environ.get("OPENAI_API_KEY")
|
key = os.environ.get("OPENAI_API_KEY")
|
||||||
if not key:
|
if not key:
|
||||||
missing.append("OPENAI_API_KEY environment variable is not set")
|
missing.append("OPENAI_API_KEY environment variable is not set")
|
||||||
|
|
||||||
# Check model only for research-only mode
|
|
||||||
if hasattr(args, "research_only") and args.research_only:
|
|
||||||
model = args.model if hasattr(args, "model") else None
|
|
||||||
if not model:
|
|
||||||
model = os.environ.get("OPENAI_MODEL")
|
|
||||||
if not model:
|
|
||||||
missing.append(
|
|
||||||
"Model is required for OpenAI provider in research-only mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from unittest import mock
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -12,10 +13,12 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.llm import (
|
from ra_aid.llm import (
|
||||||
create_llm_client,
|
create_llm_client,
|
||||||
|
get_available_openai_models,
|
||||||
get_env_var,
|
get_env_var,
|
||||||
get_provider_config,
|
get_provider_config,
|
||||||
initialize_expert_llm,
|
initialize_expert_llm,
|
||||||
initialize_llm,
|
initialize_llm,
|
||||||
|
select_expert_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -281,6 +284,53 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_available_openai_models_success():
|
||||||
|
"""Test successful retrieval of OpenAI models."""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_model.id = "gpt-4"
|
||||||
|
mock_models = Mock()
|
||||||
|
mock_models.data = [mock_model]
|
||||||
|
|
||||||
|
with mock.patch("ra_aid.llm.OpenAI") as mock_client:
|
||||||
|
mock_client.return_value.models.list.return_value = mock_models
|
||||||
|
models = get_available_openai_models()
|
||||||
|
assert models == ["gpt-4"]
|
||||||
|
mock_client.return_value.models.list.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_available_openai_models_failure():
|
||||||
|
"""Test graceful handling of model retrieval failure."""
|
||||||
|
with mock.patch("ra_aid.llm.OpenAI") as mock_client:
|
||||||
|
mock_client.return_value.models.list.side_effect = Exception("API Error")
|
||||||
|
models = get_available_openai_models()
|
||||||
|
assert models == []
|
||||||
|
mock_client.return_value.models.list.assert_called_once()
|
||||||
|
|
||||||
|
def test_select_expert_model_explicit():
|
||||||
|
"""Test model selection with explicitly specified model."""
|
||||||
|
model = select_expert_model("openai", "gpt-4")
|
||||||
|
assert model == "gpt-4"
|
||||||
|
|
||||||
|
def test_select_expert_model_non_openai():
|
||||||
|
"""Test model selection for non-OpenAI provider."""
|
||||||
|
model = select_expert_model("anthropic", None)
|
||||||
|
assert model is None
|
||||||
|
|
||||||
|
def test_select_expert_model_priority():
|
||||||
|
"""Test model selection follows priority order."""
|
||||||
|
available_models = ["gpt-4", "o1", "o3-mini"]
|
||||||
|
|
||||||
|
with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models):
|
||||||
|
model = select_expert_model("openai")
|
||||||
|
assert model == "o3-mini"
|
||||||
|
|
||||||
|
def test_select_expert_model_no_match():
|
||||||
|
"""Test model selection when no priority models available."""
|
||||||
|
available_models = ["gpt-4", "gpt-3.5"]
|
||||||
|
|
||||||
|
with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models):
|
||||||
|
model = select_expert_model("openai")
|
||||||
|
assert model is None
|
||||||
|
|
||||||
def test_temperature_validation(clean_env, mock_openai):
|
def test_temperature_validation(clean_env, mock_openai):
|
||||||
"""Test temperature validation in command line arguments."""
|
"""Test temperature validation in command line arguments."""
|
||||||
from ra_aid.__main__ import parse_arguments
|
from ra_aid.__main__ import parse_arguments
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue