embed-bge-m3/FlagEmbedding/research/old-examples/search_demo/tool.py

131 lines
4.8 KiB
Python

# encoding=gbk
import faiss
import numpy as np
import tiktoken
import torch
from datasets import load_from_disk
from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from pyserini.search.lucene import LuceneSearcher
from transformers import AutoTokenizer, AutoModel
class LocalDatasetLoader:
data_path: str = ""
doc_emb: np.ndarray = None
def __init__(self,
data_path,
embedding_path):
dataset = load_from_disk(data_path)
title = dataset['title']
text = dataset['text']
self.data = [title[i] + ' -- ' + text[i] for i in range(len(title))]
self.doc_emb = np.load(embedding_path)
class QueryGenerator:
def __init__(self):
prompt_template = """请基于历史,对新问题进行修改,如果新问题与历史相关,你必须结合语境将代词替换为相应的指代内容,让它的提问更加明确;否则保留原始的新问题。只修改新问题,不作任何回答:\n历史:{history}\n新问题:{question}\n修改后的新问题:"""
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")
prompt = PromptTemplate(
input_variables=["history", "question"],
template=prompt_template,
)
self.llm_chain = LLMChain(llm=llm, prompt=prompt)
def run(self, history, question):
return self.llm_chain.predict(history=history, question=question).strip()
class AnswerGenerator:
def __init__(self):
prompt_template = """请基于参考,回答问题,不需要标注任何引用:\n历史:{history}\n问题:{question}\n参考:{references}\n答案:"""
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")
prompt = PromptTemplate(
input_variables=["history", "question", "references"],
template=prompt_template,
)
self.llm_chain = LLMChain(llm=llm, prompt=prompt)
def run(self, history, question, references):
return self.llm_chain.predict(history=history, question=question, references=references).strip()
class BMVectorIndex:
def __init__(self,
model_path,
bm_index_path,
data_loader):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModel.from_pretrained(model_path)
self.bm_searcher = LuceneSearcher(bm_index_path)
self.loader = data_loader
if '-en' in model_path:
raise NotImplementedError("only support chinese currently")
if 'noinstruct' in model_path:
self.instruction = None
else:
self.instruction = "为这个句子生成表示以用于检索相关文章:"
def search_for_doc(self, query: str, RANKING: int = 1000, TOP_N: int = 5):
hits = self.bm_searcher.search(query, RANKING)
ids = [int(e.docid) for e in hits]
use_docs = self.loader.doc_emb[ids]
if self.instruction is not None:
query = self.instruction + query
encoded_input = self.tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = self.model(**encoded_input)
query_emb = model_output.last_hidden_state[:, 0]
query_emb = torch.nn.functional.normalize(query_emb, p=2, dim=-1).detach().cpu().numpy()
temp_index = faiss.IndexFlatIP(use_docs[0].shape[0])
temp_index.add(use_docs)
_, I = temp_index.search(query_emb, TOP_N)
return "\n".join([self.loader.data[ids[I[0][i]]] for i in range(TOP_N)])
class Agent:
def __init__(self, index):
self.memory = ""
self.index = index
self.query_generator = QueryGenerator()
self.answer_generator = AnswerGenerator()
def empty_memory(self):
self.memory = ""
def update_memory(self, question, answer):
self.memory += f"问:{question}\n答:{answer}\n"
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
while len(encoding.encode(self.memory)) > 3500:
pos = self.memory[2:].rfind("")
self.memory = self.memory[pos:]
def generate_query(self, question):
if self.memory == "":
return question
else:
return self.query_generator.run(self.memory, question)
def generate_answer(self, query, references):
return self.answer_generator.run(self.memory, query, references)
def answer(self, question, RANKING=1000, TOP_N=5, verbose=True):
query = self.generate_query(question)
references = self.index.search_for_doc(query, RANKING=RANKING, TOP_N=TOP_N)
answer = self.generate_answer(question, references)
self.update_memory(question, answer)
if verbose:
print('\033[96m' + "查:" + query + '\033[0m')
print('\033[96m' + "参:" + references + '\033[0m')
print("答:" + answer)