refactor(fallback_handler.py): clean up code by removing unused imports and comments to enhance readability
refactor(fallback_handler.py): extract tool call extraction logic into a separate method for better organization and maintainability refactor(fallback_handler.py): introduce _parse_tool_arguments method to handle argument parsing, improving code clarity and reusability
This commit is contained in:
parent
67ecf72a6c
commit
a7322eaef2
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Dict
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
from langgraph.graph.message import BaseMessage
|
||||
|
|
@ -17,9 +16,6 @@ from ra_aid.tool_leaderboard import supported_top_tool_models
|
|||
from rich.console import Console
|
||||
from ra_aid.llm import initialize_llm, validate_provider_env
|
||||
|
||||
# from langgraph.graph.message import BaseMessage, BaseMessageChunk
|
||||
# from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
|
@ -304,13 +300,11 @@ class FallbackHandler:
|
|||
tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools}
|
||||
name = tool_call_request["name"]
|
||||
arguments = tool_call_request["arguments"]
|
||||
# return tool_name_to_tool[name].invoke(arguments)
|
||||
# tool_call_dict = {"arguments": arguments}
|
||||
return tool_name_to_tool[name].invoke(arguments)
|
||||
|
||||
def base_message_to_tool_call_dict(self, response: BaseMessage):
|
||||
"""
|
||||
Extracts a tool call dictionary from a fallback response.
|
||||
Extracts a tool call dictionary from a BaseMessage.
|
||||
|
||||
Args:
|
||||
response: The response object containing tool call data.
|
||||
|
|
@ -319,6 +313,41 @@ class FallbackHandler:
|
|||
A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found,
|
||||
otherwise None.
|
||||
"""
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
if tool_calls:
|
||||
if len(tool_calls) > 1:
|
||||
logger.warning("Multiple tool calls detected, using the first one")
|
||||
tool_call = tool_calls[0]
|
||||
return {
|
||||
"id": tool_call["id"],
|
||||
"type": tool_call["type"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": self._parse_tool_arguments(
|
||||
tool_call["function"]["arguments"]
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
def _parse_tool_arguments(self, tool_arguments):
|
||||
"""
|
||||
Helper method to parse tool call arguments.
|
||||
If tool_arguments is a string, it returns the JSON-parsed dictionary.
|
||||
Otherwise, returns tool_arguments as is.
|
||||
"""
|
||||
if isinstance(tool_arguments, str):
|
||||
return json.loads(tool_arguments)
|
||||
return tool_arguments
|
||||
|
||||
def get_tool_calls(self, response: BaseMessage):
|
||||
"""
|
||||
Extracts tool calls list from a fallback response.
|
||||
|
||||
Args:
|
||||
response: The response object containing tool call data.
|
||||
|
||||
Returns:
|
||||
The tool calls list if present, otherwise None.
|
||||
"""
|
||||
tool_calls = None
|
||||
if hasattr(response, "additional_kwargs") and response.additional_kwargs.get(
|
||||
"tool_calls"
|
||||
|
|
@ -330,18 +359,4 @@ class FallbackHandler:
|
|||
"tool_calls"
|
||||
):
|
||||
tool_calls = response.get("additional_kwargs").get("tool_calls")
|
||||
if tool_calls:
|
||||
if len(tool_calls) > 1:
|
||||
logger.warning("Multiple tool calls detected, using the first one")
|
||||
tool_call = tool_calls[0]
|
||||
return {
|
||||
"id": tool_call["id"],
|
||||
"type": tool_call["type"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": (
|
||||
json.loads(tool_call["function"]["arguments"])
|
||||
if isinstance(tool_call["function"]["arguments"], str)
|
||||
else tool_call["function"]["arguments"]
|
||||
),
|
||||
}
|
||||
return None
|
||||
return tool_calls
|
||||
|
|
|
|||
Loading…
Reference in New Issue