diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 7befb1d..a1c3603 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -1,6 +1,7 @@ import os from typing import Any, Dict, Optional +from openai import OpenAI from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI @@ -8,9 +9,51 @@ from langchain_openai import ChatOpenAI from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.logging_config import get_logger +from typing import List from .models_params import models_params +def get_available_openai_models() -> List[str]: + """Fetch available OpenAI models using OpenAI client. + + Returns: + List of available model names + """ + try: + # Use OpenAI client to fetch models + client = OpenAI() + models = client.models.list() + return [str(model.id) for model in models.data] + except Exception: + # Return empty list if unable to fetch models + return [] + +def select_expert_model(provider: str, model: Optional[str] = None) -> Optional[str]: + """Select appropriate expert model based on provider and availability. + + Args: + provider: The LLM provider + model: Optional explicitly specified model name + + Returns: + Selected model name or None if no suitable model found + """ + if provider != "openai" or model is not None: + return model + + # Try to get available models + available_models = get_available_openai_models() + + # Priority order for expert models + priority_models = ["o3-mini", "o1", "o1-preview"] + + # Return first available model from priority list + for model_name in priority_models: + if model_name in available_models: + return model_name + + return None + known_temp_providers = { "openai", "anthropic", @@ -150,6 +193,11 @@ def create_llm_client( if not config: raise ValueError(f"Unsupported provider: {provider}") + if is_expert and provider == "openai": + model_name = select_expert_model(provider, model_name) + if not model_name: + raise ValueError("No suitable expert model available") + logger.debug( "Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s", provider, diff --git a/ra_aid/provider_strategy.py b/ra_aid/provider_strategy.py index 735ef3a..173ef9d 100644 --- a/ra_aid/provider_strategy.py +++ b/ra_aid/provider_strategy.py @@ -47,32 +47,20 @@ class OpenAIStrategy(ProviderStrategy): if not key: missing.append("EXPERT_OPENAI_API_KEY environment variable is not set") - # Check expert model only for research-only mode - if hasattr(args, "research_only") and args.research_only: - model = args.expert_model if hasattr(args, "expert_model") else None - if not model: - model = os.environ.get("EXPERT_OPENAI_MODEL") - if not model: - model = os.environ.get("OPENAI_MODEL") - if not model: - missing.append( - "Model is required for OpenAI provider in research-only mode" - ) + # Handle expert model selection if none specified + if hasattr(args, "expert_model") and not args.expert_model: + from ra_aid.llm import select_expert_model + model = select_expert_model("openai") + if model: + args.expert_model = model + elif hasattr(args, "research_only") and args.research_only: + missing.append("No suitable expert model available") + else: key = os.environ.get("OPENAI_API_KEY") if not key: missing.append("OPENAI_API_KEY environment variable is not set") - # Check model only for research-only mode - if hasattr(args, "research_only") and args.research_only: - model = args.model if hasattr(args, "model") else None - if not model: - model = os.environ.get("OPENAI_MODEL") - if not model: - missing.append( - "Model is required for OpenAI provider in research-only mode" - ) - return ValidationResult(valid=len(missing) == 0, missing_vars=missing) diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index d38fb17..e6fe9e8 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from unittest import mock from unittest.mock import Mock, patch import pytest @@ -12,10 +13,12 @@ from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.env import validate_environment from ra_aid.llm import ( create_llm_client, + get_available_openai_models, get_env_var, get_provider_config, initialize_expert_llm, initialize_llm, + select_expert_model, ) @@ -281,6 +284,53 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin ) +def test_get_available_openai_models_success(): + """Test successful retrieval of OpenAI models.""" + mock_model = Mock() + mock_model.id = "gpt-4" + mock_models = Mock() + mock_models.data = [mock_model] + + with mock.patch("ra_aid.llm.OpenAI") as mock_client: + mock_client.return_value.models.list.return_value = mock_models + models = get_available_openai_models() + assert models == ["gpt-4"] + mock_client.return_value.models.list.assert_called_once() + +def test_get_available_openai_models_failure(): + """Test graceful handling of model retrieval failure.""" + with mock.patch("ra_aid.llm.OpenAI") as mock_client: + mock_client.return_value.models.list.side_effect = Exception("API Error") + models = get_available_openai_models() + assert models == [] + mock_client.return_value.models.list.assert_called_once() + +def test_select_expert_model_explicit(): + """Test model selection with explicitly specified model.""" + model = select_expert_model("openai", "gpt-4") + assert model == "gpt-4" + +def test_select_expert_model_non_openai(): + """Test model selection for non-OpenAI provider.""" + model = select_expert_model("anthropic", None) + assert model is None + +def test_select_expert_model_priority(): + """Test model selection follows priority order.""" + available_models = ["gpt-4", "o1", "o3-mini"] + + with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models): + model = select_expert_model("openai") + assert model == "o3-mini" + +def test_select_expert_model_no_match(): + """Test model selection when no priority models available.""" + available_models = ["gpt-4", "gpt-3.5"] + + with mock.patch("ra_aid.llm.get_available_openai_models", return_value=available_models): + model = select_expert_model("openai") + assert model is None + def test_temperature_validation(clean_env, mock_openai): """Test temperature validation in command line arguments.""" from ra_aid.__main__ import parse_arguments