utils.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. @Time : 2024/9/11 09:57
  5. @Author : cao
  6. @File : utils.py
  7. @Desc :
  8. """
  9. import logging
  10. import json
  11. import faiss
  12. import os
  13. from sentence_transformers import SentenceTransformer
  14. import build_index
  15. import torch
  16. logger = logging.getLogger(__name__)
  17. INDEX_FILE = "./conf/index"
  18. CATE_FILE = "./conf/cate.cf"
  19. #RAW_DATA = "AI+人话术模板部分7.1.xlsx"
  20. RAW_DATA = "question.xlsx"
  21. DIRTY_DATA = "./conf/dirty"
  22. STOP_WORDS ="./conf/stopword"
  23. MODEL_DIR ="/home/model-server/m3e"
  24. def load_model(model_dir=MODEL_DIR):
  25. return SentenceTransformer(model_name_or_path=model_dir, device='cuda' if torch.cuda.is_available() else 'cpu')
  26. def load_cate(file=CATE_FILE):
  27. with open(file) as f:
  28. return json.loads(f.readline())
  29. def load_index(index_path=INDEX_FILE):
  30. if not os.path.exists(index_path):
  31. logger.info("faiss does not exists, need build index")
  32. build_index.index()
  33. if os.path.exists(index_path):
  34. index = faiss.read_index(index_path)
  35. nums = index.ntotal
  36. logger.info(f"load faiss {nums} sucess")
  37. return index
  38. index = load_index()
  39. cate = load_cate()
  40. if __name__ == "__main__":
  41. cate= load_cate("./conf/cate.cf")
  42. index = load_index()