12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- @Time : 2024/9/11 09:57
- @Author : cao
- @File : utils.py
- @Desc :
- """
- import logging
- import json
- import faiss
- import os
- from sentence_transformers import SentenceTransformer
- import build_index
- import torch
- logger = logging.getLogger(__name__)
- INDEX_FILE = "./conf/index"
- CATE_FILE = "./conf/cate.cf"
- #RAW_DATA = "AI+人话术模板部分7.1.xlsx"
- RAW_DATA = "question.xlsx"
- DIRTY_DATA = "./conf/dirty"
- STOP_WORDS ="./conf/stopword"
- MODEL_DIR ="/home/model-server/m3e"
- def load_model(model_dir=MODEL_DIR):
- return SentenceTransformer(model_name_or_path=model_dir, device='cuda' if torch.cuda.is_available() else 'cpu')
- def load_cate(file=CATE_FILE):
- with open(file) as f:
- return json.loads(f.readline())
- def load_index(index_path=INDEX_FILE):
- if not os.path.exists(index_path):
- logger.info("faiss does not exists, need build index")
- build_index.index()
- if os.path.exists(index_path):
- index = faiss.read_index(index_path)
- nums = index.ntotal
- logger.info(f"load faiss {nums} sucess")
- return index
- index = load_index()
- cate = load_cate()
- if __name__ == "__main__":
- cate= load_cate("./conf/cate.cf")
- index = load_index()
|