feat(main.py): import models_params and set default temperature for models that support it to improve user experience
fix(ciayn_agent.py): update fallback tool error messages to use FallbackToolExecutionError for better error handling fix(config.py): remove unnecessary blank line to maintain code style consistency fix(fallback_handler.py): raise FallbackToolExecutionError for better error clarity when tools are not found fix(llm.py): set default temperature to 0.7 and notify user when not provided for models that support it test(test_llm.py): update tests to check for default temperature behavior and improve error messages for unsupported providers
This commit is contained in:
parent
6970a885e4
commit
0df5d43333
|
|
@ -8,6 +8,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
from ra_aid.models_params import models_params
|
||||||
|
|
||||||
from ra_aid import print_error, print_stage_header
|
from ra_aid import print_error, print_stage_header
|
||||||
from ra_aid.__version__ import __version__
|
from ra_aid.__version__ import __version__
|
||||||
|
|
@ -23,10 +24,12 @@ from ra_aid.config import (
|
||||||
DEFAULT_RECURSION_LIMIT,
|
DEFAULT_RECURSION_LIMIT,
|
||||||
VALID_PROVIDERS,
|
VALID_PROVIDERS,
|
||||||
)
|
)
|
||||||
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.dependencies import check_dependencies
|
from ra_aid.dependencies import check_dependencies
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.logging_config import get_logger, setup_logging
|
from ra_aid.logging_config import get_logger, setup_logging
|
||||||
|
from ra_aid.models_params import DEFAULT_TEMPERATURE
|
||||||
from ra_aid.project_info import format_project_info, get_project_info
|
from ra_aid.project_info import format_project_info, get_project_info
|
||||||
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
|
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||||
from ra_aid.tool_configs import get_chat_tools
|
from ra_aid.tool_configs import get_chat_tools
|
||||||
|
|
@ -309,7 +312,6 @@ def main():
|
||||||
logger.debug("Environment validation successful")
|
logger.debug("Environment validation successful")
|
||||||
|
|
||||||
# Validate model configuration early
|
# Validate model configuration early
|
||||||
from ra_aid.models_params import models_params
|
|
||||||
|
|
||||||
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
|
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
|
||||||
supports_temperature = model_config.get(
|
supports_temperature = model_config.get(
|
||||||
|
|
@ -321,10 +323,10 @@ def main():
|
||||||
if supports_temperature and args.temperature is None:
|
if supports_temperature and args.temperature is None:
|
||||||
args.temperature = model_config.get("default_temperature")
|
args.temperature = model_config.get("default_temperature")
|
||||||
if args.temperature is None:
|
if args.temperature is None:
|
||||||
print_error(
|
cpm(
|
||||||
f"Temperature must be provided for model {args.model} which supports temperature"
|
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
args.temperature = DEFAULT_TEMPERATURE
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Using default temperature {args.temperature} for model {args.model}"
|
f"Using default temperature {args.temperature} for model {args.model}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -143,8 +143,6 @@ class CiaynAgent:
|
||||||
if code.endswith("```"):
|
if code.endswith("```"):
|
||||||
code = code[:-3].strip()
|
code = code[:-3].strip()
|
||||||
|
|
||||||
# logger.debug(f"_execute_tool: stripped code: {code}")
|
|
||||||
|
|
||||||
# if the eval fails, try to extract it via a model call
|
# if the eval fails, try to extract it via a model call
|
||||||
if validate_function_call_pattern(code):
|
if validate_function_call_pattern(code):
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
|
|
@ -176,8 +174,8 @@ class CiaynAgent:
|
||||||
msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n"
|
msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n"
|
||||||
# Passing the fallback raw invocation may confuse our llm, as invocation methods may differ.
|
# Passing the fallback raw invocation may confuse our llm, as invocation methods may differ.
|
||||||
# msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n"
|
# msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n"
|
||||||
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>"
|
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
|
||||||
msg += f"<fallback tool call result>{fallback_response[1]}</fallback tool call result>"
|
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
|
||||||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||||
RETRY_FALLBACK_COUNT = 3
|
RETRY_FALLBACK_COUNT = 3
|
||||||
|
|
||||||
|
|
||||||
VALID_PROVIDERS = [
|
VALID_PROVIDERS = [
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"openai",
|
"openai",
|
||||||
|
|
|
||||||
|
|
@ -349,7 +349,7 @@ class FallbackHandler:
|
||||||
):
|
):
|
||||||
return self.current_tool_to_bind.invoke(arguments)
|
return self.current_tool_to_bind.invoke(arguments)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Tool '{name}' not found in available tools.")
|
raise FallbackToolExecutionError(f"Tool '{name}' not found in available tools.")
|
||||||
|
|
||||||
def base_message_to_tool_call_dict(self, response: BaseMessage):
|
def base_message_to_tool_call_dict(self, response: BaseMessage):
|
||||||
"""
|
"""
|
||||||
|
|
@ -365,7 +365,7 @@ class FallbackHandler:
|
||||||
tool_calls = self.get_tool_calls(response)
|
tool_calls = self.get_tool_calls(response)
|
||||||
|
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
raise Exception(
|
raise FallbackToolExecutionError(
|
||||||
f"Could not extract tool_call_dict from response: {response}"
|
f"Could not extract tool_call_dict from response: {response}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_openai import ChatOpenAI
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||||
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
from .models_params import models_params
|
from .models_params import models_params
|
||||||
|
|
@ -228,8 +229,9 @@ def create_llm_client(
|
||||||
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
||||||
elif supports_temperature:
|
elif supports_temperature:
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
raise ValueError(
|
temperature = 0.7
|
||||||
f"Temperature must be provided for model {model_name} which supports temperature"
|
cpm(
|
||||||
|
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||||
)
|
)
|
||||||
temp_kwargs = {"temperature": temperature}
|
temp_kwargs = {"temperature": temperature}
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -237,7 +237,7 @@ def test_initialize_openai_compatible(clean_env, mock_openai):
|
||||||
|
|
||||||
def test_initialize_unsupported_provider(clean_env):
|
def test_initialize_unsupported_provider(clean_env):
|
||||||
"""Test initialization with unsupported provider raises ValueError"""
|
"""Test initialization with unsupported provider raises ValueError"""
|
||||||
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"):
|
||||||
initialize_llm("unknown", "model")
|
initialize_llm("unknown", "model")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -259,15 +259,33 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin
|
||||||
max_retries=5,
|
max_retries=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test error when no temperature provided for models that support it
|
# Test default temperature when none is provided for models that support it
|
||||||
with pytest.raises(ValueError, match="Temperature must be provided for model"):
|
initialize_llm("openai", "test-model")
|
||||||
initialize_llm("openai", "test-model")
|
mock_openai.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
model="test-model",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=180,
|
||||||
|
max_retries=5,
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Temperature must be provided for model"):
|
initialize_llm("anthropic", "test-model")
|
||||||
initialize_llm("anthropic", "test-model")
|
mock_anthropic.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
model_name="test-model",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=180,
|
||||||
|
max_retries=5,
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Temperature must be provided for model"):
|
initialize_llm("gemini", "test-model")
|
||||||
initialize_llm("gemini", "test-model")
|
mock_gemini.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
model="test-model",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=180,
|
||||||
|
max_retries=5,
|
||||||
|
)
|
||||||
|
|
||||||
# Test expert models don't require temperature
|
# Test expert models don't require temperature
|
||||||
initialize_expert_llm("openai", "o1")
|
initialize_expert_llm("openai", "o1")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue