feat(agent_utils.py): refactor agent stream handling to improve clarity and maintainability by introducing reset_agent_completion_flags function
feat(fallback_handler.py): enhance fallback handling by allowing RAgents type and improving error handling fix(config.py): update RAgents type definition to include both CompiledGraph and CiaynAgent for better type safety refactor(fallback_handler.py): streamline fallback model invocation and response handling for improved readability and functionality
This commit is contained in:
parent
6e8b0f2e42
commit
96b41458a1
|
|
@ -28,7 +28,7 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents
|
||||||
from ra_aid.console.formatting import print_error, print_stage_header
|
from ra_aid.console.formatting import print_error, print_stage_header
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
||||||
|
|
@ -807,16 +807,10 @@ def _decrement_agent_depth():
|
||||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||||
|
|
||||||
|
|
||||||
def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict):
|
def reset_agent_completion_flags():
|
||||||
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
_global_memory["plan_completed"] = False
|
||||||
logger.debug("Agent output: %s", chunk)
|
_global_memory["task_completed"] = False
|
||||||
check_interrupt()
|
_global_memory["completion_message"] = ""
|
||||||
print_agent_output(chunk)
|
|
||||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
|
||||||
_global_memory["plan_completed"] = False
|
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
_global_memory["completion_message"] = ""
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
||||||
|
|
@ -842,8 +836,26 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
||||||
|
for chunk in agent.stream({"messages": msg_list}, config):
|
||||||
|
logger.debug("Agent output: %s", chunk)
|
||||||
|
check_interrupt()
|
||||||
|
print_agent_output(chunk)
|
||||||
|
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||||
|
reset_agent_completion_flags()
|
||||||
|
break
|
||||||
|
check_interrupt()
|
||||||
|
print_agent_output(chunk)
|
||||||
|
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||||
|
reset_agent_completion_flags()
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def run_agent_with_retry(
|
def run_agent_with_retry(
|
||||||
agent, prompt: str, config: dict, fallback_handler: FallbackHandler
|
agent: RAgents,
|
||||||
|
prompt: str,
|
||||||
|
config: dict,
|
||||||
|
fallback_handler: FallbackHandler,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Run an agent with retry logic for API errors."""
|
"""Run an agent with retry logic for API errors."""
|
||||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||||
|
|
@ -854,6 +866,7 @@ def run_agent_with_retry(
|
||||||
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||||
auto_test = config.get("auto_test", False)
|
auto_test = config.get("auto_test", False)
|
||||||
original_prompt = prompt
|
original_prompt = prompt
|
||||||
|
msg_list = [HumanMessage(content=prompt)]
|
||||||
|
|
||||||
with InterruptibleSection():
|
with InterruptibleSection():
|
||||||
try:
|
try:
|
||||||
|
|
@ -862,7 +875,7 @@ def run_agent_with_retry(
|
||||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
try:
|
try:
|
||||||
_run_agent_stream(agent, prompt, config)
|
_run_agent_stream(agent, msg_list, config)
|
||||||
fallback_handler.reset_fallback_handler()
|
fallback_handler.reset_fallback_handler()
|
||||||
should_break, prompt, auto_test, test_attempts = (
|
should_break, prompt, auto_test, test_attempts = (
|
||||||
_execute_test_command_wrapper(
|
_execute_test_command_wrapper(
|
||||||
|
|
@ -878,7 +891,7 @@ def run_agent_with_retry(
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||||
if fallback_response:
|
if fallback_response:
|
||||||
prompt = original_prompt + "\n" + str(fallback_response)
|
msg_list.extend(fallback_response)
|
||||||
continue
|
continue
|
||||||
except (KeyboardInterrupt, AgentInterrupt):
|
except (KeyboardInterrupt, AgentInterrupt):
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -15,3 +15,8 @@ VALID_PROVIDERS = [
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"gemini",
|
"gemini",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
|
from langgraph.graph.graph import CompiledGraph
|
||||||
|
|
||||||
|
RAgents = CompiledGraph | CiaynAgent
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,16 @@
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.graph.graph import CompiledGraph
|
|
||||||
from langgraph.graph.message import BaseMessage
|
from langgraph.graph.message import BaseMessage
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage
|
|
||||||
|
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
|
||||||
from ra_aid.config import (
|
from ra_aid.config import (
|
||||||
DEFAULT_MAX_TOOL_FAILURES,
|
DEFAULT_MAX_TOOL_FAILURES,
|
||||||
FALLBACK_TOOL_MODEL_LIMIT,
|
FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
RETRY_FALLBACK_COUNT,
|
RETRY_FALLBACK_COUNT,
|
||||||
|
RAgents,
|
||||||
)
|
)
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
@ -47,11 +47,18 @@ class FallbackHandler:
|
||||||
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
self.failed_messages: list[BaseMessage] = []
|
self.failed_messages: list[BaseMessage] = []
|
||||||
self.tool_failure_used_fallbacks = set()
|
|
||||||
self.current_failing_tool_name = ""
|
self.current_failing_tool_name = ""
|
||||||
self.current_tool_to_bind = None
|
self.current_tool_to_bind: None | BaseTool = None
|
||||||
|
|
||||||
def _load_fallback_tool_models(self, config):
|
cpm(
|
||||||
|
"Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
|
||||||
|
title="Fallback Models",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_model(self, m: dict) -> str:
|
||||||
|
return f"{m.get('model', '')} ({m.get('type', 'prompt')})"
|
||||||
|
|
||||||
|
def _load_fallback_tool_models(self, _config):
|
||||||
"""
|
"""
|
||||||
Load and return fallback tool models based on the provided configuration.
|
Load and return fallback tool models based on the provided configuration.
|
||||||
|
|
||||||
|
|
@ -90,12 +97,9 @@ class FallbackHandler:
|
||||||
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||||
+ ", ".join(skipped)
|
+ ", ".join(skipped)
|
||||||
)
|
)
|
||||||
cpm(message, title="Fallback Models")
|
|
||||||
return final_models
|
return final_models
|
||||||
|
|
||||||
def handle_failure(
|
def handle_failure(self, error: ToolExecutionError, agent: RAgents):
|
||||||
self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||||
|
|
||||||
|
|
@ -137,10 +141,10 @@ class FallbackHandler:
|
||||||
|
|
||||||
def attempt_fallback(self):
|
def attempt_fallback(self):
|
||||||
"""
|
"""
|
||||||
Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method.
|
Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from a fallback model if any, otherwise None.
|
List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None.
|
||||||
"""
|
"""
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||||
|
|
@ -150,12 +154,10 @@ class FallbackHandler:
|
||||||
title="Fallback Notification",
|
title="Fallback Notification",
|
||||||
)
|
)
|
||||||
for fallback_model in self.fallback_tool_models:
|
for fallback_model in self.fallback_tool_models:
|
||||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
result_list = self.invoke_fallback(fallback_model)
|
||||||
response = self.attempt_fallback_function(fallback_model)
|
if result_list:
|
||||||
else:
|
msg_list_response = [SystemMessage(str(msg)) for msg in result_list]
|
||||||
response = self.attempt_fallback_prompt(fallback_model)
|
return msg_list_response
|
||||||
if response:
|
|
||||||
return response
|
|
||||||
cpm("All fallback models have failed", title="Fallback Failed")
|
cpm("All fallback models have failed", title="Fallback Failed")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -165,8 +167,9 @@ class FallbackHandler:
|
||||||
"""
|
"""
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
self.failed_messages.clear()
|
self.failed_messages.clear()
|
||||||
self.tool_failure_used_fallbacks.clear()
|
|
||||||
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
||||||
|
self.current_failing_tool_name = ""
|
||||||
|
self.current_tool_to_bind = None
|
||||||
|
|
||||||
def _reset_on_new_failure(self, failed_tool_call_name):
|
def _reset_on_new_failure(self, failed_tool_call_name):
|
||||||
if (
|
if (
|
||||||
|
|
@ -218,9 +221,21 @@ class FallbackHandler:
|
||||||
)
|
)
|
||||||
return tool_to_bind
|
return tool_to_bind
|
||||||
|
|
||||||
def attempt_fallback_prompt(self, fallback_model):
|
def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model):
|
||||||
|
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||||
|
# Force tool calling with tool_choice param.
|
||||||
|
bound_model = simple_model.bind_tools(
|
||||||
|
[self.current_tool_to_bind],
|
||||||
|
tool_choice=self.current_failing_tool_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Do not force tool calling (Prompt method)
|
||||||
|
bound_model = simple_model.bind_tools([self.current_tool_to_bind])
|
||||||
|
return bound_model
|
||||||
|
|
||||||
|
def invoke_fallback(self, fallback_model):
|
||||||
"""
|
"""
|
||||||
Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model.
|
Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fallback_model (dict): The fallback model to use.
|
fallback_model (dict): The fallback model to use.
|
||||||
|
|
@ -229,88 +244,33 @@ class FallbackHandler:
|
||||||
The response from the fallback model invocation, or None if failed.
|
The response from the fallback model invocation, or None if failed.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}")
|
||||||
simple_model = initialize_llm(
|
simple_model = initialize_llm(
|
||||||
fallback_model["provider"], fallback_model["model"]
|
fallback_model["provider"], fallback_model["model"]
|
||||||
)
|
)
|
||||||
binded_model = simple_model.bind_tools(
|
|
||||||
[self.current_tool_to_bind],
|
bound_model = self._bind_tool_model(simple_model, fallback_model)
|
||||||
tool_choice=self.current_failing_tool_name,
|
|
||||||
)
|
retry_model = bound_model.with_retry(
|
||||||
retry_model = binded_model.with_retry(
|
|
||||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
)
|
)
|
||||||
# msg_list = []
|
|
||||||
msg_list = self.construct_prompt_msg_list()
|
msg_list = self.construct_prompt_msg_list()
|
||||||
# response = retry_model.invoke(self.current_failing_tool_name)
|
|
||||||
response = retry_model.invoke(msg_list)
|
response = retry_model.invoke(msg_list)
|
||||||
|
|
||||||
cpm(f"response={response}")
|
logger.debug(f"raw llm response={response}")
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
|
||||||
tool_call = self.base_message_to_tool_call_dict(response)
|
tool_call = self.base_message_to_tool_call_dict(response)
|
||||||
if tool_call:
|
|
||||||
result = self.invoke_prompt_tool_call(tool_call)
|
tool_call_result = self.invoke_prompt_tool_call(tool_call)
|
||||||
cpm(f"result={result}")
|
cpm(str(tool_call_result), title="Fallback Tool Call Result")
|
||||||
logger.debug(
|
logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}")
|
||||||
"Prompt-based fallback executed successfully with model: "
|
|
||||||
+ fallback_model["model"]
|
self.reset_fallback_handler()
|
||||||
)
|
return [response, tool_call_result]
|
||||||
self.reset_fallback_handler()
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
cpm(
|
|
||||||
response.content if hasattr(response, "content") else response,
|
|
||||||
title="Fallback Model Response: " + fallback_model["model"],
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, KeyboardInterrupt):
|
if isinstance(e, KeyboardInterrupt):
|
||||||
raise
|
raise
|
||||||
logger.error(
|
logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}")
|
||||||
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def attempt_fallback_function(self, fallback_model):
|
|
||||||
"""
|
|
||||||
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fallback_model (dict): The fallback model to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The response from the fallback model invocation, or None if failed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
|
||||||
simple_model = initialize_llm(
|
|
||||||
fallback_model["provider"], fallback_model["model"]
|
|
||||||
)
|
|
||||||
binded_model = simple_model.bind_tools(
|
|
||||||
[self.current_tool_to_bind],
|
|
||||||
tool_choice=self.current_failing_tool_name,
|
|
||||||
)
|
|
||||||
retry_model = binded_model.with_retry(
|
|
||||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
|
||||||
)
|
|
||||||
response = retry_model.invoke(self.current_failing_tool_name)
|
|
||||||
cpm(f"response={response}")
|
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
|
||||||
logger.debug(
|
|
||||||
"Function-calling fallback executed successfully with model: "
|
|
||||||
+ fallback_model["model"]
|
|
||||||
)
|
|
||||||
cpm(
|
|
||||||
response.content if hasattr(response, "content") else response,
|
|
||||||
title="Fallback Model Response: " + fallback_model["model"],
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, KeyboardInterrupt):
|
|
||||||
raise
|
|
||||||
logger.error(
|
|
||||||
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def construct_prompt_msg_list(self):
|
def construct_prompt_msg_list(self):
|
||||||
|
|
@ -367,19 +327,22 @@ class FallbackHandler:
|
||||||
otherwise None.
|
otherwise None.
|
||||||
"""
|
"""
|
||||||
tool_calls = self.get_tool_calls(response)
|
tool_calls = self.get_tool_calls(response)
|
||||||
if tool_calls:
|
|
||||||
if len(tool_calls) > 1:
|
if not tool_calls:
|
||||||
logger.warning("Multiple tool calls detected, using the first one")
|
raise Exception(
|
||||||
tool_call = tool_calls[0]
|
f"Could not extract tool_call_dict from response: {response}"
|
||||||
return {
|
)
|
||||||
"id": tool_call["id"],
|
|
||||||
"type": tool_call["type"],
|
if len(tool_calls) > 1:
|
||||||
"name": tool_call["function"]["name"],
|
logger.warning("Multiple tool calls detected, using the first one")
|
||||||
"arguments": self._parse_tool_arguments(
|
|
||||||
tool_call["function"]["arguments"]
|
tool_call = tool_calls[0]
|
||||||
),
|
return {
|
||||||
}
|
"id": tool_call["id"],
|
||||||
return None
|
"type": tool_call["type"],
|
||||||
|
"name": tool_call["function"]["name"],
|
||||||
|
"arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]),
|
||||||
|
}
|
||||||
|
|
||||||
def _parse_tool_arguments(self, tool_arguments):
|
def _parse_tool_arguments(self, tool_arguments):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue