refactor(agent_utils.py): remove the _handle_tool_execution_error function and simplify error handling in run_agent_with_retry

feat(fallback_handler.py): enhance handle_failure method to extract tool name from ToolExecutionError and improve fallback logic
fix(exceptions.py): update ToolExecutionError to include base_message for better error context
feat(output.py): add base_message to ToolExecutionError for improved debugging
chore(tool_configs.py): update get_all_tools function to specify return type
style(logging_config.py): reorder imports for consistency
test(tests): add tests for new error handling and fallback logic in agent_utils and fallback_handler
This commit is contained in:
Ariel Frischer 2025-02-12 13:07:12 -08:00
parent a7322eaef2
commit af9f95ceb1
8 changed files with 252 additions and 208 deletions

View File

@ -9,7 +9,6 @@ import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence
from langgraph.graph.graph import CompiledGraph
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from langchain_core.language_models import BaseChatModel
@ -20,6 +19,7 @@ from langchain_core.messages import (
)
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import get_model_info
@ -876,9 +876,7 @@ def run_agent_with_retry(
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except ToolExecutionError as e:
fallback_response = _handle_tool_execution_error(
fallback_handler, agent, e
)
fallback_response = fallback_handler.handle_failure(e, agent)
if fallback_response:
prompt = original_prompt + "\n" + fallback_response
continue
@ -895,42 +893,3 @@ def run_agent_with_retry(
finally:
_decrement_agent_depth()
_restore_interrupt_handling(original_handler)
def _handle_tool_execution_error(
fallback_handler: FallbackHandler,
agent: CiaynAgent | CompiledGraph,
error: ToolExecutionError,
):
logger.debug("Entering _handle_tool_execution_error with error: %s", error)
if error.tool_name:
failed_tool_call_name = error.tool_name
logger.debug(
"Extracted failed_tool_call_name from error.tool_name: %s",
failed_tool_call_name,
)
else:
import re
msg = str(error)
logger.debug("Error message: %s", msg)
match = re.search(r"name=['\"](\w+)['\"]", msg)
if match:
failed_tool_call_name = match.group(1)
logger.debug(
"Extracted failed_tool_call_name using regex: %s", failed_tool_call_name
)
else:
failed_tool_call_name = "Tool execution error"
logger.debug(
"Defaulting failed_tool_call_name to: %s", failed_tool_call_name
)
logger.debug(
"Calling fallback_handler.handle_failure with failed_tool_call_name: %s",
failed_tool_call_name,
)
fallback_response = fallback_handler.handle_failure(
failed_tool_call_name, error, agent
)
logger.debug("Fallback response received: %s", fallback_response)
return fallback_response

View File

@ -44,7 +44,9 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
)
)
tool_name = getattr(msg, "name", None)
raise ToolExecutionError(err_msg, tool_name=tool_name)
cpm(f"type(msg): {type(msg)}")
cpm(f"msg: {msg}")
raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg)
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:

View File

@ -1,5 +1,9 @@
"""Custom exceptions for RA.Aid."""
from typing import Optional
from langchain_core.messages import BaseMessage
class AgentInterrupt(Exception):
"""Exception raised when an agent's execution is interrupted.
@ -17,6 +21,13 @@ class ToolExecutionError(Exception):
This exception is used to distinguish tool execution failures
from other types of errors in the agent system.
"""
def __init__(self, message, tool_name=None):
def __init__(
self,
message: str,
base_message: Optional[BaseMessage] = None,
tool_name: Optional[str] = None,
):
super().__init__(message)
self.base_message = base_message
self.tool_name = tool_name

View File

@ -1,20 +1,22 @@
import json
import re
from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import BaseMessage
from ra_aid.console.output import cpm
import json
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
)
from ra_aid.logging_config import get_logger
from ra_aid.tool_leaderboard import supported_top_tool_models
from rich.console import Console
from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError
from ra_aid.llm import initialize_llm, validate_provider_env
from ra_aid.logging_config import get_logger
from ra_aid.tool_configs import get_all_tools
from ra_aid.tool_leaderboard import supported_top_tool_models
logger = get_logger(__name__)
@ -41,9 +43,12 @@ class FallbackHandler:
self.tools: list[BaseTool] = tools
self.fallback_enabled = config.get("fallback_tool_enabled", True)
self.fallback_tool_models = self._load_fallback_tool_models(config)
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
self.tool_failure_consecutive_failures = 0
self.failed_messages = set()
self.tool_failure_used_fallbacks = set()
self.console = Console()
self.current_failing_tool_name = ""
self.current_tool_to_bind = None
def _load_fallback_tool_models(self, config):
"""
@ -87,66 +92,104 @@ class FallbackHandler:
cpm(message, title="Fallback Models")
return final_models
def handle_failure(
self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph
):
def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
Args:
code (str): The code that failed to execute.
error (Exception): The exception raised during execution.
logger: Logger instance for logging.
error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded.
agent: The agent instance on which fallback may be executed.
"""
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
)
self.tool_failure_consecutive_failures += 1
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
)
if not self.fallback_enabled:
return None
failed_tool_call_name = self.extract_failed_tool_name(error)
if (
self.fallback_enabled
and self.tool_failure_consecutive_failures >= max_failures
self.current_failing_tool_name
and failed_tool_call_name != self.current_failing_tool_name
):
logger.debug(
"New failing tool name identified. Resetting consecutive tool failures."
)
self.reset_fallback_handler()
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
)
self.current_failing_tool_name = failed_tool_call_name
self.current_tool_to_bind = self._find_tool_to_bind(
agent, failed_tool_call_name
)
if hasattr(error, "base_message") and error.base_message:
self.failed_messages.add(str(error.base_message))
self.tool_failure_consecutive_failures += 1
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}"
)
if (
self.tool_failure_consecutive_failures >= self.max_failures
and self.fallback_tool_models
):
logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
)
return self.attempt_fallback(code, logger, agent)
return self.attempt_fallback()
def attempt_fallback(self, code: str, logger, agent):
def attempt_fallback(self):
"""
Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method.
Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method.
Args:
code (str): The tool code that triggered the fallback.
logger: Logger instance for logging messages.
agent: The agent for which fallback is being executed.
Returns:
The response from a fallback model if any, otherwise None.
"""
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
fallback_model = self.fallback_tool_models[0]
failed_tool_call_name = code
logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
)
cpm(
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.",
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
title="Fallback Notification",
)
if fallback_model.get("type", "prompt").lower() == "fc":
self.attempt_fallback_function(code, logger, agent)
else:
self.attempt_fallback_prompt(code, logger, agent)
for fallback_model in self.fallback_tool_models:
if fallback_model.get("type", "prompt").lower() == "fc":
response = self.attempt_fallback_function(fallback_model)
else:
response = self.attempt_fallback_prompt(fallback_model)
if response:
return response
cpm("All fallback models have failed", title="Fallback Failed")
return None
def reset_fallback_handler(self):
"""
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
"""
self.tool_failure_consecutive_failures = 0
self.failed_messages.clear()
self.tool_failure_used_fallbacks.clear()
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
def extract_failed_tool_name(self, error: ToolExecutionError):
if error.tool_name:
failed_tool_call_name = error.tool_name
else:
msg = str(error)
logger.debug("Error message: %s", msg)
match = re.search(r"name=['\"](\w+)['\"]", msg)
if match:
failed_tool_call_name = str(match.group(1))
logger.debug(
"Extracted failed_tool_call_name using regex: %s",
failed_tool_call_name,
)
else:
failed_tool_call_name = "Tool execution error"
raise Exception("Fallback failed: Could not extract failed tool name.")
return failed_tool_call_name
def _find_tool_to_bind(self, agent, failed_tool_call_name):
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
@ -157,135 +200,108 @@ class FallbackHandler:
None,
)
if tool_to_bind is None:
from ra_aid.tool_configs import get_all_tools
all_tools = get_all_tools()
tool_to_bind = next(
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
None,
)
if tool_to_bind is None:
available = [t.func.__name__ for t in get_all_tools()]
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}"
# TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name
raise Exception(
f"Fallback failed: {failed_tool_call_name} not found in all tools."
)
raise Exception(f"Tool {failed_tool_call_name} not found in all tools.")
return tool_to_bind
def attempt_fallback_prompt(self, code: str, logger, agent):
def attempt_fallback_prompt(self, fallback_model):
"""
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model.
Args:
code (str): The tool code to invoke via fallback.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
fallback_model (dict): The fallback model to use.
Returns:
The response from the fallback model invocation.
Raises:
Exception: If all prompt-based fallback models fail.
The response from the fallback model invocation, or None if failed.
"""
logger.debug("Attempting prompt-based fallback using fallback models")
failed_tool_call_name = code
for fallback_model in self.fallback_tool_models:
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
binded_model = simple_model.bind_tools(
[tool_to_bind], tool_choice=failed_tool_call_name
)
# retry_model = binded_model.with_retry(
# stop_after_attempt=RETRY_FALLBACK_COUNT
# )
response = binded_model.invoke(code)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
tool_call = self.base_message_to_tool_call_dict(response)
if tool_call:
result = self.invoke_prompt_tool_call(tool_call)
cpm(f"result={result}")
logger.debug(
"Prompt-based fallback executed successfully with model: "
+ fallback_model["model"]
)
self.reset_fallback_handler()
return result
else:
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
)
raise Exception("All prompt-based fallback models failed")
def attempt_fallback_function(self, code: str, logger, agent):
"""
Attempt a function-calling fallback by iterating over fallback models and invoking the provided code.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Args:
code (str): The tool code to invoke via fallback.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
Returns:
The response from the fallback model invocation.
Raises:
Exception: If all function-calling fallback models fail.
"""
logger.debug("Attempting function-calling fallback using fallback models")
failed_tool_call_name = code
for fallback_model in self.fallback_tool_models:
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
binded_model = simple_model.bind_tools(
[tool_to_bind], tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(code)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
self.reset_fallback_handler()
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
binded_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(self.current_failing_tool_name)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
tool_call = self.base_message_to_tool_call_dict(response)
if tool_call:
result = self.invoke_prompt_tool_call(tool_call)
cpm(f"result={result}")
logger.debug(
"Function-calling fallback executed successfully with model: "
"Prompt-based fallback executed successfully with model: "
+ fallback_model["model"]
)
self.reset_fallback_handler()
return result
else:
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
)
raise Exception("All function-calling fallback models failed")
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
)
return None
def attempt_fallback_function(self, fallback_model):
"""
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
Args:
fallback_model (dict): The fallback model to use.
Returns:
The response from the fallback model invocation, or None if failed.
"""
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
binded_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(self.current_failing_tool_name)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
logger.debug(
"Function-calling fallback executed successfully with model: "
+ fallback_model["model"]
)
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
)
return None
def invoke_prompt_tool_call(self, tool_call_request: dict):
"""

View File

@ -1,9 +1,10 @@
import logging
import sys
from typing import Optional
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.panel import Panel
class PrettyHandler(logging.Handler):

View File

@ -1,3 +1,5 @@
from langchain_core.tools import BaseTool
from ra_aid.tools import (
ask_expert,
ask_human,
@ -61,11 +63,13 @@ def get_read_only_tools(
return tools
def get_all_tools_simple():
"""Return a list containing all available tools using existing group methods."""
return get_all_tools()
def get_all_tools():
def get_all_tools() -> list[BaseTool]:
"""Return a list containing all available tools from different groups."""
all_tools = []
all_tools.extend(get_read_only_tools())
@ -176,7 +180,7 @@ def get_implementation_tools(
return tools
def get_web_research_tools(expert_enabled: bool = True) -> list:
def get_web_research_tools(expert_enabled: bool = True):
"""Get the list of tools available for web research.
Args:
@ -196,9 +200,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
return tools
def get_chat_tools(
expert_enabled: bool = True, web_research_enabled: bool = False
) -> list:
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False):
"""Get the list of tools available in chat mode.
Chat mode includes research and implementation capabilities but excludes

View File

@ -276,11 +276,19 @@ def test_get_model_token_limit_planner(mock_memory):
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000
# New tests for private helper methods in agent_utils.py
def test_setup_and_restore_interrupt_handling():
import signal, threading
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
import signal
from ra_aid.agent_utils import (
_request_interrupt,
_restore_interrupt_handling,
_setup_interrupt_handling,
)
original_handler = signal.getsignal(signal.SIGINT)
handler = _setup_interrupt_handling()
# Verify the SIGINT handler is set to _request_interrupt
@ -289,66 +297,93 @@ def test_setup_and_restore_interrupt_handling():
# Verify the SIGINT handler is restored to the original
assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth():
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory
from ra_aid.agent_utils import (
_decrement_agent_depth,
_global_memory,
_increment_agent_depth,
)
_global_memory["agent_depth"] = 10
_increment_agent_depth()
assert _global_memory["agent_depth"] == 11
_decrement_agent_depth()
assert _global_memory["agent_depth"] == 10
def test_run_agent_stream(monkeypatch):
from ra_aid.agent_utils import _run_agent_stream, _global_memory
from ra_aid.agent_utils import _global_memory, _run_agent_stream
# Create a dummy agent that yields one chunk
class DummyAgent:
def stream(self, msg, cfg):
yield {"content": "chunk1"}
dummy_agent = DummyAgent()
# Set flags so that _run_agent_stream will reset them
_global_memory["plan_completed"] = True
_global_memory["task_completed"] = True
_global_memory["completion_message"] = "existing"
call_flag = {"called": False}
def fake_print_agent_output(chunk):
call_flag["called"] = True
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
monkeypatch.setattr(
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
)
_run_agent_stream(dummy_agent, "dummy prompt", {})
assert call_flag["called"]
assert _global_memory["plan_completed"] is False
assert _global_memory["task_completed"] is False
assert _global_memory["completion_message"] == ""
def test_execute_test_command_wrapper(monkeypatch):
from ra_aid.agent_utils import _execute_test_command_wrapper
# Patch execute_test_command to return a testable tuple
def fake_execute(config, orig, tests, auto):
return (True, "new prompt", auto, tests + 1)
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
result = _execute_test_command_wrapper("orig", {}, 0, False)
assert result == (True, "new prompt", False, 1)
def test_handle_api_error_valueerror():
from ra_aid.agent_utils import _handle_api_error
import pytest
from ra_aid.agent_utils import _handle_api_error
# ValueError not containing "code" or "429" should be re-raised
with pytest.raises(ValueError):
_handle_api_error(ValueError("some error"), 0, 5, 1)
def test_handle_api_error_max_retries():
from ra_aid.agent_utils import _handle_api_error
import pytest
from ra_aid.agent_utils import _handle_api_error
# When attempt reaches max retries, a RuntimeError should be raised
with pytest.raises(RuntimeError):
_handle_api_error(Exception("error code 429"), 4, 5, 1)
def test_handle_api_error_retry(monkeypatch):
from ra_aid.agent_utils import _handle_api_error
import time
from ra_aid.agent_utils import _handle_api_error
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
fake_time = [0]
def fake_monotonic():
fake_time[0] += 0.5
return fake_time[0]
monkeypatch.setattr(time, "monotonic", fake_monotonic)
monkeypatch.setattr(time, "sleep", lambda s: None)
# Should not raise error when attempt is lower than max retries

View File

@ -1,28 +1,41 @@
import unittest
from ra_aid.fallback_handler import FallbackHandler
class DummyLogger:
def debug(self, msg):
pass
def error(self, msg):
pass
class DummyAgent:
provider = "openai"
tools = []
model = None
class TestFallbackHandler(unittest.TestCase):
def setUp(self):
self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
self.config = {
"max_tool_failures": 2,
"fallback_tool_models": "dummy-fallback-model",
}
self.fallback_handler = FallbackHandler(self.config)
self.logger = DummyLogger()
self.agent = DummyAgent()
def test_handle_failure_increments_counter(self):
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent)
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1)
self.fallback_handler.handle_failure(
"dummy_call()", Exception("Test error"), self.logger, self.agent
)
self.assertEqual(
self.fallback_handler.tool_failure_consecutive_failures,
initial_failures + 1,
)
def test_attempt_fallback_resets_counter(self):
# Monkey-patch dummy functions for fallback components
@ -30,6 +43,7 @@ class TestFallbackHandler(unittest.TestCase):
class DummyModel:
def bind_tools(self, tools, tool_choice):
pass
return DummyModel()
def dummy_merge_chat_history():
@ -39,6 +53,7 @@ class TestFallbackHandler(unittest.TestCase):
return True
import ra_aid.llm as llm
original_initialize = llm.initialize_llm
original_merge = llm.merge_chat_history
original_validate = llm.validate_provider_env
@ -47,12 +62,15 @@ class TestFallbackHandler(unittest.TestCase):
llm.validate_provider_env = dummy_validate_provider_env
self.fallback_handler.tool_failure_consecutive_failures = 2
self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent)
self.fallback_handler.attempt_fallback(
"dummy_tool_call()", self.logger, self.agent
)
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
llm.initialize_llm = original_initialize
llm.merge_chat_history = original_merge
llm.validate_provider_env = original_validate
if __name__ == "__main__":
unittest.main()