From 96b41458a1835a9d1cde4d29dba4114a91ef3223 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Feb 2025 15:35:31 -0800 Subject: [PATCH] feat(agent_utils.py): refactor agent stream handling to improve clarity and maintainability by introducing reset_agent_completion_flags function feat(fallback_handler.py): enhance fallback handling by allowing RAgents type and improving error handling fix(config.py): update RAgents type definition to include both CompiledGraph and CiaynAgent for better type safety refactor(fallback_handler.py): streamline fallback model invocation and response handling for improved readability and functionality --- ra_aid/agent_utils.py | 41 ++++++--- ra_aid/config.py | 5 ++ ra_aid/fallback_handler.py | 171 +++++++++++++++---------------------- 3 files changed, 99 insertions(+), 118 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index ad1adf5..75b95b7 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -28,7 +28,7 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agents.ciayn_agent import CiaynAgent -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output from ra_aid.exceptions import AgentInterrupt, ToolExecutionError @@ -807,16 +807,10 @@ def _decrement_agent_depth(): _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 -def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict): - for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): - logger.debug("Agent output: %s", chunk) - check_interrupt() - print_agent_output(chunk) - if _global_memory["plan_completed"] or _global_memory["task_completed"]: - _global_memory["plan_completed"] = False - _global_memory["task_completed"] = False - _global_memory["completion_message"] = "" - break +def reset_agent_completion_flags(): + _global_memory["plan_completed"] = False + _global_memory["task_completed"] = False + _global_memory["completion_message"] = "" def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test): @@ -842,8 +836,26 @@ def _handle_api_error(e, attempt, max_retries, base_delay): time.sleep(0.1) +def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): + for chunk in agent.stream({"messages": msg_list}, config): + logger.debug("Agent output: %s", chunk) + check_interrupt() + print_agent_output(chunk) + if _global_memory["plan_completed"] or _global_memory["task_completed"]: + reset_agent_completion_flags() + break + check_interrupt() + print_agent_output(chunk) + if _global_memory["plan_completed"] or _global_memory["task_completed"]: + reset_agent_completion_flags() + break + + def run_agent_with_retry( - agent, prompt: str, config: dict, fallback_handler: FallbackHandler + agent: RAgents, + prompt: str, + config: dict, + fallback_handler: FallbackHandler, ) -> Optional[str]: """Run an agent with retry logic for API errors.""" logger.debug("Running agent with prompt length: %d", len(prompt)) @@ -854,6 +866,7 @@ def run_agent_with_retry( _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt + msg_list = [HumanMessage(content=prompt)] with InterruptibleSection(): try: @@ -862,7 +875,7 @@ def run_agent_with_retry( logger.debug("Attempt %d/%d", attempt + 1, max_retries) check_interrupt() try: - _run_agent_stream(agent, prompt, config) + _run_agent_stream(agent, msg_list, config) fallback_handler.reset_fallback_handler() should_break, prompt, auto_test, test_attempts = ( _execute_test_command_wrapper( @@ -878,7 +891,7 @@ def run_agent_with_retry( except ToolExecutionError as e: fallback_response = fallback_handler.handle_failure(e, agent) if fallback_response: - prompt = original_prompt + "\n" + str(fallback_response) + msg_list.extend(fallback_response) continue except (KeyboardInterrupt, AgentInterrupt): raise diff --git a/ra_aid/config.py b/ra_aid/config.py index 54d7995..8bfaf5c 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -15,3 +15,8 @@ VALID_PROVIDERS = [ "deepseek", "gemini", ] + +from ra_aid.agents.ciayn_agent import CiaynAgent +from langgraph.graph.graph import CompiledGraph + +RAgents = CompiledGraph | CiaynAgent diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 4ae2016..c1c5ea2 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,16 +1,16 @@ import json import re +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import BaseTool -from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage -from langchain_core.messages import SystemMessage, HumanMessage -from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, + RAgents, ) from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError @@ -47,11 +47,18 @@ class FallbackHandler: self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 self.failed_messages: list[BaseMessage] = [] - self.tool_failure_used_fallbacks = set() self.current_failing_tool_name = "" - self.current_tool_to_bind = None + self.current_tool_to_bind: None | BaseTool = None - def _load_fallback_tool_models(self, config): + cpm( + "Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]), + title="Fallback Models", + ) + + def _format_model(self, m: dict) -> str: + return f"{m.get('model', '')} ({m.get('type', 'prompt')})" + + def _load_fallback_tool_models(self, _config): """ Load and return fallback tool models based on the provided configuration. @@ -90,12 +97,9 @@ class FallbackHandler: "\nSkipped top tool calling models due to missing provider ENV API keys: " + ", ".join(skipped) ) - cpm(message, title="Fallback Models") return final_models - def handle_failure( - self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph - ): + def handle_failure(self, error: ToolExecutionError, agent: RAgents): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -137,10 +141,10 @@ class FallbackHandler: def attempt_fallback(self): """ - Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method. + Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call. Returns: - The response from a fallback model if any, otherwise None. + List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None. """ logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" @@ -150,12 +154,10 @@ class FallbackHandler: title="Fallback Notification", ) for fallback_model in self.fallback_tool_models: - if fallback_model.get("type", "prompt").lower() == "fc": - response = self.attempt_fallback_function(fallback_model) - else: - response = self.attempt_fallback_prompt(fallback_model) - if response: - return response + result_list = self.invoke_fallback(fallback_model) + if result_list: + msg_list_response = [SystemMessage(str(msg)) for msg in result_list] + return msg_list_response cpm("All fallback models have failed", title="Fallback Failed") return None @@ -165,8 +167,9 @@ class FallbackHandler: """ self.tool_failure_consecutive_failures = 0 self.failed_messages.clear() - self.tool_failure_used_fallbacks.clear() self.fallback_tool_models = self._load_fallback_tool_models(self.config) + self.current_failing_tool_name = "" + self.current_tool_to_bind = None def _reset_on_new_failure(self, failed_tool_call_name): if ( @@ -218,9 +221,21 @@ class FallbackHandler: ) return tool_to_bind - def attempt_fallback_prompt(self, fallback_model): + def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model): + if fallback_model.get("type", "prompt").lower() == "fc": + # Force tool calling with tool_choice param. + bound_model = simple_model.bind_tools( + [self.current_tool_to_bind], + tool_choice=self.current_failing_tool_name, + ) + else: + # Do not force tool calling (Prompt method) + bound_model = simple_model.bind_tools([self.current_tool_to_bind]) + return bound_model + + def invoke_fallback(self, fallback_model): """ - Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model. + Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model. Args: fallback_model (dict): The fallback model to use. @@ -229,88 +244,33 @@ class FallbackHandler: The response from the fallback model invocation, or None if failed. """ try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") + logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) - binded_model = simple_model.bind_tools( - [self.current_tool_to_bind], - tool_choice=self.current_failing_tool_name, - ) - retry_model = binded_model.with_retry( + + bound_model = self._bind_tool_model(simple_model, fallback_model) + + retry_model = bound_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT ) - # msg_list = [] + msg_list = self.construct_prompt_msg_list() - # response = retry_model.invoke(self.current_failing_tool_name) response = retry_model.invoke(msg_list) - cpm(f"response={response}") - self.tool_failure_used_fallbacks.add(fallback_model["model"]) + logger.debug(f"raw llm response={response}") tool_call = self.base_message_to_tool_call_dict(response) - if tool_call: - result = self.invoke_prompt_tool_call(tool_call) - cpm(f"result={result}") - logger.debug( - "Prompt-based fallback executed successfully with model: " - + fallback_model["model"] - ) - self.reset_fallback_handler() - return result - else: - cpm( - response.content if hasattr(response, "content") else response, - title="Fallback Model Response: " + fallback_model["model"], - ) - return response + + tool_call_result = self.invoke_prompt_tool_call(tool_call) + cpm(str(tool_call_result), title="Fallback Tool Call Result") + logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}") + + self.reset_fallback_handler() + return [response, tool_call_result] except Exception as e: if isinstance(e, KeyboardInterrupt): raise - logger.error( - f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" - ) - return None - - def attempt_fallback_function(self, fallback_model): - """ - Attempt a function-calling fallback by invoking the current failing tool with the given fallback model. - - Args: - fallback_model (dict): The fallback model to use. - - Returns: - The response from the fallback model invocation, or None if failed. - """ - try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") - simple_model = initialize_llm( - fallback_model["provider"], fallback_model["model"] - ) - binded_model = simple_model.bind_tools( - [self.current_tool_to_bind], - tool_choice=self.current_failing_tool_name, - ) - retry_model = binded_model.with_retry( - stop_after_attempt=RETRY_FALLBACK_COUNT - ) - response = retry_model.invoke(self.current_failing_tool_name) - cpm(f"response={response}") - self.tool_failure_used_fallbacks.add(fallback_model["model"]) - logger.debug( - "Function-calling fallback executed successfully with model: " - + fallback_model["model"] - ) - cpm( - response.content if hasattr(response, "content") else response, - title="Fallback Model Response: " + fallback_model["model"], - ) - return response - except Exception as e: - if isinstance(e, KeyboardInterrupt): - raise - logger.error( - f"Function-calling fallback with model {fallback_model['model']} failed: {e}" - ) + logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}") return None def construct_prompt_msg_list(self): @@ -367,19 +327,22 @@ class FallbackHandler: otherwise None. """ tool_calls = self.get_tool_calls(response) - if tool_calls: - if len(tool_calls) > 1: - logger.warning("Multiple tool calls detected, using the first one") - tool_call = tool_calls[0] - return { - "id": tool_call["id"], - "type": tool_call["type"], - "name": tool_call["function"]["name"], - "arguments": self._parse_tool_arguments( - tool_call["function"]["arguments"] - ), - } - return None + + if not tool_calls: + raise Exception( + f"Could not extract tool_call_dict from response: {response}" + ) + + if len(tool_calls) > 1: + logger.warning("Multiple tool calls detected, using the first one") + + tool_call = tool_calls[0] + return { + "id": tool_call["id"], + "type": tool_call["type"], + "name": tool_call["function"]["name"], + "arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]), + } def _parse_tool_arguments(self, tool_arguments): """