feat(agent_utils.py): refactor agent stream handling to improve clarity and maintainability by introducing reset_agent_completion_flags function
feat(fallback_handler.py): enhance fallback handling by allowing RAgents type and improving error handling fix(config.py): update RAgents type definition to include both CompiledGraph and CiaynAgent for better type safety refactor(fallback_handler.py): streamline fallback model invocation and response handling for improved readability and functionality
This commit is contained in:
parent
6e8b0f2e42
commit
96b41458a1
|
|
@ -28,7 +28,7 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents
|
||||
from ra_aid.console.formatting import print_error, print_stage_header
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
||||
|
|
@ -807,16 +807,10 @@ def _decrement_agent_depth():
|
|||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||
|
||||
|
||||
def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict):
|
||||
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
def reset_agent_completion_flags():
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
break
|
||||
|
||||
|
||||
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
||||
|
|
@ -842,8 +836,26 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
||||
for chunk in agent.stream({"messages": msg_list}, config):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
reset_agent_completion_flags()
|
||||
break
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
reset_agent_completion_flags()
|
||||
break
|
||||
|
||||
|
||||
def run_agent_with_retry(
|
||||
agent, prompt: str, config: dict, fallback_handler: FallbackHandler
|
||||
agent: RAgents,
|
||||
prompt: str,
|
||||
config: dict,
|
||||
fallback_handler: FallbackHandler,
|
||||
) -> Optional[str]:
|
||||
"""Run an agent with retry logic for API errors."""
|
||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||
|
|
@ -854,6 +866,7 @@ def run_agent_with_retry(
|
|||
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||
auto_test = config.get("auto_test", False)
|
||||
original_prompt = prompt
|
||||
msg_list = [HumanMessage(content=prompt)]
|
||||
|
||||
with InterruptibleSection():
|
||||
try:
|
||||
|
|
@ -862,7 +875,7 @@ def run_agent_with_retry(
|
|||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||
check_interrupt()
|
||||
try:
|
||||
_run_agent_stream(agent, prompt, config)
|
||||
_run_agent_stream(agent, msg_list, config)
|
||||
fallback_handler.reset_fallback_handler()
|
||||
should_break, prompt, auto_test, test_attempts = (
|
||||
_execute_test_command_wrapper(
|
||||
|
|
@ -878,7 +891,7 @@ def run_agent_with_retry(
|
|||
except ToolExecutionError as e:
|
||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||
if fallback_response:
|
||||
prompt = original_prompt + "\n" + str(fallback_response)
|
||||
msg_list.extend(fallback_response)
|
||||
continue
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -15,3 +15,8 @@ VALID_PROVIDERS = [
|
|||
"deepseek",
|
||||
"gemini",
|
||||
]
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
|
||||
RAgents = CompiledGraph | CiaynAgent
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
import json
|
||||
import re
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
from langgraph.graph.message import BaseMessage
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TOOL_FAILURES,
|
||||
FALLBACK_TOOL_MODEL_LIMIT,
|
||||
RETRY_FALLBACK_COUNT,
|
||||
RAgents,
|
||||
)
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
|
@ -47,11 +47,18 @@ class FallbackHandler:
|
|||
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.failed_messages: list[BaseMessage] = []
|
||||
self.tool_failure_used_fallbacks = set()
|
||||
self.current_failing_tool_name = ""
|
||||
self.current_tool_to_bind = None
|
||||
self.current_tool_to_bind: None | BaseTool = None
|
||||
|
||||
def _load_fallback_tool_models(self, config):
|
||||
cpm(
|
||||
"Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
|
||||
title="Fallback Models",
|
||||
)
|
||||
|
||||
def _format_model(self, m: dict) -> str:
|
||||
return f"{m.get('model', '')} ({m.get('type', 'prompt')})"
|
||||
|
||||
def _load_fallback_tool_models(self, _config):
|
||||
"""
|
||||
Load and return fallback tool models based on the provided configuration.
|
||||
|
||||
|
|
@ -90,12 +97,9 @@ class FallbackHandler:
|
|||
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||
+ ", ".join(skipped)
|
||||
)
|
||||
cpm(message, title="Fallback Models")
|
||||
return final_models
|
||||
|
||||
def handle_failure(
|
||||
self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph
|
||||
):
|
||||
def handle_failure(self, error: ToolExecutionError, agent: RAgents):
|
||||
"""
|
||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||
|
||||
|
|
@ -137,10 +141,10 @@ class FallbackHandler:
|
|||
|
||||
def attempt_fallback(self):
|
||||
"""
|
||||
Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method.
|
||||
Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call.
|
||||
|
||||
Returns:
|
||||
The response from a fallback model if any, otherwise None.
|
||||
List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None.
|
||||
"""
|
||||
logger.error(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||
|
|
@ -150,12 +154,10 @@ class FallbackHandler:
|
|||
title="Fallback Notification",
|
||||
)
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||
response = self.attempt_fallback_function(fallback_model)
|
||||
else:
|
||||
response = self.attempt_fallback_prompt(fallback_model)
|
||||
if response:
|
||||
return response
|
||||
result_list = self.invoke_fallback(fallback_model)
|
||||
if result_list:
|
||||
msg_list_response = [SystemMessage(str(msg)) for msg in result_list]
|
||||
return msg_list_response
|
||||
cpm("All fallback models have failed", title="Fallback Failed")
|
||||
return None
|
||||
|
||||
|
|
@ -165,8 +167,9 @@ class FallbackHandler:
|
|||
"""
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.failed_messages.clear()
|
||||
self.tool_failure_used_fallbacks.clear()
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
||||
self.current_failing_tool_name = ""
|
||||
self.current_tool_to_bind = None
|
||||
|
||||
def _reset_on_new_failure(self, failed_tool_call_name):
|
||||
if (
|
||||
|
|
@ -218,9 +221,21 @@ class FallbackHandler:
|
|||
)
|
||||
return tool_to_bind
|
||||
|
||||
def attempt_fallback_prompt(self, fallback_model):
|
||||
def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model):
|
||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||
# Force tool calling with tool_choice param.
|
||||
bound_model = simple_model.bind_tools(
|
||||
[self.current_tool_to_bind],
|
||||
tool_choice=self.current_failing_tool_name,
|
||||
)
|
||||
else:
|
||||
# Do not force tool calling (Prompt method)
|
||||
bound_model = simple_model.bind_tools([self.current_tool_to_bind])
|
||||
return bound_model
|
||||
|
||||
def invoke_fallback(self, fallback_model):
|
||||
"""
|
||||
Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model.
|
||||
Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model.
|
||||
|
||||
Args:
|
||||
fallback_model (dict): The fallback model to use.
|
||||
|
|
@ -229,88 +244,33 @@ class FallbackHandler:
|
|||
The response from the fallback model invocation, or None if failed.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}")
|
||||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
binded_model = simple_model.bind_tools(
|
||||
[self.current_tool_to_bind],
|
||||
tool_choice=self.current_failing_tool_name,
|
||||
)
|
||||
retry_model = binded_model.with_retry(
|
||||
|
||||
bound_model = self._bind_tool_model(simple_model, fallback_model)
|
||||
|
||||
retry_model = bound_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
)
|
||||
# msg_list = []
|
||||
|
||||
msg_list = self.construct_prompt_msg_list()
|
||||
# response = retry_model.invoke(self.current_failing_tool_name)
|
||||
response = retry_model.invoke(msg_list)
|
||||
|
||||
cpm(f"response={response}")
|
||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||
logger.debug(f"raw llm response={response}")
|
||||
tool_call = self.base_message_to_tool_call_dict(response)
|
||||
if tool_call:
|
||||
result = self.invoke_prompt_tool_call(tool_call)
|
||||
cpm(f"result={result}")
|
||||
logger.debug(
|
||||
"Prompt-based fallback executed successfully with model: "
|
||||
+ fallback_model["model"]
|
||||
)
|
||||
|
||||
tool_call_result = self.invoke_prompt_tool_call(tool_call)
|
||||
cpm(str(tool_call_result), title="Fallback Tool Call Result")
|
||||
logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}")
|
||||
|
||||
self.reset_fallback_handler()
|
||||
return result
|
||||
else:
|
||||
cpm(
|
||||
response.content if hasattr(response, "content") else response,
|
||||
title="Fallback Model Response: " + fallback_model["model"],
|
||||
)
|
||||
return response
|
||||
return [response, tool_call_result]
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(
|
||||
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def attempt_fallback_function(self, fallback_model):
|
||||
"""
|
||||
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
|
||||
|
||||
Args:
|
||||
fallback_model (dict): The fallback model to use.
|
||||
|
||||
Returns:
|
||||
The response from the fallback model invocation, or None if failed.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
binded_model = simple_model.bind_tools(
|
||||
[self.current_tool_to_bind],
|
||||
tool_choice=self.current_failing_tool_name,
|
||||
)
|
||||
retry_model = binded_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
)
|
||||
response = retry_model.invoke(self.current_failing_tool_name)
|
||||
cpm(f"response={response}")
|
||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||
logger.debug(
|
||||
"Function-calling fallback executed successfully with model: "
|
||||
+ fallback_model["model"]
|
||||
)
|
||||
cpm(
|
||||
response.content if hasattr(response, "content") else response,
|
||||
title="Fallback Model Response: " + fallback_model["model"],
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(
|
||||
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
||||
)
|
||||
logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}")
|
||||
return None
|
||||
|
||||
def construct_prompt_msg_list(self):
|
||||
|
|
@ -367,19 +327,22 @@ class FallbackHandler:
|
|||
otherwise None.
|
||||
"""
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
if tool_calls:
|
||||
|
||||
if not tool_calls:
|
||||
raise Exception(
|
||||
f"Could not extract tool_call_dict from response: {response}"
|
||||
)
|
||||
|
||||
if len(tool_calls) > 1:
|
||||
logger.warning("Multiple tool calls detected, using the first one")
|
||||
|
||||
tool_call = tool_calls[0]
|
||||
return {
|
||||
"id": tool_call["id"],
|
||||
"type": tool_call["type"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": self._parse_tool_arguments(
|
||||
tool_call["function"]["arguments"]
|
||||
),
|
||||
"arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]),
|
||||
}
|
||||
return None
|
||||
|
||||
def _parse_tool_arguments(self, tool_arguments):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue