From 3607803bf5e8f7e82b25a2adbef195200345783f Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Fri, 7 Mar 2025 17:31:27 -0500 Subject: [PATCH] use interrupt for tighter control of agent loop --- ra_aid/agent_utils.py | 138 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 129 insertions(+), 9 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index f3a8be6..6fe8ed5 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -310,7 +310,7 @@ def create_agent( if is_anthropic_claude(config): logger.debug("Using create_react_agent to instantiate agent.") agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) - return create_react_agent(model, tools, **agent_kwargs) + return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs) else: logger.debug("Using CiaynAgent agent instance") return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config) @@ -321,7 +321,7 @@ def create_agent( config = get_config_repository().get_all() max_input_tokens = get_model_token_limit(config, agent_type) agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) - return create_react_agent(model, tools, **agent_kwargs) + return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs) def run_research_agent( @@ -991,16 +991,136 @@ def _handle_fallback_response( msg_list.extend(msg_list_response) +# def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): +# for chunk in agent.stream({"messages": msg_list}, config): +# logger.debug("Agent output: %s", chunk) +# check_interrupt() +# agent_type = get_agent_type(agent) +# print_agent_output(chunk, agent_type) +# if is_completed() or should_exit(): +# reset_completion_flags() +# break +# def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): +# while True: ## WE NEED TO ONLY KEEP ITERATING IF IT IS AN INTERRUPT, NOT UNCONDITIONALLY +# stream = agent.stream({"messages": msg_list}, config) +# for chunk in stream: +# logger.debug("Agent output: %s", chunk) +# check_interrupt() +# agent_type = get_agent_type(agent) +# print_agent_output(chunk, agent_type) +# if is_completed() or should_exit(): +# reset_completion_flags() +# return True +# print("HERE!") + +# def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): +# while True: +# for chunk in agent.stream({"messages": msg_list}, config): +# print("Chunk received:", chunk) +# check_interrupt() +# agent_type = get_agent_type(agent) +# print_agent_output(chunk, agent_type) +# if is_completed() or should_exit(): +# reset_completion_flags() +# return True +# print("HERE!") +# print("Config passed to _run_agent_stream:", config) +# print("Config keys:", list(config.keys())) + +# # Ensure the configuration for state retrieval contains a 'configurable' key. +# state_config = config.copy() +# if "configurable" not in state_config: +# print("Key 'configurable' not found in config. Adding it as an empty dict.") +# state_config["configurable"] = {} +# print("Using state_config for agent.get_state():", state_config) + +# try: +# state = agent.get_state(state_config) +# print("Agent state retrieved:", state) +# print("State type:", type(state)) +# print("State attributes:", dir(state)) +# except Exception as e: +# print("Error retrieving agent state with state_config", state_config, ":", e) +# raise + +# # Since state.current is not available, we rely solely on state.next. +# try: +# next_node = state.next +# print("State next value:", next_node) +# except Exception as e: +# print("Error accessing state.next:", e) +# next_node = None + +# # Resume execution if state.next is truthy (indicating further steps remain). +# if next_node: +# print("Resuming execution because state.next is nonempty:", next_node) +# agent.invoke(None, config) +# continue +# else: +# print("No further steps indicated; breaking out of loop.") +# break + +# return True + + def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): - for chunk in agent.stream({"messages": msg_list}, config): - logger.debug("Agent output: %s", chunk) - check_interrupt() - agent_type = get_agent_type(agent) - print_agent_output(chunk, agent_type) - if is_completed() or should_exit(): - reset_completion_flags() + """ + Streams agent output while handling completion and interruption. + + For each chunk, it logs the output, calls check_interrupt(), prints agent output, + and then checks if is_completed() or should_exit() are true. If so, it resets completion + flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks), + the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty), + it resumes execution via agent.invoke(None, config); otherwise, it exits the loop. + + This function adheres to the latest LangGraph best practices (as of March 2025) for handling + human-in-the-loop interruptions using interrupt_after=["tools"]. + """ + while True: + print("HERE") + # Process each chunk from the agent stream. + for chunk in agent.stream({"messages": msg_list}, config): + print("HERE IN FOR CHUNK") + logger.debug("Agent output: %s", chunk) + check_interrupt() + agent_type = get_agent_type(agent) + print_agent_output(chunk, agent_type) + if is_completed() or should_exit(): + print("IS COMPLETED OR SHOULD EXIT TRIGGERED") + reset_completion_flags() + return True # Exit immediately when finished or signaled to exit. + logger.debug("Stream iteration ended; checking agent state for continuation.") + + # Prepare state configuration, ensuring 'configurable' is present. + state_config = config.copy() + if "configurable" not in state_config: + logger.debug("Key 'configurable' not found in config; adding it as an empty dict.") + state_config["configurable"] = {} + logger.debug("Using state_config for agent.get_state(): %s", state_config) + + try: + state = agent.get_state(state_config) + logger.debug("Agent state retrieved: %s", state) + except Exception as e: + logger.error("Error retrieving agent state with state_config %s: %s", state_config, e) + raise + + # If the state indicates that further steps remain (i.e. state.next is non-empty), + # then resume execution by invoking the agent with no new input. + if state.next: + print("HAS NEXT STATE") + logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next) + agent.invoke(None, config) + continue + else: + print("NO NEXT STATE") + logger.debug("No continuation indicated in state; exiting stream loop.") break + print("WHILE EXITED") + return True + + def run_agent_with_retry( agent: RAgents,