asr.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. #!/usr/bin/env python3
  2. # encoding:utf-8
  3. import os
  4. import json
  5. import base64
  6. import threading
  7. import traceback
  8. from datetime import datetime
  9. from src.core.callcenter import registry
  10. import nls # 引入阿里云语音识别库
  11. from aliyunsdkcore.client import AcsClient
  12. from aliyunsdkcore.request import CommonRequest
  13. import time
  14. # 定义实时转写类
  15. class TestSt:
  16. # 静态变量用于缓存Token
  17. token_cache = {
  18. "token": None,
  19. "expire_time": None
  20. }
  21. # 获取Token的函数
  22. @classmethod
  23. def get_token(cls):
  24. ak_id = "LTAI5tQ2HmiHCygZkt5BYrYR"
  25. ak_secret = "KhmxTd14SUcXafpFk5yofA43FoeM99"
  26. client = AcsClient(ak_id, ak_secret, "cn-shanghai")
  27. request = CommonRequest()
  28. request.set_method('POST')
  29. request.set_domain('nls-meta.cn-shanghai.aliyuncs.com')
  30. request.set_version('2019-02-28')
  31. request.set_action_name('CreateToken')
  32. try:
  33. response = client.do_action_with_exception(request)
  34. jss = json.loads(response)
  35. if 'Token' in jss and 'Id' in jss['Token']:
  36. token = jss['Token']['Id']
  37. expire_time = jss['Token']['ExpireTime']
  38. print(f"Token: {token}, ExpireTime: {expire_time}")
  39. return token, int(expire_time) # 返回Token和过期时间
  40. else:
  41. print("Token获取失败,响应内容: ", response)
  42. except Exception as e:
  43. print(f"获取Token时发生错误: {e}")
  44. return None, None
  45. @classmethod
  46. def get_cached_token(cls):
  47. # 检查是否已有缓存的Token且未过期s):
  48. # # 检查是否已有缓存的Token且未
  49. current_time = int(time.time())
  50. # if cls.token_cache["token"] and cls.token_cache["expire_time"]:
  51. if cls.token_cache["token"] and cls.token_cache["expire_time"] - current_time > 60:
  52. # if current_time < cls.token_cache["expire_time"]:
  53. # print("使用缓存的Token")
  54. return cls.token_cache["token"]
  55. # 如果没有缓存Token或者Token已过期,重新获取
  56. new_token, expire_time = cls.get_token()
  57. if new_token:
  58. cls.token_cache["token"] = new_token
  59. cls.token_cache["expire_time"] = expire_time
  60. print("获取新的Token")
  61. return new_token
  62. else:
  63. print("无法获取Token")
  64. return None
  65. def __init__(self, tid, logger, message_receiver=None):
  66. # self.is_closed = False
  67. # self.lock = threading.Lock()
  68. self.logger = logger
  69. self.__event = threading.Event()
  70. self.__th = threading.Thread(target=self.__test_run)
  71. self.__id = tid
  72. self.message_receiver = message_receiver
  73. self._Token = self.get_cached_token()
  74. self.sr = None
  75. # WebSocket服务地址
  76. self.URL = "wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"
  77. self.APPKEY = "OKt6jogp6fRjHQVp" # 你的Appkey
  78. self.logger.debug("开始")
  79. def start(self):
  80. self.__th.start()
  81. def send_audio(self, audio_data):
  82. if self.sr:
  83. self.sr.send_audio(audio_data)
  84. def close(self):
  85. try:
  86. self.sr.stop()
  87. except Exception as e:
  88. self.logger.debug(f"[{self.__id}]Error stopping ASR: {e}")
  89. def __test_run(self):
  90. self.logger.debug("Thread:%s start..",self.__id)
  91. nls.enableTrace(True)
  92. count = 0
  93. self.__event.clear()
  94. while not self.__event.is_set():
  95. self.sr = nls.NlsSpeechTranscriber(
  96. url=self.URL,
  97. token=self._Token,
  98. appkey=self.APPKEY,
  99. on_sentence_begin=self.test_on_sentence_begin,
  100. on_sentence_end=self.test_on_sentence_end,
  101. on_start=self.test_on_start,
  102. on_result_changed=self.test_on_result_chg,
  103. on_completed=self.test_on_completed,
  104. on_error=self.test_on_error,
  105. on_close=self.test_on_close,
  106. callback_args=[self.__id]
  107. )
  108. try:
  109. self.sr.start(
  110. aformat="pcm",
  111. sample_rate=8000,
  112. enable_intermediate_result=True,
  113. enable_punctuation_prediction=True,
  114. enable_inverse_text_normalization=True,
  115. ex={'max_sentence_silence': 2000, 'disfluency': True, 'enable_words': True}
  116. )
  117. # _res = self.sr.ctrl(ex={'max_sentence_silence': 6000, 'disfluency': True,'enable_words': True })
  118. self.logger.debug(f"[{self.__id}]ASR session started. {count}")
  119. self.__event.wait(timeout=.5)
  120. self.logger.debug(f"[{self.__id}]ASR session started. {count}")
  121. except Exception as e:
  122. traceback.print_exc()
  123. self.logger.debug(f"[{self.__id}]ASR session start exception. {e}")
  124. count = count + 1
  125. def test_on_sentence_begin(self, message, *args):
  126. self.logger.debug("[%s]test_on_sentence_begin:%s", self.__id, message)
  127. if self.message_receiver:
  128. self.message_receiver(self.convert_message(message), *args)
  129. def test_on_sentence_end(self, message, *args):
  130. self.logger.debug("[%s]test_on_sentence_end:%s", self.__id, message)
  131. if self.message_receiver:
  132. self.message_receiver(self.convert_message(message), *args)
  133. def test_on_start(self, message, *args):
  134. self.__event.set()
  135. self.logger.debug("[%s]test_on_start:%s", self.__id, message)
  136. pass
  137. def test_on_error(self, message, *args):
  138. self.logger.debug("on_error args=>%s", args)
  139. if not self.__event.is_set():
  140. self.__event.set()
  141. if self.message_receiver:
  142. self.message_receiver(self.convert_message(message), *args)
  143. def test_on_close(self, *args):
  144. self.logger.debug("on_close: args=>%s", args)
  145. if not self.__event.is_set():
  146. self.__event.set()
  147. pass
  148. def test_on_result_chg(self, message, *args):
  149. # self.logger.debug("test_on_chg:{}".format(message))
  150. if self.message_receiver:
  151. self.message_receiver(self.convert_message(message), *args)
  152. def test_on_completed(self, message, *args):
  153. # self.logger.debug("on_completed:args=>{} message=>{}".format(args, message))
  154. pass
  155. def convert_message(self, message):
  156. final_result = {}
  157. message = json.loads(message)
  158. if message["header"]["status"] == 20000000:
  159. if message["header"]["name"] == "SentenceBegin":
  160. final_result['name'] = 'SentenceBegin'
  161. if message["header"]["name"] == "SentenceEnd":
  162. result = message["payload"]["result"]
  163. # self.logger.info("asr返回内容Result:%s", result)
  164. final_result['name'] = 'SentenceEnd'
  165. final_result['result'] = result
  166. elif message["header"]["name"] == "TranscriptionResultChanged":
  167. final_result['name'] = 'TranscriptionResultChanged'
  168. else:
  169. final_result['name'] = 'TranscriptionResultError'
  170. final_result['status'] = message['header']['status']
  171. final_result['result'] = ''
  172. self.logger.info(f"Status is not {message['header']['status']}")
  173. registry.ASR_ERRORS.labels(message['header']['status']).inc()
  174. self.logger.error("aliyun.Asr.recv: call_id:%s, final_result: %s", self.__id, final_result)
  175. return final_result
  176. # 讯飞ASR实时转写
  177. class XfAsr:
  178. def __init__(self, tid, logger, message_receiver=None):
  179. self.end_tag = "{\"end\": true}"
  180. self.tid = tid
  181. self.logger = logger
  182. self.message_receiver = message_receiver
  183. self.ws = self.new_connection()
  184. self.logger.info(f"xunfei.Asr: call_id: {tid}, ws.connected:{self.ws.connected}")
  185. self.trecv = threading.Thread(target=self.recv)
  186. def new_connection(self, base_url = "ws://rtasr.xfyun.cn/v1/ws", app_id = "1ec1097b", api_key = "60b7d2d8d172b065b1c3e723e5ba0696"):
  187. import hashlib
  188. import hmac
  189. import base64
  190. # from socket import *
  191. # import json, time, threading
  192. from websocket import create_connection
  193. # import websocket
  194. from urllib.parse import quote
  195. # import logging
  196. # logging.basicConfig()
  197. ts = str(int(time.time()))
  198. tt = (app_id + ts).encode('utf-8')
  199. md5 = hashlib.md5()
  200. md5.update(tt)
  201. baseString = md5.hexdigest()
  202. baseString = bytes(baseString, encoding='utf-8')
  203. apiKey = api_key.encode('utf-8')
  204. signa = hmac.new(apiKey, baseString, hashlib.sha1).digest()
  205. signa = base64.b64encode(signa)
  206. signa = str(signa, 'utf-8')
  207. count = 10
  208. _ws = None
  209. while count > 0:
  210. try:
  211. _ws = create_connection(base_url + "?appid=" + app_id + "&ts=" + ts + "&signa=" + quote(signa))
  212. break
  213. except Exception as e:
  214. count -= 1
  215. self.logger.info("new_connection:exception, call_id: %s, count=%s, message=%s", self.tid, count, e)
  216. time.sleep(.010)
  217. return _ws
  218. def start(self):
  219. self.trecv.start()
  220. def send_audio(self, chunk):
  221. # self.logger.debug('xunfei.Asr.send_audio: call_id: %s, chunk:%s, %s', self.tid, len(chunk), chunk)
  222. if self.ws:
  223. self.ws.send(chunk)
  224. def close(self):
  225. if self.ws:
  226. self.ws.send(bytes(self.end_tag.encode('utf-8')))
  227. self.ws.close()
  228. self.logger.info("connection closed call_id: %s", self.tid)
  229. def recv(self):
  230. try:
  231. self.logger.info(f"xunfei.Asr.recv: call_id: {self.tid}, ws.connected:{self.ws.connected}")
  232. while self.ws and self.ws.connected:
  233. message = str(self.ws.recv())
  234. if len(message) == 0:
  235. self.logger.info("xunfei.Asr.recv: receive result end call_id: %s", self.tid)
  236. break
  237. self.logger.info("xunfei.Asr.recv: call_id: %s, message :%s", self.tid, message)
  238. if self.message_receiver:
  239. self.message_receiver(self.convert_message(message))
  240. except Exception as e:
  241. traceback.print_exc()
  242. self.logger.error("xunfei.Asr.recv: receive result end, call_id: %s, message: %s", self.tid, e)
  243. def convert_message(self, message):
  244. final_result = {}
  245. result_dict = json.loads(message)
  246. if result_dict["code"] == "0":
  247. if result_dict["action"] == "started":
  248. final_result['name'] = 'SentenceBegin'
  249. elif result_dict["action"] == "result":
  250. result_dict = json.loads(message)
  251. result_1 = json.loads(result_dict["data"])
  252. st = result_1["cn"]["st"]
  253. rt = st["rt"]
  254. if st['type'] == "1":
  255. final_result['name'] = 'TranscriptionResultChanged'
  256. if st['type'] == "0":
  257. final_result['name'] = 'SentenceEnd'
  258. final_result['result'] = ''.join(cw["w"] for item in rt for ws in item["ws"] for cw in ws["cw"]).strip()
  259. elif result_dict["action"] == "error":
  260. self.logger.error("xunfei.Asr.recv: call_id: %s, action is error: %s", self.tid, message)
  261. final_result['name'] = 'TranscriptionResultError'
  262. final_result['result'] = message
  263. if self.ws:
  264. self.ws.close()
  265. else:
  266. self.logger.error("xunfei.Asr.recv: call_id: %s, Status is not: %s", self.tid, result_dict["code"])
  267. final_result['name'] = 'TranscriptionResultError'
  268. final_result['status'] = result_dict["code"]
  269. registry.ASR_ERRORS.labels(result_dict["code"]).inc()
  270. if self.ws:
  271. self.ws.close()
  272. self.logger.error("xunfei.Asr.recv: call_id: %s, final_result: %s", self.tid, final_result)
  273. return final_result
  274. class XunfeiAsr:
  275. STATUS_FIRST_FRAME = 0 # 第一帧的标识
  276. STATUS_CONTINUE_FRAME = 1 # 中间帧标识
  277. STATUS_LAST_FRAME = 2 # 最后一帧的标识
  278. def __init__(self, tid, logger, message_receiver=None):
  279. self.__id = tid
  280. self.logger = logger
  281. self.message_receiver = message_receiver
  282. self.connected = False
  283. self.status = XunfeiAsr.STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
  284. self.pre_result = ""
  285. self.APPID = '1ec1097b'
  286. self.APIKey = '7237c29d862daa6fd805f788ed70c409'
  287. self.APISecret = 'YTk1YzAyMmQwYjQ3ZDJkYWQyZGQwMDFm'
  288. # 公共参数(common)
  289. self.CommonArgs = {"app_id": self.APPID}
  290. # 业务参数(business),更多个性化参数可在官网查看
  291. self.BusinessArgs = {"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vinfo":1,"vad_eos":2000, "dwa":"wpgs"}
  292. # self.__event = threading.Event()
  293. self.__th = threading.Thread(target=self.__run)
  294. def create_url(self):
  295. import hashlib
  296. import hmac
  297. from urllib.parse import urlencode
  298. from wsgiref.handlers import format_date_time
  299. from time import mktime
  300. url = 'wss://ws-api.xfyun.cn/v2/iat'
  301. # 生成RFC1123格式的时间戳
  302. now = datetime.now()
  303. date = format_date_time(mktime(now.timetuple()))
  304. # 拼接字符串
  305. signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
  306. signature_origin += "date: " + date + "\n"
  307. signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
  308. # 进行hmac-sha256进行加密
  309. signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
  310. digestmod=hashlib.sha256).digest()
  311. signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
  312. authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
  313. self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
  314. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  315. # 将请求的鉴权参数组合为字典
  316. v = {
  317. "authorization": authorization,
  318. "date": date,
  319. "host": "ws-api.xfyun.cn"
  320. }
  321. # 拼接鉴权参数,生成url
  322. url = url + '?' + urlencode(v)
  323. # print("date: ",date)
  324. # print("v: ",v)
  325. # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
  326. # print('websocket url :', url)
  327. return url
  328. def __run(self):
  329. import ssl
  330. import websocket
  331. websocket.enableTrace(False)
  332. try:
  333. # 测试时候在此处正确填写相关信息即可运行
  334. time1 = datetime.now()
  335. ws_url = self.create_url()
  336. self.logger.info("xunfei.Asr.call_id:%s, ws_url:%s", self.__id, ws_url)
  337. self.ws = websocket.WebSocketApp(ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close)
  338. self.ws.on_open = self.on_open
  339. self.connected = False
  340. self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  341. time_cost = (datetime.now() - time1)
  342. self.logger.info(f"xunfei.Asr.started. call_id:{self.__id}, timeCost:{time_cost}")
  343. except Exception as e:
  344. traceback.print_exc()
  345. self.logger.info(f"[{self.__id}]ASR session start exception. {e}")
  346. def start(self):
  347. self.__th.start()
  348. def send_audio(self, audio_data=None):
  349. if not self.connected:
  350. self.logger.info('xunfei.Asr.send_audio:ws_is_None: call_id: %s, chunk:%s, %s', self.__id, len(audio_data), audio_data)
  351. return
  352. if not audio_data:
  353. self.status = XunfeiAsr.STATUS_LAST_FRAME
  354. buf = bytes(audio_data) if audio_data else bytes()
  355. # self.logger.info('xunfei.Asr.send_audio: call_id: %s, status:%s, chunk:%s, %s', self.__id, status, (len(buf) if buf else 0) , buf)
  356. # 第一帧处理
  357. # 发送第一帧音频,带business 参数
  358. # appid 必须带上,只需第一帧发送
  359. if self.status == XunfeiAsr.STATUS_FIRST_FRAME:
  360. d = {"common": self.CommonArgs,
  361. "business": self.BusinessArgs,
  362. "data": {"status": 0, "format": "audio/L16;rate=16000",
  363. "audio": str(base64.b64encode(buf), 'utf-8'),
  364. "encoding": "raw"}}
  365. d = json.dumps(d)
  366. self.ws.send(d)
  367. self.status = XunfeiAsr.STATUS_CONTINUE_FRAME
  368. # 中间帧处理
  369. elif self.status == XunfeiAsr.STATUS_CONTINUE_FRAME:
  370. d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
  371. "audio": str(base64.b64encode(buf), 'utf-8'),
  372. "encoding": "raw"}}
  373. self.ws.send(json.dumps(d))
  374. # 最后一帧处理
  375. elif self.status == XunfeiAsr.STATUS_LAST_FRAME:
  376. d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
  377. "audio": str(base64.b64encode(buf), 'utf-8'),
  378. "encoding": "raw"}}
  379. self.ws.send(json.dumps(d))
  380. time.sleep(1)
  381. def close(self):
  382. try:
  383. self.send_audio()
  384. self.ws.close()
  385. except Exception as e:
  386. self.logger.info(f"[{self.__id}]Error stopping ASR: {e}")
  387. # 收到websocket连接建立的处理
  388. def on_open(self, ws):
  389. self.connected = True
  390. self.logger.info("xunfei.Asr.open: call_id: %s", self.__id)
  391. # self.__event.set()
  392. # 收到websocket消息的处理
  393. def on_message(self, ws, message):
  394. try:
  395. self.logger.info("xunfei.Asr.recv: call_id: %s, message :%s", self.__id, message)
  396. if self.message_receiver:
  397. self.message_receiver(self.convert_message(message))
  398. except Exception as e:
  399. self.logger.error("receive msg, but parse exception call_id:%s, message:%s, error:%s", self.__id, message, e)
  400. # 收到websocket错误的处理
  401. def on_error(self, ws, error):
  402. self.logger.error("xunfei.Asr.recv::error, call_id:%s, msg:%s", self.__id, error)
  403. # if not self.__event.is_set():
  404. # self.__event.set()
  405. # 收到websocket关闭的处理
  406. def on_close(self, ws, a, b):
  407. self.connected = False
  408. self.logger.error("xunfei.Asr.recv::close, call_id:%s", self.__id)
  409. # if not self.__event.is_set():
  410. # self.__event.set()
  411. def convert_message(self, message):
  412. final_result = {}
  413. message = json.loads(message)
  414. if message["code"] == 0:
  415. data = message["data"]["result"]["ws"]
  416. result = ""
  417. for i in data:
  418. for w in i["cw"]:
  419. result += w["w"]
  420. status = message["data"]["status"]
  421. if status == 0:
  422. final_result['name'] = 'SentenceBegin'
  423. elif status == 1:
  424. final_result['name'] = 'TranscriptionResultChanged'
  425. final_result['result'] = result
  426. self.pre_result = result
  427. elif status == 2:
  428. final_result['status'] = 'SentenceEnd'
  429. final_result['result'] = self.pre_result + result
  430. self.pre_result = ""
  431. else:
  432. final_result['name'] = 'TranscriptionResultError'
  433. final_result['status'] = message['code']
  434. final_result['result'] = message['message']
  435. self.logger.info("call_id:%s, sid:%s call error:%s code is:%s" % (self.__id, message["sid"], message["message"], message["code"]))
  436. registry.ASR_ERRORS.labels(message['code']).inc()
  437. self.logger.error("xunfei.Asr.recv: call_id:%s, final_result: %s", self.__id, final_result)
  438. return final_result