feat(main.py): import models_params and set default temperature for models that support it to improve user experience

fix(ciayn_agent.py): update fallback tool error messages to use FallbackToolExecutionError for better error handling
fix(config.py): remove unnecessary blank line to maintain code style consistency
fix(fallback_handler.py): raise FallbackToolExecutionError for better error clarity when tools are not found
fix(llm.py): set default temperature to 0.7 and notify user when not provided for models that support it
test(test_llm.py): update tests to check for default temperature behavior and improve error messages for unsupported providers
This commit is contained in:
Ariel Frischer 2025-02-14 13:50:32 -08:00
parent 6970a885e4
commit 0df5d43333
6 changed files with 41 additions and 20 deletions

View File

@ -8,6 +8,7 @@ from langgraph.checkpoint.memory import MemorySaver
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.text import Text from rich.text import Text
from ra_aid.models_params import models_params
from ra_aid import print_error, print_stage_header from ra_aid import print_error, print_stage_header
from ra_aid.__version__ import __version__ from ra_aid.__version__ import __version__
@ -23,10 +24,12 @@ from ra_aid.config import (
DEFAULT_RECURSION_LIMIT, DEFAULT_RECURSION_LIMIT,
VALID_PROVIDERS, VALID_PROVIDERS,
) )
from ra_aid.console.output import cpm
from ra_aid.dependencies import check_dependencies from ra_aid.dependencies import check_dependencies
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.logging_config import get_logger, setup_logging from ra_aid.logging_config import get_logger, setup_logging
from ra_aid.models_params import DEFAULT_TEMPERATURE
from ra_aid.project_info import format_project_info, get_project_info from ra_aid.project_info import format_project_info, get_project_info
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
from ra_aid.tool_configs import get_chat_tools from ra_aid.tool_configs import get_chat_tools
@ -309,7 +312,6 @@ def main():
logger.debug("Environment validation successful") logger.debug("Environment validation successful")
# Validate model configuration early # Validate model configuration early
from ra_aid.models_params import models_params
model_config = models_params.get(args.provider, {}).get(args.model or "", {}) model_config = models_params.get(args.provider, {}).get(args.model or "", {})
supports_temperature = model_config.get( supports_temperature = model_config.get(
@ -321,10 +323,10 @@ def main():
if supports_temperature and args.temperature is None: if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature") args.temperature = model_config.get("default_temperature")
if args.temperature is None: if args.temperature is None:
print_error( cpm(
f"Temperature must be provided for model {args.model} which supports temperature" f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
) )
sys.exit(1) args.temperature = DEFAULT_TEMPERATURE
logger.debug( logger.debug(
f"Using default temperature {args.temperature} for model {args.model}" f"Using default temperature {args.temperature} for model {args.model}"
) )

View File

@ -143,8 +143,6 @@ class CiaynAgent:
if code.endswith("```"): if code.endswith("```"):
code = code[:-3].strip() code = code[:-3].strip()
# logger.debug(f"_execute_tool: stripped code: {code}")
# if the eval fails, try to extract it via a model call # if the eval fails, try to extract it via a model call
if validate_function_call_pattern(code): if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions) functions_list = "\n\n".join(self.available_functions)
@ -176,8 +174,8 @@ class CiaynAgent:
msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n" msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n"
# Passing the fallback raw invocation may confuse our llm, as invocation methods may differ. # Passing the fallback raw invocation may confuse our llm, as invocation methods may differ.
# msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n" # msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n"
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>" msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
msg += f"<fallback tool call result>{fallback_response[1]}</fallback tool call result>" msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
return msg return msg
def _create_agent_chunk(self, content: str) -> Dict[str, Any]: def _create_agent_chunk(self, content: str) -> Dict[str, Any]:

View File

@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
FALLBACK_TOOL_MODEL_LIMIT = 5 FALLBACK_TOOL_MODEL_LIMIT = 5
RETRY_FALLBACK_COUNT = 3 RETRY_FALLBACK_COUNT = 3
VALID_PROVIDERS = [ VALID_PROVIDERS = [
"anthropic", "anthropic",
"openai", "openai",

View File

@ -349,7 +349,7 @@ class FallbackHandler:
): ):
return self.current_tool_to_bind.invoke(arguments) return self.current_tool_to_bind.invoke(arguments)
else: else:
raise Exception(f"Tool '{name}' not found in available tools.") raise FallbackToolExecutionError(f"Tool '{name}' not found in available tools.")
def base_message_to_tool_call_dict(self, response: BaseMessage): def base_message_to_tool_call_dict(self, response: BaseMessage):
""" """
@ -365,7 +365,7 @@ class FallbackHandler:
tool_calls = self.get_tool_calls(response) tool_calls = self.get_tool_calls(response)
if not tool_calls: if not tool_calls:
raise Exception( raise FallbackToolExecutionError(
f"Could not extract tool_call_dict from response: {response}" f"Could not extract tool_call_dict from response: {response}"
) )

View File

@ -9,6 +9,7 @@ from langchain_openai import ChatOpenAI
from openai import OpenAI from openai import OpenAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.console.output import cpm
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from .models_params import models_params from .models_params import models_params
@ -228,8 +229,9 @@ def create_llm_client(
temp_kwargs = {"temperature": 0} if supports_temperature else {} temp_kwargs = {"temperature": 0} if supports_temperature else {}
elif supports_temperature: elif supports_temperature:
if temperature is None: if temperature is None:
raise ValueError( temperature = 0.7
f"Temperature must be provided for model {model_name} which supports temperature" cpm(
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
) )
temp_kwargs = {"temperature": temperature} temp_kwargs = {"temperature": temperature}
else: else:

View File

@ -237,7 +237,7 @@ def test_initialize_openai_compatible(clean_env, mock_openai):
def test_initialize_unsupported_provider(clean_env): def test_initialize_unsupported_provider(clean_env):
"""Test initialization with unsupported provider raises ValueError""" """Test initialization with unsupported provider raises ValueError"""
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"): with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"):
initialize_llm("unknown", "model") initialize_llm("unknown", "model")
@ -259,15 +259,33 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin
max_retries=5, max_retries=5,
) )
# Test error when no temperature provided for models that support it # Test default temperature when none is provided for models that support it
with pytest.raises(ValueError, match="Temperature must be provided for model"): initialize_llm("openai", "test-model")
initialize_llm("openai", "test-model") mock_openai.assert_called_with(
api_key="test-key",
model="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
with pytest.raises(ValueError, match="Temperature must be provided for model"): initialize_llm("anthropic", "test-model")
initialize_llm("anthropic", "test-model") mock_anthropic.assert_called_with(
api_key="test-key",
model_name="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
with pytest.raises(ValueError, match="Temperature must be provided for model"): initialize_llm("gemini", "test-model")
initialize_llm("gemini", "test-model") mock_gemini.assert_called_with(
api_key="test-key",
model="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
# Test expert models don't require temperature # Test expert models don't require temperature
initialize_expert_llm("openai", "o1") initialize_expert_llm("openai", "o1")