speech_transcriber.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import logging
  3. import uuid
  4. import json
  5. import threading
  6. from nls.core import NlsCore
  7. from . import logging
  8. from . import util
  9. from nls.exception import (StartTimeoutException,
  10. StopTimeoutException,
  11. NotStartException,
  12. InvalidParameter)
  13. __SPEECH_TRANSCRIBER_NAMESPACE__ = 'SpeechTranscriber'
  14. __SPEECH_TRANSCRIBER_REQUEST_CMD__ = {
  15. 'start': 'StartTranscription',
  16. 'stop': 'StopTranscription',
  17. 'control': 'ControlTranscriber'
  18. }
  19. __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1'
  20. __all__ = ['NlsSpeechTranscriber']
  21. class NlsSpeechTranscriber:
  22. """
  23. Api for realtime speech transcription
  24. """
  25. def __init__(self,
  26. url=__URL__,
  27. token=None,
  28. appkey=None,
  29. on_start=None,
  30. on_sentence_begin=None,
  31. on_sentence_end=None,
  32. on_result_changed=None,
  33. on_completed=None,
  34. on_error=None,
  35. on_close=None,
  36. callback_args=[]):
  37. '''
  38. NlsSpeechTranscriber initialization
  39. Parameters:
  40. -----------
  41. url: str
  42. websocket url.
  43. token: str
  44. access token. if you do not have a token, provide access id and key
  45. secret from your aliyun account.
  46. appkey: str
  47. appkey from aliyun
  48. on_start: function
  49. Callback object which is called when recognition started.
  50. on_start has two arguments.
  51. The 1st argument is message which is a json format string.
  52. The 2nd argument is *args which is callback_args.
  53. on_sentence_begin: function
  54. Callback object which is called when one sentence started.
  55. on_sentence_begin has two arguments.
  56. The 1st argument is message which is a json format string.
  57. The 2nd argument is *args which is callback_args.
  58. on_sentence_end: function
  59. Callback object which is called when sentence is end.
  60. on_sentence_end has two arguments.
  61. The 1st argument is message which is a json format string.
  62. The 2nd argument is *args which is callback_args.
  63. on_result_changed: function
  64. Callback object which is called when partial recognition result
  65. arrived.
  66. on_result_changed has two arguments.
  67. The 1st argument is message which is a json format string.
  68. The 2nd argument is *args which is callback_args.
  69. on_completed: function
  70. Callback object which is called when recognition is completed.
  71. on_completed has two arguments.
  72. The 1st argument is message which is a json format string.
  73. The 2nd argument is *args which is callback_args.
  74. on_error: function
  75. Callback object which is called when any error occurs.
  76. on_error has two arguments.
  77. The 1st argument is message which is a json format string.
  78. The 2nd argument is *args which is callback_args.
  79. on_close: function
  80. Callback object which is called when connection closed.
  81. on_close has one arguments.
  82. The 1st argument is *args which is callback_args.
  83. callback_args: list
  84. callback_args will return in callbacks above for *args.
  85. '''
  86. if not token or not appkey:
  87. raise InvalidParameter('Must provide token and appkey')
  88. self.__response_handler__ = {
  89. 'SentenceBegin': self.__sentence_begin,
  90. 'SentenceEnd': self.__sentence_end,
  91. 'TranscriptionStarted': self.__transcription_started,
  92. 'TranscriptionResultChanged': self.__transcription_result_changed,
  93. 'TranscriptionCompleted': self.__transcription_completed,
  94. 'TaskFailed': self.__task_failed
  95. }
  96. self.__callback_args = callback_args
  97. self.__url = url
  98. self.__appkey = appkey
  99. self.__token = token
  100. self.__start_cond = threading.Condition()
  101. self.__start_flag = False
  102. self.__on_start = on_start
  103. self.__on_sentence_begin = on_sentence_begin
  104. self.__on_sentence_end = on_sentence_end
  105. self.__on_result_changed = on_result_changed
  106. self.__on_completed = on_completed
  107. self.__on_error = on_error
  108. self.__on_close = on_close
  109. self.__allow_aformat = (
  110. 'pcm', 'opus', 'opu', 'wav', 'amr', 'speex', 'mp3', 'aac'
  111. )
  112. def __handle_message(self, message):
  113. logging.debug('__handle_message')
  114. try:
  115. __result = json.loads(message)
  116. if __result['header']['name'] in self.__response_handler__:
  117. __handler = self.__response_handler__[
  118. __result['header']['name']]
  119. __handler(message)
  120. else:
  121. logging.error('cannot handle cmd{}'.format(
  122. __result['header']['name']))
  123. return
  124. except json.JSONDecodeError:
  125. logging.error('cannot parse message:{}'.format(message))
  126. return
  127. def __tr_core_on_open(self):
  128. logging.debug('__tr_core_on_open')
  129. def __tr_core_on_msg(self, msg, *args):
  130. logging.debug('__tr_core_on_msg:msg={} args={}'.format(msg, args))
  131. self.__handle_message(msg)
  132. def __tr_core_on_error(self, msg, *args):
  133. logging.debug('__tr_core_on_error:msg={} args={}'.format(msg, args))
  134. def __tr_core_on_close(self):
  135. logging.debug('__tr_core_on_close')
  136. if self.__on_close:
  137. self.__on_close(*self.__callback_args)
  138. with self.__start_cond:
  139. self.__start_flag = False
  140. self.__start_cond.notify()
  141. def __sentence_begin(self, message):
  142. logging.debug('__sentence_begin')
  143. if self.__on_sentence_begin:
  144. self.__on_sentence_begin(message, *self.__callback_args)
  145. def __sentence_end(self, message):
  146. logging.debug('__sentence_end')
  147. if self.__on_sentence_end:
  148. self.__on_sentence_end(message, *self.__callback_args)
  149. def __transcription_started(self, message):
  150. logging.debug('__transcription_started')
  151. if self.__on_start:
  152. self.__on_start(message, *self.__callback_args)
  153. with self.__start_cond:
  154. self.__start_flag = True
  155. self.__start_cond.notify()
  156. def __transcription_result_changed(self, message):
  157. logging.debug('__transcription_result_changed')
  158. if self.__on_result_changed:
  159. self.__on_result_changed(message, *self.__callback_args)
  160. def __transcription_completed(self, message):
  161. logging.debug('__transcription_completed')
  162. self.__nls.shutdown()
  163. logging.debug('__transcription_completed shutdown done')
  164. if self.__on_completed:
  165. self.__on_completed(message, *self.__callback_args)
  166. with self.__start_cond:
  167. self.__start_flag = False
  168. self.__start_cond.notify()
  169. def __task_failed(self, message):
  170. logging.debug('__task_failed')
  171. with self.__start_cond:
  172. self.__start_flag = False
  173. self.__start_cond.notify()
  174. if self.__on_error:
  175. self.__on_error(message, *self.__callback_args)
  176. def start(self, aformat='pcm', sample_rate=16000, ch=1,
  177. enable_intermediate_result=False,
  178. enable_punctuation_prediction=False,
  179. enable_inverse_text_normalization=False,
  180. timeout=10,
  181. ping_interval=8,
  182. ping_timeout=None,
  183. ex:dict=None):
  184. """
  185. Transcription start
  186. Parameters:
  187. -----------
  188. aformat: str
  189. audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm'
  190. sample_rate: int
  191. audio sample rate, default is 16000
  192. ch: int
  193. audio channels, only support mono which is 1
  194. enable_intermediate_result: bool
  195. whether enable return intermediate recognition result, default is False
  196. enable_punctuation_prediction: bool
  197. whether enable punctuation prediction, default is False
  198. enable_inverse_text_normalization: bool
  199. whether enable ITN, default is False
  200. timeout: int
  201. wait timeout for connection setup
  202. ping_interval: int
  203. send ping interval, 0 for disable ping send, default is 8
  204. ping_timeout: int
  205. timeout after send ping and recive pong, set None for disable timeout check and default is None
  206. ex: dict
  207. dict which will merge into 'payload' field in request
  208. """
  209. self.__nls = NlsCore(
  210. url=self.__url,
  211. token=self.__token,
  212. on_open=self.__tr_core_on_open,
  213. on_message=self.__tr_core_on_msg,
  214. on_close=self.__tr_core_on_close,
  215. on_error=self.__tr_core_on_error,
  216. callback_args=[])
  217. if ch != 1:
  218. raise ValueError('not support channel: {}'.format(ch))
  219. if aformat not in self.__allow_aformat:
  220. raise ValueError('format {} not support'.format(aformat))
  221. __id4 = uuid.uuid4().hex
  222. self.__task_id = uuid.uuid4().hex
  223. __header = {
  224. 'message_id': __id4,
  225. 'task_id': self.__task_id,
  226. 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__,
  227. 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['start'],
  228. 'appkey': self.__appkey
  229. }
  230. __payload = {
  231. 'format': aformat,
  232. 'sample_rate': sample_rate,
  233. 'enable_intermediate_result': enable_intermediate_result,
  234. 'enable_punctuation_prediction': enable_punctuation_prediction,
  235. 'enable_inverse_text_normalization': enable_inverse_text_normalization
  236. }
  237. if ex:
  238. __payload.update(ex)
  239. __msg = {
  240. 'header': __header,
  241. 'payload': __payload,
  242. 'context': util.GetDefaultContext()
  243. }
  244. __jmsg = json.dumps(__msg)
  245. with self.__start_cond:
  246. if self.__start_flag:
  247. logging.debug('already start...')
  248. return
  249. self.__nls.start(__jmsg, ping_interval, ping_timeout)
  250. if self.__start_flag == False:
  251. if self.__start_cond.wait(timeout):
  252. return
  253. else:
  254. raise StartTimeoutException(f'Waiting Start over {timeout}s')
  255. def stop(self, timeout=10):
  256. """
  257. Stop transcription and mark session finished
  258. Parameters:
  259. -----------
  260. timeout: int
  261. timeout for waiting completed message from cloud
  262. """
  263. __id4 = uuid.uuid4().hex
  264. __header = {
  265. 'message_id': __id4,
  266. 'task_id': self.__task_id,
  267. 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__,
  268. 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['stop'],
  269. 'appkey': self.__appkey
  270. }
  271. __msg = {
  272. 'header': __header,
  273. 'context': util.GetDefaultContext()
  274. }
  275. __jmsg = json.dumps(__msg)
  276. with self.__start_cond:
  277. if not self.__start_flag:
  278. logging.debug('not start yet...')
  279. return
  280. self.__nls.send(__jmsg, False)
  281. if self.__start_flag == True:
  282. logging.debug('stop wait..')
  283. if self.__start_cond.wait(timeout):
  284. return
  285. else:
  286. raise StopTimeoutException(f'Waiting stop over {timeout}s')
  287. def ctrl(self, **kwargs):
  288. """
  289. Send control message to cloud
  290. Parameters:
  291. -----------
  292. kwargs: dict
  293. dict which will merge into 'payload' field in request
  294. """
  295. if not kwargs:
  296. raise InvalidParameter('Empty kwargs not allowed!')
  297. __id4 = uuid.uuid4().hex
  298. __header = {
  299. 'message_id': __id4,
  300. 'task_id': self.__task_id,
  301. 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__,
  302. 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['control'],
  303. 'appkey': self.__appkey
  304. }
  305. payload = {}
  306. payload.update(kwargs)
  307. __msg = {
  308. 'header': __header,
  309. 'payload': payload,
  310. 'context': util.GetDefaultContext()
  311. }
  312. __jmsg = json.dumps(__msg)
  313. with self.__start_cond:
  314. if not self.__start_flag:
  315. logging.debug('not start yet...')
  316. return
  317. self.__nls.send(__jmsg, False)
  318. def shutdown(self):
  319. """
  320. Shutdown connection immediately
  321. """
  322. self.__nls.shutdown()
  323. def send_audio(self, pcm_data):
  324. """
  325. Send audio binary, audio size prefer 20ms length
  326. Parameters:
  327. -----------
  328. pcm_data: bytes
  329. audio binary which format is 'aformat' in start method
  330. """
  331. __data = pcm_data
  332. with self.__start_cond:
  333. if not self.__start_flag:
  334. return
  335. try:
  336. self.__nls.send(__data, True)
  337. except ConnectionResetError as __e:
  338. logging.error('connection reset')
  339. self.__start_flag = False
  340. self.__nls.shutdown()
  341. raise __e