diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 2a15872..46bd819 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -8,6 +8,7 @@ from langgraph.checkpoint.memory import MemorySaver from rich.console import Console from rich.panel import Panel from rich.text import Text +from ra_aid.models_params import models_params from ra_aid import print_error, print_stage_header from ra_aid.__version__ import __version__ @@ -23,10 +24,12 @@ from ra_aid.config import ( DEFAULT_RECURSION_LIMIT, VALID_PROVIDERS, ) +from ra_aid.console.output import cpm from ra_aid.dependencies import check_dependencies from ra_aid.env import validate_environment from ra_aid.llm import initialize_llm from ra_aid.logging_config import get_logger, setup_logging +from ra_aid.models_params import DEFAULT_TEMPERATURE from ra_aid.project_info import format_project_info, get_project_info from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT from ra_aid.tool_configs import get_chat_tools @@ -309,7 +312,6 @@ def main(): logger.debug("Environment validation successful") # Validate model configuration early - from ra_aid.models_params import models_params model_config = models_params.get(args.provider, {}).get(args.model or "", {}) supports_temperature = model_config.get( @@ -321,10 +323,10 @@ def main(): if supports_temperature and args.temperature is None: args.temperature = model_config.get("default_temperature") if args.temperature is None: - print_error( - f"Temperature must be provided for model {args.model} which supports temperature" + cpm( + f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}." ) - sys.exit(1) + args.temperature = DEFAULT_TEMPERATURE logger.debug( f"Using default temperature {args.temperature} for model {args.model}" ) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 18907ce..ca5a8db 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -143,8 +143,6 @@ class CiaynAgent: if code.endswith("```"): code = code[:-3].strip() - # logger.debug(f"_execute_tool: stripped code: {code}") - # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) @@ -176,8 +174,8 @@ class CiaynAgent: msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n" # Passing the fallback raw invocation may confuse our llm, as invocation methods may differ. # msg += f"{fallback_response[0]}\n" - msg += f"{e.tool_name}" - msg += f"{fallback_response[1]}" + msg += f"{e.tool_name}\n" + msg += f"\n{fallback_response[1]}\n\n" return msg def _create_agent_chunk(self, content: str) -> Dict[str, Any]: diff --git a/ra_aid/config.py b/ra_aid/config.py index 2f5eab0..9393ba0 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 + VALID_PROVIDERS = [ "anthropic", "openai", diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index fdf45c6..1958a00 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -349,7 +349,7 @@ class FallbackHandler: ): return self.current_tool_to_bind.invoke(arguments) else: - raise Exception(f"Tool '{name}' not found in available tools.") + raise FallbackToolExecutionError(f"Tool '{name}' not found in available tools.") def base_message_to_tool_call_dict(self, response: BaseMessage): """ @@ -365,7 +365,7 @@ class FallbackHandler: tool_calls = self.get_tool_calls(response) if not tool_calls: - raise Exception( + raise FallbackToolExecutionError( f"Could not extract tool_call_dict from response: {response}" ) diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 4558beb..506028c 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -9,6 +9,7 @@ from langchain_openai import ChatOpenAI from openai import OpenAI from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner +from ra_aid.console.output import cpm from ra_aid.logging_config import get_logger from .models_params import models_params @@ -228,8 +229,9 @@ def create_llm_client( temp_kwargs = {"temperature": 0} if supports_temperature else {} elif supports_temperature: if temperature is None: - raise ValueError( - f"Temperature must be provided for model {model_name} which supports temperature" + temperature = 0.7 + cpm( + "This model supports temperature argument but none was given. Setting default temperature to 0.7." ) temp_kwargs = {"temperature": temperature} else: diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 8314367..853be14 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -237,7 +237,7 @@ def test_initialize_openai_compatible(clean_env, mock_openai): def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" - with pytest.raises(ValueError, match=r"Unsupported provider: unknown"): + with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"): initialize_llm("unknown", "model") @@ -259,15 +259,33 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin max_retries=5, ) - # Test error when no temperature provided for models that support it - with pytest.raises(ValueError, match="Temperature must be provided for model"): - initialize_llm("openai", "test-model") + # Test default temperature when none is provided for models that support it + initialize_llm("openai", "test-model") + mock_openai.assert_called_with( + api_key="test-key", + model="test-model", + temperature=0.7, + timeout=180, + max_retries=5, + ) - with pytest.raises(ValueError, match="Temperature must be provided for model"): - initialize_llm("anthropic", "test-model") + initialize_llm("anthropic", "test-model") + mock_anthropic.assert_called_with( + api_key="test-key", + model_name="test-model", + temperature=0.7, + timeout=180, + max_retries=5, + ) - with pytest.raises(ValueError, match="Temperature must be provided for model"): - initialize_llm("gemini", "test-model") + initialize_llm("gemini", "test-model") + mock_gemini.assert_called_with( + api_key="test-key", + model="test-model", + temperature=0.7, + timeout=180, + max_retries=5, + ) # Test expert models don't require temperature initialize_expert_llm("openai", "o1")