feat(main.py): import models_params and set default temperature for models that support it to improve user experience

fix(ciayn_agent.py): update fallback tool error messages to use FallbackToolExecutionError for better error handling
fix(config.py): remove unnecessary blank line to maintain code style consistency
fix(fallback_handler.py): raise FallbackToolExecutionError for better error clarity when tools are not found
fix(llm.py): set default temperature to 0.7 and notify user when not provided for models that support it
test(test_llm.py): update tests to check for default temperature behavior and improve error messages for unsupported providers
This commit is contained in:
Ariel Frischer 2025-02-14 13:50:32 -08:00
parent 6970a885e4
commit 0df5d43333
6 changed files with 41 additions and 20 deletions

View File

@ -8,6 +8,7 @@ from langgraph.checkpoint.memory import MemorySaver
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from ra_aid.models_params import models_params
from ra_aid import print_error, print_stage_header
from ra_aid.__version__ import __version__
@ -23,10 +24,12 @@ from ra_aid.config import (
DEFAULT_RECURSION_LIMIT,
VALID_PROVIDERS,
)
from ra_aid.console.output import cpm
from ra_aid.dependencies import check_dependencies
from ra_aid.env import validate_environment
from ra_aid.llm import initialize_llm
from ra_aid.logging_config import get_logger, setup_logging
from ra_aid.models_params import DEFAULT_TEMPERATURE
from ra_aid.project_info import format_project_info, get_project_info
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
from ra_aid.tool_configs import get_chat_tools
@ -309,7 +312,6 @@ def main():
logger.debug("Environment validation successful")
# Validate model configuration early
from ra_aid.models_params import models_params
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
supports_temperature = model_config.get(
@ -321,10 +323,10 @@ def main():
if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature")
if args.temperature is None:
print_error(
f"Temperature must be provided for model {args.model} which supports temperature"
cpm(
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
)
sys.exit(1)
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
)

View File

@ -143,8 +143,6 @@ class CiaynAgent:
if code.endswith("```"):
code = code[:-3].strip()
# logger.debug(f"_execute_tool: stripped code: {code}")
# if the eval fails, try to extract it via a model call
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
@ -176,8 +174,8 @@ class CiaynAgent:
msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n"
# Passing the fallback raw invocation may confuse our llm, as invocation methods may differ.
# msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n"
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>"
msg += f"<fallback tool call result>{fallback_response[1]}</fallback tool call result>"
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
return msg
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:

View File

@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
FALLBACK_TOOL_MODEL_LIMIT = 5
RETRY_FALLBACK_COUNT = 3
VALID_PROVIDERS = [
"anthropic",
"openai",

View File

@ -349,7 +349,7 @@ class FallbackHandler:
):
return self.current_tool_to_bind.invoke(arguments)
else:
raise Exception(f"Tool '{name}' not found in available tools.")
raise FallbackToolExecutionError(f"Tool '{name}' not found in available tools.")
def base_message_to_tool_call_dict(self, response: BaseMessage):
"""
@ -365,7 +365,7 @@ class FallbackHandler:
tool_calls = self.get_tool_calls(response)
if not tool_calls:
raise Exception(
raise FallbackToolExecutionError(
f"Could not extract tool_call_dict from response: {response}"
)

View File

@ -9,6 +9,7 @@ from langchain_openai import ChatOpenAI
from openai import OpenAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.console.output import cpm
from ra_aid.logging_config import get_logger
from .models_params import models_params
@ -228,8 +229,9 @@ def create_llm_client(
temp_kwargs = {"temperature": 0} if supports_temperature else {}
elif supports_temperature:
if temperature is None:
raise ValueError(
f"Temperature must be provided for model {model_name} which supports temperature"
temperature = 0.7
cpm(
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
)
temp_kwargs = {"temperature": temperature}
else:

View File

@ -237,7 +237,7 @@ def test_initialize_openai_compatible(clean_env, mock_openai):
def test_initialize_unsupported_provider(clean_env):
"""Test initialization with unsupported provider raises ValueError"""
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"):
initialize_llm("unknown", "model")
@ -259,15 +259,33 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin
max_retries=5,
)
# Test error when no temperature provided for models that support it
with pytest.raises(ValueError, match="Temperature must be provided for model"):
initialize_llm("openai", "test-model")
# Test default temperature when none is provided for models that support it
initialize_llm("openai", "test-model")
mock_openai.assert_called_with(
api_key="test-key",
model="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
with pytest.raises(ValueError, match="Temperature must be provided for model"):
initialize_llm("anthropic", "test-model")
initialize_llm("anthropic", "test-model")
mock_anthropic.assert_called_with(
api_key="test-key",
model_name="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
with pytest.raises(ValueError, match="Temperature must be provided for model"):
initialize_llm("gemini", "test-model")
initialize_llm("gemini", "test-model")
mock_gemini.assert_called_with(
api_key="test-key",
model="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
# Test expert models don't require temperature
initialize_expert_llm("openai", "o1")