feat(agent_utils.py): add debug print for config to assist in troubleshooting
feat(ciayn_agent.py): pass config to CiaynAgent for improved functionality fix(ciayn_agent.py): handle tool execution errors more gracefully with msg_list feat(fallback_handler.py): enhance handle_failure method to utilize msg_list for better context feat(fallback_handler.py): implement init_msg_list to manage message history effectively test(test_fallback_handler.py): add unit tests for init_msg_list to ensure correct behavior
This commit is contained in:
parent
9caa46bc78
commit
15ce534f8f
|
|
@ -270,6 +270,7 @@ def create_agent(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = _global_memory.get("config", {})
|
config = _global_memory.get("config", {})
|
||||||
|
print(f"config={config}")
|
||||||
max_input_tokens = (
|
max_input_tokens = (
|
||||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||||
)
|
)
|
||||||
|
|
@ -281,7 +282,7 @@ def create_agent(
|
||||||
return create_react_agent(model, tools, **agent_kwargs)
|
return create_react_agent(model, tools, **agent_kwargs)
|
||||||
else:
|
else:
|
||||||
logger.debug("Using CiaynAgent agent instance")
|
logger.debug("Using CiaynAgent agent instance")
|
||||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens)
|
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Default to REACT agent if provider/model detection fails
|
# Default to REACT agent if provider/model detection fails
|
||||||
|
|
@ -882,7 +883,7 @@ def _handle_fallback_response(
|
||||||
"""
|
"""
|
||||||
if not fallback_handler:
|
if not fallback_handler:
|
||||||
return
|
return
|
||||||
fallback_response = fallback_handler.handle_failure(error, agent)
|
fallback_response = fallback_handler.handle_failure(error, agent, msg_list)
|
||||||
agent_type = get_agent_type(agent)
|
agent_type = get_agent_type(agent)
|
||||||
if fallback_response and agent_type == "React":
|
if fallback_response and agent_type == "React":
|
||||||
msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response]
|
msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response]
|
||||||
|
|
|
||||||
|
|
@ -139,18 +139,24 @@ class CiaynAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code = code.strip()
|
code = code.strip()
|
||||||
|
if code.startswith("```"):
|
||||||
|
code = code[3:].strip()
|
||||||
|
if code.endswith("```"):
|
||||||
|
code = code[:-3].strip()
|
||||||
|
|
||||||
|
raise ToolExecutionError(
|
||||||
|
"err", base_message=msg, tool_name="ripgrep_search"
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"_execute_tool: stripped code: {code}")
|
logger.debug(f"_execute_tool: stripped code: {code}")
|
||||||
|
|
||||||
# if the eval fails, try to extract it via a model call
|
# if the eval fails, try to extract it via a model call
|
||||||
if validate_function_call_pattern(code):
|
if validate_function_call_pattern(code):
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
code = self._extract_tool_call(code, functions_list)
|
code = self._extract_tool_call(code, functions_list)
|
||||||
|
pass
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
|
|
||||||
)
|
|
||||||
result = eval(code.strip(), globals_dict)
|
result = eval(code.strip(), globals_dict)
|
||||||
logger.debug(f"_execute_tool: result: {result}")
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error: {str(e)} \n Could not excute code: {code}"
|
error_msg = f"Error: {str(e)} \n Could not excute code: {code}"
|
||||||
|
|
@ -230,6 +236,7 @@ class CiaynAgent:
|
||||||
raise ToolExecutionError("Failed to extract tool call")
|
raise ToolExecutionError("Failed to extract tool call")
|
||||||
ma = matches[0][0].strip()
|
ma = matches[0][0].strip()
|
||||||
mb = matches[0][1].strip().replace("\n", " ")
|
mb = matches[0][1].strip().replace("\n", " ")
|
||||||
|
logger.debug(f"Extracted tool call: {ma}({mb})")
|
||||||
return f"{ma}({mb})"
|
return f"{ma}({mb})"
|
||||||
|
|
||||||
def _trim_chat_history(
|
def _trim_chat_history(
|
||||||
|
|
@ -284,13 +291,14 @@ class CiaynAgent:
|
||||||
response = self.model.invoke([self.sys_message] + full_history)
|
response = self.model.invoke([self.sys_message] + full_history)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# logger.debug(f"Code generated by agent: {response.content}")
|
|
||||||
last_result = self._execute_tool(response)
|
last_result = self._execute_tool(response)
|
||||||
self.chat_history.append(response)
|
self.chat_history.append(response)
|
||||||
self.fallback_handler.reset_fallback_handler()
|
self.fallback_handler.reset_fallback_handler()
|
||||||
yield {}
|
yield {}
|
||||||
|
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
fallback_response = self.fallback_handler.handle_failure(e, self)
|
fallback_response = self.fallback_handler.handle_failure(
|
||||||
|
e, self, self.chat_history
|
||||||
|
)
|
||||||
last_result = self.handle_fallback_response(fallback_response, e)
|
last_result = self.handle_fallback_response(fallback_response, e)
|
||||||
yield {}
|
yield {}
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ class FallbackHandler:
|
||||||
self.failed_messages: list[BaseMessage] = []
|
self.failed_messages: list[BaseMessage] = []
|
||||||
self.current_failing_tool_name = ""
|
self.current_failing_tool_name = ""
|
||||||
self.current_tool_to_bind: None | BaseTool = None
|
self.current_tool_to_bind: None | BaseTool = None
|
||||||
|
self.msg_list: list[BaseMessage] = []
|
||||||
|
|
||||||
cpm(
|
cpm(
|
||||||
"Fallback models selected: "
|
"Fallback models selected: "
|
||||||
|
|
@ -100,7 +101,9 @@ class FallbackHandler:
|
||||||
)
|
)
|
||||||
return final_models
|
return final_models
|
||||||
|
|
||||||
def handle_failure(self, error: ToolExecutionError, agent: RAgents):
|
def handle_failure(
|
||||||
|
self, error: ToolExecutionError, agent: RAgents, msg_list: list[BaseMessage]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||||
|
|
||||||
|
|
@ -111,6 +114,9 @@ class FallbackHandler:
|
||||||
if not self.fallback_enabled:
|
if not self.fallback_enabled:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if self.tool_failure_consecutive_failures == 0:
|
||||||
|
self.init_msg_list(msg_list)
|
||||||
|
|
||||||
failed_tool_call_name = self.extract_failed_tool_name(error)
|
failed_tool_call_name = self.extract_failed_tool_name(error)
|
||||||
self._reset_on_new_failure(failed_tool_call_name)
|
self._reset_on_new_failure(failed_tool_call_name)
|
||||||
|
|
||||||
|
|
@ -177,6 +183,7 @@ class FallbackHandler:
|
||||||
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
||||||
self.current_failing_tool_name = ""
|
self.current_failing_tool_name = ""
|
||||||
self.current_tool_to_bind = None
|
self.current_tool_to_bind = None
|
||||||
|
self.msg_list = []
|
||||||
|
|
||||||
def _reset_on_new_failure(self, failed_tool_call_name):
|
def _reset_on_new_failure(self, failed_tool_call_name):
|
||||||
if (
|
if (
|
||||||
|
|
@ -296,22 +303,29 @@ class FallbackHandler:
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of chat messages.
|
list: A list of chat messages.
|
||||||
"""
|
"""
|
||||||
msg_list: list[BaseMessage] = []
|
prompt_msg_list: list[BaseMessage] = []
|
||||||
msg_list.append(
|
prompt_msg_list.append(
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages."
|
content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Have some way to use the correct message type in the future, dont just convert everything to system message.
|
||||||
|
# This may be difficult as each model type may require different chat structures and throw API errors.
|
||||||
|
prompt_msg_list.extend(SystemMessage(str(msg)) for msg in self.msg_list)
|
||||||
|
|
||||||
if self.failed_messages:
|
if self.failed_messages:
|
||||||
# Convert to system messages to avoid API errors asking for correct msg structure
|
# Convert to system messages to avoid API errors asking for correct msg structure
|
||||||
msg_list.extend([SystemMessage(str(msg)) for msg in self.failed_messages])
|
prompt_msg_list.extend(
|
||||||
|
[SystemMessage(str(msg)) for msg in self.failed_messages]
|
||||||
|
)
|
||||||
|
|
||||||
msg_list.append(
|
prompt_msg_list.append(
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=f"Retry using the tool '{self.current_failing_tool_name}' with improved arguments."
|
content=f"Retry using the tool: '{self.current_failing_tool_name}' with correct arguments and formatting."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return msg_list
|
return prompt_msg_list
|
||||||
|
|
||||||
def invoke_prompt_tool_call(self, tool_call_request: dict):
|
def invoke_prompt_tool_call(self, tool_call_request: dict):
|
||||||
"""
|
"""
|
||||||
|
|
@ -323,12 +337,17 @@ class FallbackHandler:
|
||||||
Returns:
|
Returns:
|
||||||
The result of invoking the tool.
|
The result of invoking the tool.
|
||||||
"""
|
"""
|
||||||
tool_name_to_tool = {getattr(tool.func, "__name__", None): tool for tool in self.tools}
|
tool_name_to_tool = {
|
||||||
|
getattr(tool.func, "__name__", None): tool for tool in self.tools
|
||||||
|
}
|
||||||
name = tool_call_request["name"]
|
name = tool_call_request["name"]
|
||||||
arguments = tool_call_request["arguments"]
|
arguments = tool_call_request["arguments"]
|
||||||
if name in tool_name_to_tool:
|
if name in tool_name_to_tool:
|
||||||
return tool_name_to_tool[name].invoke(arguments)
|
return tool_name_to_tool[name].invoke(arguments)
|
||||||
elif self.current_tool_to_bind is not None and getattr(self.current_tool_to_bind.func, "__name__", None) == name:
|
elif (
|
||||||
|
self.current_tool_to_bind is not None
|
||||||
|
and getattr(self.current_tool_to_bind.func, "__name__", None) == name
|
||||||
|
):
|
||||||
return self.current_tool_to_bind.invoke(arguments)
|
return self.current_tool_to_bind.invoke(arguments)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Tool '{name}' not found in available tools.")
|
raise Exception(f"Tool '{name}' not found in available tools.")
|
||||||
|
|
@ -406,3 +425,12 @@ class FallbackHandler:
|
||||||
if fallback_response and agent_type == "React":
|
if fallback_response and agent_type == "React":
|
||||||
return [SystemMessage(str(msg)) for msg in fallback_response]
|
return [SystemMessage(str(msg)) for msg in fallback_response]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def init_msg_list(self, full_msg_list: list[BaseMessage]) -> None:
|
||||||
|
first_two = full_msg_list[:2]
|
||||||
|
last_two = full_msg_list[-2:]
|
||||||
|
merged = first_two.copy()
|
||||||
|
for msg in last_two:
|
||||||
|
if msg not in merged:
|
||||||
|
merged.append(msg)
|
||||||
|
self.msg_list = merged
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
error_obj = ToolExecutionError(
|
error_obj = ToolExecutionError(
|
||||||
"Test error", base_message="dummy_call()", tool_name="dummy_tool"
|
"Test error", base_message="dummy_call()", tool_name="dummy_tool"
|
||||||
)
|
)
|
||||||
self.fallback_handler.handle_failure(error_obj, self.agent)
|
self.fallback_handler.handle_failure(error_obj, self.agent, [])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.fallback_handler.tool_failure_consecutive_failures,
|
self.fallback_handler.tool_failure_consecutive_failures,
|
||||||
initial_failures + 1,
|
initial_failures + 1,
|
||||||
|
|
@ -295,5 +295,19 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
self.assertIsNone(response_non)
|
self.assertIsNone(response_non)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_msg_list_non_overlapping(self):
|
||||||
|
# Test when the first two and last two messages do not overlap.
|
||||||
|
full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"]
|
||||||
|
self.fallback_handler.init_msg_list(full_list)
|
||||||
|
# Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5")
|
||||||
|
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"])
|
||||||
|
|
||||||
|
def test_init_msg_list_with_overlap(self):
|
||||||
|
# Test when the last two messages overlap with the first two.
|
||||||
|
full_list = ["msg1", "msg2", "msg1", "msg3"]
|
||||||
|
self.fallback_handler.init_msg_list(full_list)
|
||||||
|
# Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present.
|
||||||
|
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue