utils.py 9.2 KB


  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. @Time : 2024/10/15 16:19
  5. @File : utils.py
  6. @Desc :
  7. """
  8. import sys
  9. sys.path.append("..")
  10. from datetime import datetime
  11. from functools import wraps
  12. from typing import (Any,
  13. List,
  14. Text,
  15. Dict
  16. )
  17. from config import (
  18. get_logger,
  19. GENERATED,
  20. FIXED,
  21. MOUDLES
  22. )
  23. import pandas as pd
  24. from threading import Thread
  25. import json
  26. logger = get_logger()
  27. from database import Mysql
  28. from pypinyin import pinyin, Style
  29. import jieba
  30. import re
  31. import itertools
  32. from concurrent.futures import ThreadPoolExecutor,as_completed
  33. executor = ThreadPoolExecutor(max_workers=20)
  34. def get_speech_status(bid: Text = None, options: List[Dict[Text, Text]] = None):
  35. """which speech template to choose"""
  36. res = dict()
  37. module = MOUDLES[bid]
  38. option = ''
  39. if options:
  40. option = options[-1]['title'] # 外呼机器人单选
  41. for key, value in module.items():
  42. if isinstance(value, dict):
  43. if option in value['content']:
  44. res['action'] = key
  45. res['speech_id'] = value['speech_id']
  46. res['speech_type'] = value['speech_type']
  47. res['speech_interrupt'] = value['speech_interrupt']
  48. res['asr'] = options[-1]['asr']
  49. break
  50. return res
  51. def get_robot_speeches(msg, bid, uid, questions: Dict[Any, Any] = None):
  52. """GET speech and speech status"""
  53. from entity import Status
  54. choose_speech_status = get_speech_status(bid, msg.option)
  55. if not choose_speech_status:
  56. choose_speech = Status.base.value
  57. robot_speech, interrupt = speech_main_contents(uid, bid, msg.code, questions)
  58. interrupt = questions[msg.code]['mainInterrupt']
  59. return robot_speech, choose_speech, interrupt
  60. return None, None, None
  61. def speech_main_contents(uid: Text = None,
  62. bid: Text = None,
  63. code: Text = None,
  64. questions: Dict[Any, Any] = None,
  65. options: List[Dict[Any, Any]] = None):
  66. from util import nlg_service
  67. from entity import Status
  68. def parse_options():
  69. if options:
  70. option = options[-1]
  71. is_faq = option['isFaq']
  72. faq_content = option['faqContent']
  73. if is_faq:
  74. return faq_content
  75. if "businessContent" in option:
  76. return option['businessContent']
  77. return
  78. _robot_speech = ''
  79. speech_type = questions[code]['mainType']
  80. node_name = questions[code]['nodeName']
  81. interrupt = questions[code]['mainInterrupt']
  82. choose_speech = Status.base.value
  83. #logger.info(f"code:{code},speech_type:{speech_type}")
  84. topic = questions[code]
  85. # logger.info(f"topic: {topic}")
  86. _faq_content = parse_options()
  87. if speech_type == GENERATED:
  88. # TODO 生成话术
  89. resp = nlg_service(uid, bid, node_name, choose_speech)
  90. if _faq_content or resp:
  91. _robot_speech = "{}&{}".format(_faq_content, resp) if _faq_content else resp
  92. elif speech_type == FIXED:
  93. content = questions[code]['mainContent']
  94. if content or _faq_content:
  95. _robot_speech = "{}&{}".format(_faq_content, content) if _faq_content else content
  96. return _robot_speech, interrupt
  97. def get_next_code_with_track(uid: Text = None,
  98. code: Text = None,
  99. option: Text = None,
  100. questions: Dict[Any, Any] = None
  101. ):
  102. """
  103. use user trackCode get qus id
  104. @param uid:
  105. @param code:
  106. @param option:
  107. @param questions:
  108. @return:
  109. """
  110. options = questions[code]['options']
  111. for cell in options:
  112. if option == cell['title']:
  113. next = cell['next']
  114. logger.info(f"uid:{uid}:code from {code} to {next}")
  115. return next
  116. return code
  117. def _async(f):
  118. def wrapper(*args, **kwargs):
  119. thr = Thread(target=f, args=args, kwargs=kwargs)
  120. thr.start()
  121. return wrapper
  122. @_async
  123. def insert_log(bid, uid, session_id, scene):
  124. """
  125. CREATE TABLE botrecords (
  126. id INT AUTO_INCREMENT PRIMARY KEY,
  127. session VARCHAR(50) unique not null COMMENT '请求id',
  128. req_time DATETIME COMMENT '来电时间',
  129. uid VARCHAR(20) COMMENT '来电手机号',
  130. bid VARCHAR(20) COMMENT '话术id',
  131. intent VARCHAR(20) COMMENT '意图',
  132. contents TEXT comment '内容',
  133. dialog TEXT COMMENT '对话'
  134. );
  135. """
  136. tmp = json.dumps(scene.case, ensure_ascii=False, default=lambda obj: obj.__dict__)
  137. res = json.loads(tmp)
  138. if res:
  139. answers = res.get("answer")
  140. intent = ''
  141. codes = [i for i in map(lambda x: x.get("code"), answers)]
  142. if "1.20" in codes or "1.10" in codes or "1.00" in codes:
  143. for answer in answers:
  144. if answer.get("code") in ["1.20", "1.10", "1.00"]:
  145. asr = answer.get("option", [{}])[-1].get("asr", '')
  146. code = answer.get("code")
  147. opt = answer.get("option", {})
  148. if (code == "1.10" and asr!="1") or (code == "1.00" and asr not in ["2", "1"]):
  149. if opt:
  150. if opt[-1].get("subclass"):
  151. intent = opt[-1].get("subclass")
  152. elif opt[-1].get("firstclass"):
  153. intent = opt[-1].get("firstclass")
  154. else:
  155. intent =opt[-1].get("title", '')
  156. elif code == "1.20":
  157. tags = {"1": "1_停水咨询", "2":"1_漏水保修", "3":"1_户号查询","4":"1_水费查询", "5":"1_水价咨询","6":"1_水质水价保修", "7":"1_投诉建议", "0":"1_转人工"}
  158. intent = tags.get(asr, "1_其他")
  159. elif code == "1.00" and asr=="2":
  160. intent = "不体验AI服务"
  161. elif code == "1.00" and asr=="1":
  162. intent = asr
  163. filter_ans = filter(lambda x: x.get("code") not in ["2.00", "3.00", "4.00"], answers)
  164. content = [i for i in map(lambda x:[ x.get("question"), x.get("option", [{}])[-1].get("asr", '')], filter_ans)]
  165. content.append((res.get("robot_speech"), ''))
  166. contents = json.dumps({"data": content}, ensure_ascii=False)
  167. req_time = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
  168. mysql = Mysql()
  169. mysql.insert_records([session_id, req_time, uid, bid, intent, contents, json.dumps(res, ensure_ascii=False)])
  170. # records = mysql.get_records(uid)
  171. # logger.info(f"databases:uid:{uid}, {records}")
  172. mysql.close_mysql()
  173. def timetic(func):
  174. @wraps(func)
  175. def wrapper(*args, **kwargs):
  176. start = datetime.now()
  177. results = func(*args, **kwargs)
  178. cost = (datetime.now() - start).total_seconds()
  179. if func.__qualname__ == "botservice":
  180. bot = kwargs.get("reqbot")
  181. sessionId, userId,nodeId = bot.sessionId, bot.userId,bot.nodeId
  182. logger.info("{},session:{},uid:{},nodeid:{} ==> {}s".format(func.__qualname__,sessionId , userId, nodeId, cost))
  183. else:
  184. logger.info("{} ==> {}s".format(func.__qualname__, cost))
  185. return results
  186. return wrapper
  187. def loaddict():
  188. loc = dict()
  189. df = pd.read_excel("../data/location.xlsx", header=0)
  190. loc['zh'] = dict(df[['norm_name', 'name']].values)
  191. loc['pinyin'] = dict(df[['name_pinyin', 'name']].values)
  192. short_val = [(i, 80) for i in df['short_name'].dropna().tolist()]
  193. norm_val = [(i, 100) for i in df['norm_name'].dropna().tolist()]
  194. norm_val.extend(short_val)
  195. loc['total'] = dict(norm_val)
  196. return loc
  197. user_dict= loaddict()
  198. jieba.load_userdict(user_dict['total'])
  199. def norm_community(asr):
  200. def match_loc(comb):
  201. cur_wd = ''.join(comb)
  202. if cur_wd in asr:
  203. if cur_wd in user_dict['zh']:
  204. return user_dict['zh'].get(cur_wd)
  205. term = "|".join([term[0] for term in pinyin(cur_wd, style=Style.NORMAL)])
  206. if term in user_dict['pinyin']:
  207. return user_dict['pinyin'][term]
  208. return None
  209. if asr in user_dict['zh']:
  210. return user_dict['zh'][asr]
  211. text = re.sub(r'[(())]', '', asr)
  212. text = "|".join([word[0] for word in pinyin(text, style=Style.NORMAL)])
  213. if text in user_dict['pinyin']:
  214. return user_dict['pinyin'][text]
  215. words = jieba.lcut(asr)
  216. for word in words:
  217. if word in user_dict['zh']:
  218. return user_dict['zh'].get(word)
  219. term = "|".join([term[0] for term in pinyin(word, style=Style.NORMAL)])
  220. if term in user_dict['pinyin']:
  221. return user_dict['pinyin'][term]
  222. for r in range(1, len(words) + 1):
  223. combinations_list = list(itertools.combinations(words, r))
  224. features = [executor.submit(match_loc, combo) for combo in combinations_list]
  225. result =[feature.result() for feature in as_completed(features)]
  226. res = [i for i in filter(lambda x: x is not None, result)]
  227. if len(res) >0:
  228. return res[0]
  229. return asr
  230. if __name__ == "__main__":
  231. print(norm_community("嗯,那个我们家是碧水兰庭的"))