diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index 1186956..5214b0a 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -8,7 +8,13 @@ from langchain_openai import ChatOpenAI from .memory import get_memory_value, get_related_files console = Console() -model = ChatOpenAI(model_name="o1-preview") +_model = None + +def get_model(): + global _model + if _model is None: + _model = ChatOpenAI(model_name="o1-preview") + return _model # Keep track of context globally expert_context = [] @@ -168,7 +174,7 @@ def ask_expert(question: str) -> str: query = "\n".join(query_parts) # Get response - response = model.invoke(query) + response = get_model().invoke(query) # Format and display response console.print(Panel(