build_index.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. @Time : 2024/9/10 17:55
  5. @Author : cao
  6. @File : build_index.py
  7. @Desc :
  8. """
  9. import json
  10. import torch
  11. import logging
  12. logger = logging.getLogger(__name__)
  13. import pandas as pd
  14. import faiss
  15. import numpy as np
  16. from sentence_transformers import SentenceTransformer
  17. logger.setLevel(logging.INFO)
  18. import jieba
  19. INDEX_FILE = "./conf/index"
  20. CATE_FILE = "./conf/cate.cf"
  21. #RAW_DATA = "AI+人话术模板部分7.1.xlsx"
  22. RAW_DATA = "question.xlsx"
  23. MODEL_Dir = "/home/model-server/m3e"
  24. def index():
  25. """有问答数据集生成训练样本
  26. """
  27. df = pd.read_excel(RAW_DATA)
  28. df.columns = ["意图类别","FAQ答案", '相似问',"class","cate"]
  29. df = df[df['相似问'].isna()==False]
  30. data = df[["意图类别",'相似问', "FAQ答案","class","cate"]]
  31. data=data.reset_index(drop=True)
  32. build_emb_index(data)
  33. def build_emb_index(data):
  34. model = load_model()
  35. emb = model.encode(data['相似问'].tolist(), normalize_embeddings=True)
  36. doc_intent = data.T.to_dict()
  37. with open(CATE_FILE, "w") as f:
  38. f.write(json.dumps(doc_intent, ensure_ascii=False))
  39. size = len(emb[0])
  40. index = faiss.IndexIDMap(faiss.IndexFlatIP(size))
  41. # faiss.normalize_L2(emb)
  42. index.add_with_ids(emb, np.array(range(0, data.shape[0])))
  43. faiss.write_index(index, INDEX_FILE)
  44. nums = index.ntotal
  45. logger.info("build faiss sucess {nums}")
  46. def load_model(model_dir=MODEL_Dir):
  47. return SentenceTransformer(model_name_or_path=model_dir)
  48. # if __name__ == "__main__":
  49. # build_emb_index()