From 149e8e22510bfa5b2b3a58f59abe2d34bef5646e Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Mon, 10 Feb 2025 11:41:27 -0500 Subject: [PATCH] set timeouts on llm clients --- ra_aid/chat_models/deepseek_chat.py | 4 ++-- ra_aid/llm.py | 24 +++++++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/ra_aid/chat_models/deepseek_chat.py b/ra_aid/chat_models/deepseek_chat.py index 0d04696..46a3cb5 100644 --- a/ra_aid/chat_models/deepseek_chat.py +++ b/ra_aid/chat_models/deepseek_chat.py @@ -10,8 +10,8 @@ from langchain_openai import ChatOpenAI class ChatDeepseekReasoner(ChatOpenAI): """ChatDeepseekReasoner with custom overrides for R1/reasoner models.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, timeout: int = 180, max_retries: int = 5, **kwargs): + super().__init__(*args, timeout=timeout, max_retries=max_retries, **kwargs) def invocation_params( self, options: Optional[Dict[str, Any]] = None, **kwargs: Any diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 4a4038a..3ab9aac 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -20,6 +20,10 @@ known_temp_providers = { "deepseek", } +# Constants for API request configuration +LLM_REQUEST_TIMEOUT = 180 +LLM_MAX_RETRIES = 5 + logger = get_logger(__name__) @@ -51,6 +55,8 @@ def create_deepseek_client( if is_expert else (temperature if temperature is not None else 1), model=model_name, + timeout=LLM_REQUEST_TIMEOUT, + max_retries=LLM_MAX_RETRIES, ) return ChatOpenAI( @@ -58,6 +64,8 @@ def create_deepseek_client( base_url=base_url, temperature=0 if is_expert else (temperature if temperature is not None else 1), model=model_name, + timeout=LLM_REQUEST_TIMEOUT, + max_retries=LLM_MAX_RETRIES, ) @@ -76,12 +84,16 @@ def create_openrouter_client( if is_expert else (temperature if temperature is not None else 1), model=model_name, + timeout=LLM_REQUEST_TIMEOUT, + max_retries=LLM_MAX_RETRIES, ) return ChatOpenAI( api_key=api_key, base_url="https://openrouter.ai/api/v1", model=model_name, + timeout=LLM_REQUEST_TIMEOUT, + max_retries=LLM_MAX_RETRIES, **({"temperature": temperature} if temperature is not None else {}), ) @@ -188,11 +200,17 @@ def create_llm_client( } if is_expert: openai_kwargs["reasoning_effort"] = "high" - return ChatOpenAI(**openai_kwargs) + return ChatOpenAI(**{ + **openai_kwargs, + "timeout": LLM_REQUEST_TIMEOUT, + "max_retries": LLM_MAX_RETRIES, + }) elif provider == "anthropic": return ChatAnthropic( api_key=config["api_key"], model_name=model_name, + timeout=LLM_REQUEST_TIMEOUT, + max_retries=LLM_MAX_RETRIES, **temp_kwargs, ) elif provider == "openai-compatible": @@ -200,12 +218,16 @@ def create_llm_client( api_key=config["api_key"], base_url=config["base_url"], model=model_name, + timeout=LLM_REQUEST_TIMEOUT, + max_retries=LLM_MAX_RETRIES, **temp_kwargs, ) elif provider == "gemini": return ChatGoogleGenerativeAI( api_key=config["api_key"], model=model_name, + timeout=LLM_REQUEST_TIMEOUT, + max_retries=LLM_MAX_RETRIES, **temp_kwargs, ) else: