12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- @Time : 2024/9/10 17:55
- @Author : cao
- @File : build_index.py
- @Desc :
- """
- import json
- import torch
- import logging
- logger = logging.getLogger(__name__)
- import pandas as pd
- import faiss
- import numpy as np
- from sentence_transformers import SentenceTransformer
- logger.setLevel(logging.INFO)
- import jieba
- INDEX_FILE = "./conf/index"
- CATE_FILE = "./conf/cate.cf"
- #RAW_DATA = "AI+人话术模板部分7.1.xlsx"
- RAW_DATA = "question.xlsx"
- MODEL_Dir = "/home/model-server/m3e"
- def index():
- """有问答数据集生成训练样本
- """
- df = pd.read_excel(RAW_DATA)
- df.columns = ["意图类别","FAQ答案", '相似问',"class","cate"]
- df = df[df['相似问'].isna()==False]
- data = df[["意图类别",'相似问', "FAQ答案","class","cate"]]
- data=data.reset_index(drop=True)
- build_emb_index(data)
- def build_emb_index(data):
- model = load_model()
- emb = model.encode(data['相似问'].tolist(), normalize_embeddings=True)
- doc_intent = data.T.to_dict()
- with open(CATE_FILE, "w") as f:
- f.write(json.dumps(doc_intent, ensure_ascii=False))
- size = len(emb[0])
- index = faiss.IndexIDMap(faiss.IndexFlatIP(size))
- # faiss.normalize_L2(emb)
- index.add_with_ids(emb, np.array(range(0, data.shape[0])))
- faiss.write_index(index, INDEX_FILE)
- nums = index.ntotal
- logger.info("build faiss sucess {nums}")
- def load_model(model_dir=MODEL_Dir):
- return SentenceTransformer(model_name_or_path=model_dir)
- # if __name__ == "__main__":
- # build_emb_index()
|