From 00a455d586481f7b228560caecfefb8c8859f804 Mon Sep 17 00:00:00 2001 From: Jose M Leon Date: Sat, 8 Feb 2025 20:28:10 -0500 Subject: [PATCH] FIX do not default to o1 model (#82) --- ra_aid/__main__.py | 17 +++++++---------- ra_aid/llm.py | 2 +- ra_aid/tools/expert.py | 5 +++-- tests/ra_aid/test_llm.py | 4 ++-- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 66ed881..21868bd 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -206,7 +206,7 @@ Examples: if parsed_args.provider == "openai": parsed_args.model = parsed_args.model or OPENAI_DEFAULT_MODEL - if parsed_args.provider == "anthropic": + elif parsed_args.provider == "anthropic": # Always use default model for Anthropic parsed_args.model = ANTHROPIC_DEFAULT_MODEL elif not parsed_args.model and not parsed_args.research_only: @@ -215,15 +215,12 @@ Examples: f"--model is required when using provider '{parsed_args.provider}'" ) - # Validate expert model requirement - if ( - parsed_args.expert_provider != "openai" - and not parsed_args.expert_model - and not parsed_args.research_only - ): - parser.error( - f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'" - ) + # Handle expert provider/model defaults + if not parsed_args.expert_provider: + # If no expert provider specified, use main provider instead of defaulting to + # to any particular model since we do not know if we have access to any other model. + parsed_args.expert_provider = parsed_args.provider + parsed_args.expert_model = parsed_args.model # Validate temperature range if provided if parsed_args.temperature is not None and not ( diff --git a/ra_aid/llm.py b/ra_aid/llm.py index f95509e..4a4038a 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -220,7 +220,7 @@ def initialize_llm( def initialize_expert_llm( - provider: str = "openai", model_name: str = "o1" + provider: str, model_name: str ) -> BaseChatModel: """Initialize an expert language model client based on the specified provider and model.""" return create_llm_client(provider, model_name, temperature=None, is_expert=True) diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index ab41d0a..3b8188f 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -17,8 +17,9 @@ def get_model(): global _model try: if _model is None: - provider = _global_memory["config"]["expert_provider"] or "openai" - model = _global_memory["config"]["expert_model"] or "o1" + config = _global_memory["config"] + provider = config.get("expert_provider") or config.get("provider") + model = config.get("expert_model") or config.get("model") _model = initialize_expert_llm(provider, model) except Exception as e: _model = None diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 1f16ad2..2e7ea10 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -50,9 +50,9 @@ def mock_openai(): def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch): - """Test expert LLM initialization with default parameters.""" + """Test expert LLM initialization with explicit parameters.""" monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key") - _llm = initialize_expert_llm() + _llm = initialize_expert_llm("openai", "o1") mock_openai.assert_called_once_with(api_key="test-key", model="o1", reasoning_effort="high")