speech_recognizer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import logging
  3. import uuid
  4. import json
  5. import threading
  6. from nls.core import NlsCore
  7. from . import logging
  8. from . import util
  9. from .exception import (StartTimeoutException,
  10. StopTimeoutException,
  11. NotStartException,
  12. InvalidParameter)
  13. __SPEECH_RECOGNIZER_NAMESPACE__ = 'SpeechRecognizer'
  14. __SPEECH_RECOGNIZER_REQUEST_CMD__ = {
  15. 'start': 'StartRecognition',
  16. 'stop': 'StopRecognition'
  17. }
  18. __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1'
  19. __all__ = ['NlsSpeechRecognizer']
  20. class NlsSpeechRecognizer:
  21. """
  22. Api for short sentence speech recognition
  23. """
  24. def __init__(self,
  25. url=__URL__,
  26. token=None,
  27. appkey=None,
  28. on_start=None,
  29. on_result_changed=None,
  30. on_completed=None,
  31. on_error=None, on_close=None,
  32. callback_args=[]):
  33. """
  34. NlsSpeechRecognizer initialization
  35. Parameters:
  36. -----------
  37. url: str
  38. websocket url.
  39. token: str
  40. access token. if you do not have a token, provide access id and key
  41. secret from your aliyun account.
  42. appkey: str
  43. appkey from aliyun
  44. on_start: function
  45. Callback object which is called when recognition started.
  46. on_start has two arguments.
  47. The 1st argument is message which is a json format string.
  48. The 2nd argument is *args which is callback_args.
  49. on_result_changed: function
  50. Callback object which is called when partial recognition result
  51. arrived.
  52. on_result_changed has two arguments.
  53. The 1st argument is message which is a json format string.
  54. The 2nd argument is *args which is callback_args.
  55. on_completed: function
  56. Callback object which is called when recognition is completed.
  57. on_completed has two arguments.
  58. The 1st argument is message which is a json format string.
  59. The 2nd argument is *args which is callback_args.
  60. on_error: function
  61. Callback object which is called when any error occurs.
  62. on_error has two arguments.
  63. The 1st argument is message which is a json format string.
  64. The 2nd argument is *args which is callback_args.
  65. on_close: function
  66. Callback object which is called when connection closed.
  67. on_close has one arguments.
  68. The 1st argument is *args which is callback_args.
  69. callback_args: list
  70. callback_args will return in callbacks above for *args.
  71. """
  72. if not token or not appkey:
  73. raise InvalidParameter('Must provide token and appkey')
  74. self.__response_handler__ = {
  75. 'RecognitionStarted': self.__recognition_started,
  76. 'RecognitionResultChanged': self.__recognition_result_changed,
  77. 'RecognitionCompleted': self.__recognition_completed,
  78. 'TaskFailed': self.__task_failed
  79. }
  80. self.__callback_args = callback_args
  81. self.__appkey = appkey
  82. self.__url = url
  83. self.__token = token
  84. self.__start_cond = threading.Condition()
  85. self.__start_flag = False
  86. self.__on_start = on_start
  87. self.__on_result_changed = on_result_changed
  88. self.__on_completed = on_completed
  89. self.__on_error = on_error
  90. self.__on_close = on_close
  91. self.__allow_aformat = (
  92. 'pcm', 'opus', 'opu', 'wav', 'mp3', 'speex', 'aac', 'amr'
  93. )
  94. def __handle_message(self, message):
  95. logging.debug('__handle_message')
  96. try:
  97. __result = json.loads(message)
  98. if __result['header']['name'] in self.__response_handler__:
  99. __handler = self.__response_handler__[
  100. __result['header']['name']]
  101. __handler(message)
  102. else:
  103. logging.error('cannot handle cmd{}'.format(
  104. __result['header']['name']))
  105. return
  106. except json.JSONDecodeError:
  107. logging.error('cannot parse message:{}'.format(message))
  108. return
  109. def __sr_core_on_open(self):
  110. logging.debug('__sr_core_on_open')
  111. def __sr_core_on_msg(self, msg, *args):
  112. logging.debug('__sr_core_on_msg:msg={} args={}'.format(msg, args))
  113. self.__handle_message(msg)
  114. def __sr_core_on_error(self, msg, *args):
  115. logging.debug('__sr_core_on_error:msg={} args={}'.format(msg, args))
  116. def __sr_core_on_close(self):
  117. logging.debug('__sr_core_on_close')
  118. if self.__on_close:
  119. self.__on_close(*self.__callback_args)
  120. with self.__start_cond:
  121. self.__start_flag = False
  122. self.__start_cond.notify()
  123. def __recognition_started(self, message):
  124. logging.debug('__recognition_started')
  125. if self.__on_start:
  126. self.__on_start(message, *self.__callback_args)
  127. with self.__start_cond:
  128. self.__start_flag = True
  129. self.__start_cond.notify()
  130. def __recognition_result_changed(self, message):
  131. logging.debug('__recognition_result_changed')
  132. if self.__on_result_changed:
  133. self.__on_result_changed(message, *self.__callback_args)
  134. def __recognition_completed(self, message):
  135. logging.debug('__recognition_completed')
  136. self.__nls.shutdown()
  137. logging.debug('__recognition_completed shutdown done')
  138. if self.__on_completed:
  139. self.__on_completed(message, *self.__callback_args)
  140. with self.__start_cond:
  141. self.__start_flag = False
  142. self.__start_cond.notify()
  143. def __task_failed(self, message):
  144. logging.debug('__task_failed')
  145. with self.__start_cond:
  146. self.__start_flag = False
  147. self.__start_cond.notify()
  148. if self.__on_error:
  149. self.__on_error(message, *self.__callback_args)
  150. def start(self, aformat='pcm', sample_rate=16000, ch=1,
  151. enable_intermediate_result=False,
  152. enable_punctuation_prediction=False,
  153. enable_inverse_text_normalization=False,
  154. timeout=10,
  155. ping_interval=8,
  156. ping_timeout=None,
  157. ex:dict=None):
  158. """
  159. Recognition start
  160. Parameters:
  161. -----------
  162. aformat: str
  163. audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm'
  164. sample_rate: int
  165. audio sample rate, default is 16000
  166. ch: int
  167. audio channels, only support mono which is 1
  168. enable_intermediate_result: bool
  169. whether enable return intermediate recognition result, default is False
  170. enable_punctuation_prediction: bool
  171. whether enable punctuation prediction, default is False
  172. enable_inverse_text_normalization: bool
  173. whether enable ITN, default is False
  174. timeout: int
  175. wait timeout for connection setup
  176. ping_interval: int
  177. send ping interval, 0 for disable ping send, default is 8
  178. ping_timeout: int
  179. timeout after send ping and recive pong, set None for disable timeout check and default is None
  180. ex: dict
  181. dict which will merge into 'payload' field in request
  182. """
  183. self.__nls = NlsCore(
  184. url=self.__url,
  185. token=self.__token,
  186. on_open=self.__sr_core_on_open,
  187. on_message=self.__sr_core_on_msg,
  188. on_close=self.__sr_core_on_close,
  189. on_error=self.__sr_core_on_error,
  190. callback_args=[])
  191. if ch != 1:
  192. raise InvalidParameter(f'Not support channel {ch}')
  193. if aformat not in self.__allow_aformat:
  194. raise InvalidParameter(f'Format {aformat} not support')
  195. __id4 = uuid.uuid4().hex
  196. self.__task_id = uuid.uuid4().hex
  197. __header = {
  198. 'message_id': __id4,
  199. 'task_id': self.__task_id,
  200. 'namespace': __SPEECH_RECOGNIZER_NAMESPACE__,
  201. 'name': __SPEECH_RECOGNIZER_REQUEST_CMD__['start'],
  202. 'appkey': self.__appkey
  203. }
  204. __payload = {
  205. 'format': aformat,
  206. 'sample_rate': sample_rate,
  207. 'enable_intermediate_result': enable_intermediate_result,
  208. 'enable_punctuation_prediction': enable_punctuation_prediction,
  209. 'enable_inverse_text_normalization': enable_inverse_text_normalization
  210. }
  211. if ex:
  212. __payload.update(ex)
  213. __msg = {
  214. 'header': __header,
  215. 'payload': __payload,
  216. 'context': util.GetDefaultContext()
  217. }
  218. __jmsg = json.dumps(__msg)
  219. with self.__start_cond:
  220. if self.__start_flag:
  221. logging.debug('already start...')
  222. return
  223. self.__nls.start(__jmsg, ping_interval, ping_timeout)
  224. if self.__start_flag == False:
  225. if self.__start_cond.wait(timeout=timeout):
  226. return
  227. else:
  228. raise StartTimeoutException(f'Waiting Start over {timeout}s')
  229. def stop(self, timeout=10):
  230. """
  231. Stop recognition and mark session finished
  232. Parameters:
  233. -----------
  234. timeout: int
  235. timeout for waiting completed message from cloud
  236. """
  237. __id4 = uuid.uuid4().hex
  238. __header = {
  239. 'message_id': __id4,
  240. 'task_id': self.__task_id,
  241. 'namespace': __SPEECH_RECOGNIZER_NAMESPACE__,
  242. 'name': __SPEECH_RECOGNIZER_REQUEST_CMD__['stop'],
  243. 'appkey': self.__appkey
  244. }
  245. __msg = {
  246. 'header': __header,
  247. 'context': util.GetDefaultContext()
  248. }
  249. __jmsg = json.dumps(__msg)
  250. with self.__start_cond:
  251. if not self.__start_flag:
  252. logging.debug('not start yet...')
  253. return
  254. self.__nls.send(__jmsg, False)
  255. if self.__start_flag == True:
  256. logging.debug('stop wait..')
  257. if self.__start_cond.wait(timeout):
  258. return
  259. else:
  260. raise StopTimeoutException(f'Waiting stop over {timeout}s')
  261. def shutdown(self):
  262. """
  263. Shutdown connection immediately
  264. """
  265. self.__nls.shutdown()
  266. def send_audio(self, pcm_data):
  267. """
  268. Send audio binary, audio size prefer 20ms length
  269. Parameters:
  270. -----------
  271. pcm_data: bytes
  272. audio binary which format is 'aformat' in start method
  273. """
  274. if not pcm_data:
  275. raise InvalidParameter('data empty!')
  276. __data = pcm_data
  277. with self.__start_cond:
  278. if not self.__start_flag:
  279. raise NotStartException('Need start before send!')
  280. try:
  281. self.__nls.send(__data, True)
  282. except ConnectionResetError as __e:
  283. logging.error('connection reset')
  284. self.__start_flag = False
  285. self.__nls.shutdown()
  286. raise __e