refactor(fallback_handler.py): clean up code by removing unused imports and comments to enhance readability

refactor(fallback_handler.py): extract tool call extraction logic into a separate method for better organization and maintainability
refactor(fallback_handler.py): introduce _parse_tool_arguments method to handle argument parsing, improving code clarity and reusability
This commit is contained in:
Ariel Frischer 2025-02-11 18:38:52 -08:00
parent 67ecf72a6c
commit a7322eaef2
1 changed files with 37 additions and 22 deletions

View File

@ -1,4 +1,3 @@
from typing import Dict
from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import BaseMessage
@ -17,9 +16,6 @@ from ra_aid.tool_leaderboard import supported_top_tool_models
from rich.console import Console
from ra_aid.llm import initialize_llm, validate_provider_env
# from langgraph.graph.message import BaseMessage, BaseMessageChunk
# from langgraph.prebuilt import ToolNode
logger = get_logger(__name__)
@ -304,13 +300,11 @@ class FallbackHandler:
tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools}
name = tool_call_request["name"]
arguments = tool_call_request["arguments"]
# return tool_name_to_tool[name].invoke(arguments)
# tool_call_dict = {"arguments": arguments}
return tool_name_to_tool[name].invoke(arguments)
def base_message_to_tool_call_dict(self, response: BaseMessage):
"""
Extracts a tool call dictionary from a fallback response.
Extracts a tool call dictionary from a BaseMessage.
Args:
response: The response object containing tool call data.
@ -319,6 +313,41 @@ class FallbackHandler:
A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found,
otherwise None.
"""
tool_calls = self.get_tool_calls(response)
if tool_calls:
if len(tool_calls) > 1:
logger.warning("Multiple tool calls detected, using the first one")
tool_call = tool_calls[0]
return {
"id": tool_call["id"],
"type": tool_call["type"],
"name": tool_call["function"]["name"],
"arguments": self._parse_tool_arguments(
tool_call["function"]["arguments"]
),
}
return None
def _parse_tool_arguments(self, tool_arguments):
"""
Helper method to parse tool call arguments.
If tool_arguments is a string, it returns the JSON-parsed dictionary.
Otherwise, returns tool_arguments as is.
"""
if isinstance(tool_arguments, str):
return json.loads(tool_arguments)
return tool_arguments
def get_tool_calls(self, response: BaseMessage):
"""
Extracts tool calls list from a fallback response.
Args:
response: The response object containing tool call data.
Returns:
The tool calls list if present, otherwise None.
"""
tool_calls = None
if hasattr(response, "additional_kwargs") and response.additional_kwargs.get(
"tool_calls"
@ -330,18 +359,4 @@ class FallbackHandler:
"tool_calls"
):
tool_calls = response.get("additional_kwargs").get("tool_calls")
if tool_calls:
if len(tool_calls) > 1:
logger.warning("Multiple tool calls detected, using the first one")
tool_call = tool_calls[0]
return {
"id": tool_call["id"],
"type": tool_call["type"],
"name": tool_call["function"]["name"],
"arguments": (
json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"]
),
}
return None
return tool_calls