| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- import os
- import pandas as pd
- import chromadb
- from tqdm import tqdm
- from dotenv import load_dotenv
- from typing import List, Any # Ajout de Any pour corriger le warning
- from langchain_community.tools import TavilySearchResults , DuckDuckGoSearchRun
- from langchain_openai import ChatOpenAI
- from langchain_chroma import Chroma
- from langchain_huggingface import HuggingFaceEmbeddings
- from langchain_core.documents import Document
- from langchain_core.retrievers import BaseRetriever
- from langchain_core.callbacks import CallbackManagerForRetrieverRun
- from langchain_core.prompts import ChatPromptTemplate
- # Correction des imports des chains (Chemins officiels 2026)
- from langchain_classic.chains import create_retrieval_chain
- from langchain_classic.chains.combine_documents import create_stuff_documents_chain
- # Correction de l'import Mixedbread
- from mixedbread_ai.client import MixedbreadAI
- from langgraph.prebuilt import create_react_agent
- from langgraph.checkpoint.memory import MemorySaver
- load_dotenv()
- # --- CONFIGURATION ---
- CHROMA_PATH = "./chroma_db_actuariat"
- PARQUET_PATH = "memoires_total_local.parquet"
- MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
- LLM_NAME = "gpt-4o-mini"
- EMB_MODEL_NAME = "mixedbread-ai/mxbai-embed-large-v1"
- mxbai_client = MixedbreadAI(api_key=os.environ.get("MXBAI_API_KEY"))
- # --- 1. INITIALISATION DE LA BASE (CHROMA) ---
- def get_vector_db():
- persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)
-
- if "embeddings_mxbai" not in [c.name for c in persistent_client.list_collections()]:
- print("Initialisation de la base de données vectorielle locale...")
- collection = persistent_client.create_collection(name="embeddings_mxbai")
-
- # Chargement du Parquet
- df = pd.read_parquet(PARQUET_PATH)
-
- batch_size = 5000
- for i in tqdm(range(0, len(df), batch_size), desc="Peuplement de ChromaDB"):
- batch = df.iloc[i : i + batch_size]
-
- # Sécurité : Conversion des ID en string
- ids = [str(idx) for idx in batch["id"].tolist()]
-
- collection.add(
- ids=ids,
- embeddings=batch["embedding"].tolist(),
- documents=batch["document"].tolist(),
- metadatas=[{"title": str(m["title"])} for m in batch["metadata"]]
- )
- print("✅ Base de données prête.")
-
- return Chroma(
- client=persistent_client,
- collection_name="embeddings_mxbai",
- embedding_function=HuggingFaceEmbeddings(model_name=EMB_MODEL_NAME)
- )
- # --- 2. RERANKER PERSONNALISÉ (Correction du Warning) ---
- class MixedbreadReranker(BaseRetriever):
- retriever: Any
- k: int = 4
- def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
- # 1. On récupère les documents dans ChromaDB (Local, donc toujours dispo !)
- initial_docs = self.retriever.invoke(query)
-
- try:
- # 2. Tentative de Reranking avec Mixedbread
- res = mxbai_client.reranking(
- model=MODEL_RRK,
- query=query,
- input=[doc.page_content for doc in initial_docs],
- top_k=self.k
- )
-
- final_docs = []
- for r in res.data:
- final_docs.append(Document(
- page_content=r.input,
- metadata=initial_docs[r.index].metadata
- ))
- return final_docs
-
- except Exception as e:
- # 3. LE SECOURS : Si l'API 503 répond, on renvoie les docs bruts de Chroma
- print(f"⚠️ Mixedbread indisponible ({e}). Passage en mode secours.")
- # On renvoie juste les 'k' premiers documents de ChromaDB sans les retrier
- return initial_docs[:self.k]
- # --- 3. INITIALISATION DU RAG TOOL ---
- def init_rag_tool():
- db = get_vector_db()
- base_retriever = db.as_retriever(search_kwargs={"k": 25})
- reranker = MixedbreadReranker(retriever=base_retriever, k=5)
-
- llm = ChatOpenAI(model=LLM_NAME, temperature=0)
-
- system_prompt = (
- "Tu es un expert en actuariat. Utilise le contexte suivant pour répondre : \n\n {context}\n\n"
- "Si la réponse n'est pas dans le contexte, utilise tes connaissances générales ou Web_search."
- )
- prompt = ChatPromptTemplate.from_messages([
- ("system", system_prompt),
- ("human", "{input}")
- ])
-
- qa_chain = create_stuff_documents_chain(llm, prompt)
- rag_chain = create_retrieval_chain(reranker, qa_chain)
-
- # Transformation de la chain en outil pour l'agent
- return rag_chain.as_tool(
- name="RAG_search",
- description="Cherche des informations techniques dans les mémoires d'actuariat."
- )
- def init_websearch_tool():
- return DuckDuckGoSearchRun(
- name="Web_search",
- max_results=3,
- description="Recherche d'informations actuelles sur le web",
- )
- def create_agent():
- rag_tool = init_rag_tool()
- web_search_tool = init_websearch_tool()
- memory = MemorySaver()
-
- llm_4o = ChatOpenAI(model=LLM_NAME, temperature=0, streaming=True)
- tools = [rag_tool, web_search_tool]
-
- system_message = """
- Tu es un expert actuariel. Ton rôle est d'assister les utilisateurs sur des sujets d'assurance.
- 1. Utilise 'RAG_search' pour les concepts théoriques, thèses et mémoires.
- 2. Utilise 'Web_search' pour les actualités ou réglementations très récentes.
- 3. Cite toujours tes sources si disponibles dans les métadonnées (titre du mémoire).
- """
- return create_react_agent(
- llm_4o,
- tools,
- prompt=system_message,
- checkpointer=memory
- )
|