import os import pandas as pd import chromadb from tqdm import tqdm from dotenv import load_dotenv from typing import List, Any # Ajout de Any pour corriger le warning from langchain_community.tools import TavilySearchResults , DuckDuckGoSearchRun from langchain_openai import ChatOpenAI from langchain_chroma import Chroma from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.prompts import ChatPromptTemplate # Correction des imports des chains (Chemins officiels 2026) from langchain_classic.chains import create_retrieval_chain from langchain_classic.chains.combine_documents import create_stuff_documents_chain # Correction de l'import Mixedbread from mixedbread_ai.client import MixedbreadAI from langgraph.prebuilt import create_react_agent from langgraph.checkpoint.memory import MemorySaver load_dotenv() # --- CONFIGURATION --- CHROMA_PATH = "./chroma_db_actuariat" PARQUET_PATH = "memoires_total_local.parquet" MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" LLM_NAME = "gpt-4o-mini" EMB_MODEL_NAME = "mixedbread-ai/mxbai-embed-large-v1" mxbai_client = MixedbreadAI(api_key=os.environ.get("MXBAI_API_KEY")) # --- 1. INITIALISATION DE LA BASE (CHROMA) --- def get_vector_db(): persistent_client = chromadb.PersistentClient(path=CHROMA_PATH) if "embeddings_mxbai" not in [c.name for c in persistent_client.list_collections()]: print("Initialisation de la base de données vectorielle locale...") collection = persistent_client.create_collection(name="embeddings_mxbai") # Chargement du Parquet df = pd.read_parquet(PARQUET_PATH) batch_size = 5000 for i in tqdm(range(0, len(df), batch_size), desc="Peuplement de ChromaDB"): batch = df.iloc[i : i + batch_size] # Sécurité : Conversion des ID en string ids = [str(idx) for idx in batch["id"].tolist()] collection.add( ids=ids, embeddings=batch["embedding"].tolist(), documents=batch["document"].tolist(), metadatas=[{"title": str(m["title"])} for m in batch["metadata"]] ) print("✅ Base de données prête.") return Chroma( client=persistent_client, collection_name="embeddings_mxbai", embedding_function=HuggingFaceEmbeddings(model_name=EMB_MODEL_NAME) ) # --- 2. RERANKER PERSONNALISÉ (Correction du Warning) --- class MixedbreadReranker(BaseRetriever): retriever: Any k: int = 4 def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: # 1. On récupère les documents dans ChromaDB (Local, donc toujours dispo !) initial_docs = self.retriever.invoke(query) try: # 2. Tentative de Reranking avec Mixedbread res = mxbai_client.reranking( model=MODEL_RRK, query=query, input=[doc.page_content for doc in initial_docs], top_k=self.k ) final_docs = [] for r in res.data: final_docs.append(Document( page_content=r.input, metadata=initial_docs[r.index].metadata )) return final_docs except Exception as e: # 3. LE SECOURS : Si l'API 503 répond, on renvoie les docs bruts de Chroma print(f"⚠️ Mixedbread indisponible ({e}). Passage en mode secours.") # On renvoie juste les 'k' premiers documents de ChromaDB sans les retrier return initial_docs[:self.k] # --- 3. INITIALISATION DU RAG TOOL --- def init_rag_tool(): db = get_vector_db() base_retriever = db.as_retriever(search_kwargs={"k": 25}) reranker = MixedbreadReranker(retriever=base_retriever, k=5) llm = ChatOpenAI(model=LLM_NAME, temperature=0) system_prompt = ( "Tu es un expert en actuariat. Utilise le contexte suivant pour répondre : \n\n {context}\n\n" "Si la réponse n'est pas dans le contexte, utilise tes connaissances générales ou Web_search." ) prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), ("human", "{input}") ]) qa_chain = create_stuff_documents_chain(llm, prompt) rag_chain = create_retrieval_chain(reranker, qa_chain) # Transformation de la chain en outil pour l'agent return rag_chain.as_tool( name="RAG_search", description="Cherche des informations techniques dans les mémoires d'actuariat." ) def init_websearch_tool(): return DuckDuckGoSearchRun( name="Web_search", max_results=3, description="Recherche d'informations actuelles sur le web", ) def create_agent(): rag_tool = init_rag_tool() web_search_tool = init_websearch_tool() memory = MemorySaver() llm_4o = ChatOpenAI(model=LLM_NAME, temperature=0, streaming=True) tools = [rag_tool, web_search_tool] system_message = """ Tu es un expert actuariel. Ton rôle est d'assister les utilisateurs sur des sujets d'assurance. 1. Utilise 'RAG_search' pour les concepts théoriques, thèses et mémoires. 2. Utilise 'Web_search' pour les actualités ou réglementations très récentes. 3. Cite toujours tes sources si disponibles dans les métadonnées (titre du mémoire). """ return create_react_agent( llm_4o, tools, prompt=system_message, checkpointer=memory )