tools.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import os
  2. import pandas as pd
  3. import chromadb
  4. from tqdm import tqdm
  5. from dotenv import load_dotenv
  6. from typing import List, Any # Ajout de Any pour corriger le warning
  7. from langchain_community.tools import TavilySearchResults , DuckDuckGoSearchRun
  8. from langchain_openai import ChatOpenAI
  9. from langchain_chroma import Chroma
  10. from langchain_huggingface import HuggingFaceEmbeddings
  11. from langchain_core.documents import Document
  12. from langchain_core.retrievers import BaseRetriever
  13. from langchain_core.callbacks import CallbackManagerForRetrieverRun
  14. from langchain_core.prompts import ChatPromptTemplate
  15. # Correction des imports des chains (Chemins officiels 2026)
  16. from langchain_classic.chains import create_retrieval_chain
  17. from langchain_classic.chains.combine_documents import create_stuff_documents_chain
  18. # Correction de l'import Mixedbread
  19. from mixedbread_ai.client import MixedbreadAI
  20. from langgraph.prebuilt import create_react_agent
  21. from langgraph.checkpoint.memory import MemorySaver
  22. load_dotenv()
  23. # --- CONFIGURATION ---
  24. CHROMA_PATH = "./chroma_db_actuariat"
  25. PARQUET_PATH = "memoires_total_local.parquet"
  26. MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
  27. LLM_NAME = "gpt-4o-mini"
  28. EMB_MODEL_NAME = "mixedbread-ai/mxbai-embed-large-v1"
  29. mxbai_client = MixedbreadAI(api_key=os.environ.get("MXBAI_API_KEY"))
  30. # --- 1. INITIALISATION DE LA BASE (CHROMA) ---
  31. def get_vector_db():
  32. persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)
  33. if "embeddings_mxbai" not in [c.name for c in persistent_client.list_collections()]:
  34. print("Initialisation de la base de données vectorielle locale...")
  35. collection = persistent_client.create_collection(name="embeddings_mxbai")
  36. # Chargement du Parquet
  37. df = pd.read_parquet(PARQUET_PATH)
  38. batch_size = 5000
  39. for i in tqdm(range(0, len(df), batch_size), desc="Peuplement de ChromaDB"):
  40. batch = df.iloc[i : i + batch_size]
  41. # Sécurité : Conversion des ID en string
  42. ids = [str(idx) for idx in batch["id"].tolist()]
  43. collection.add(
  44. ids=ids,
  45. embeddings=batch["embedding"].tolist(),
  46. documents=batch["document"].tolist(),
  47. metadatas=[{"title": str(m["title"])} for m in batch["metadata"]]
  48. )
  49. print("✅ Base de données prête.")
  50. return Chroma(
  51. client=persistent_client,
  52. collection_name="embeddings_mxbai",
  53. embedding_function=HuggingFaceEmbeddings(model_name=EMB_MODEL_NAME)
  54. )
  55. # --- 2. RERANKER PERSONNALISÉ (Correction du Warning) ---
  56. class MixedbreadReranker(BaseRetriever):
  57. retriever: Any
  58. k: int = 4
  59. def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
  60. # 1. On récupère les documents dans ChromaDB (Local, donc toujours dispo !)
  61. initial_docs = self.retriever.invoke(query)
  62. try:
  63. # 2. Tentative de Reranking avec Mixedbread
  64. res = mxbai_client.reranking(
  65. model=MODEL_RRK,
  66. query=query,
  67. input=[doc.page_content for doc in initial_docs],
  68. top_k=self.k
  69. )
  70. final_docs = []
  71. for r in res.data:
  72. final_docs.append(Document(
  73. page_content=r.input,
  74. metadata=initial_docs[r.index].metadata
  75. ))
  76. return final_docs
  77. except Exception as e:
  78. # 3. LE SECOURS : Si l'API 503 répond, on renvoie les docs bruts de Chroma
  79. print(f"⚠️ Mixedbread indisponible ({e}). Passage en mode secours.")
  80. # On renvoie juste les 'k' premiers documents de ChromaDB sans les retrier
  81. return initial_docs[:self.k]
  82. # --- 3. INITIALISATION DU RAG TOOL ---
  83. def init_rag_tool():
  84. db = get_vector_db()
  85. base_retriever = db.as_retriever(search_kwargs={"k": 25})
  86. reranker = MixedbreadReranker(retriever=base_retriever, k=5)
  87. llm = ChatOpenAI(model=LLM_NAME, temperature=0)
  88. system_prompt = (
  89. "Tu es un expert en actuariat. Utilise le contexte suivant pour répondre : \n\n {context}\n\n"
  90. "Si la réponse n'est pas dans le contexte, utilise tes connaissances générales ou Web_search."
  91. )
  92. prompt = ChatPromptTemplate.from_messages([
  93. ("system", system_prompt),
  94. ("human", "{input}")
  95. ])
  96. qa_chain = create_stuff_documents_chain(llm, prompt)
  97. rag_chain = create_retrieval_chain(reranker, qa_chain)
  98. # Transformation de la chain en outil pour l'agent
  99. return rag_chain.as_tool(
  100. name="RAG_search",
  101. description="Cherche des informations techniques dans les mémoires d'actuariat."
  102. )
  103. def init_websearch_tool():
  104. return DuckDuckGoSearchRun(
  105. name="Web_search",
  106. max_results=3,
  107. description="Recherche d'informations actuelles sur le web",
  108. )
  109. def create_agent():
  110. rag_tool = init_rag_tool()
  111. web_search_tool = init_websearch_tool()
  112. memory = MemorySaver()
  113. llm_4o = ChatOpenAI(model=LLM_NAME, temperature=0, streaming=True)
  114. tools = [rag_tool, web_search_tool]
  115. system_message = """
  116. Tu es un expert actuariel. Ton rôle est d'assister les utilisateurs sur des sujets d'assurance.
  117. 1. Utilise 'RAG_search' pour les concepts théoriques, thèses et mémoires.
  118. 2. Utilise 'Web_search' pour les actualités ou réglementations très récentes.
  119. 3. Cite toujours tes sources si disponibles dans les métadonnées (titre du mémoire).
  120. """
  121. return create_react_agent(
  122. llm_4o,
  123. tools,
  124. prompt=system_message,
  125. checkpointer=memory
  126. )