# Copyright (c) Alibaba, Inc. and its affiliates. import logging import uuid import json import threading from nls.core import NlsCore from . import logging from . import util from nls.exception import (StartTimeoutException, StopTimeoutException, NotStartException, InvalidParameter) __SPEECH_TRANSCRIBER_NAMESPACE__ = 'SpeechTranscriber' __SPEECH_TRANSCRIBER_REQUEST_CMD__ = { 'start': 'StartTranscription', 'stop': 'StopTranscription', 'control': 'ControlTranscriber' } __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' __all__ = ['NlsSpeechTranscriber'] class NlsSpeechTranscriber: """ Api for realtime speech transcription """ def __init__(self, url=__URL__, token=None, appkey=None, on_start=None, on_sentence_begin=None, on_sentence_end=None, on_result_changed=None, on_completed=None, on_error=None, on_close=None, callback_args=[]): ''' NlsSpeechTranscriber initialization Parameters: ----------- url: str websocket url. token: str access token. if you do not have a token, provide access id and key secret from your aliyun account. appkey: str appkey from aliyun on_start: function Callback object which is called when recognition started. on_start has two arguments. The 1st argument is message which is a json format string. The 2nd argument is *args which is callback_args. on_sentence_begin: function Callback object which is called when one sentence started. on_sentence_begin has two arguments. The 1st argument is message which is a json format string. The 2nd argument is *args which is callback_args. on_sentence_end: function Callback object which is called when sentence is end. on_sentence_end has two arguments. The 1st argument is message which is a json format string. The 2nd argument is *args which is callback_args. on_result_changed: function Callback object which is called when partial recognition result arrived. on_result_changed has two arguments. The 1st argument is message which is a json format string. The 2nd argument is *args which is callback_args. on_completed: function Callback object which is called when recognition is completed. on_completed has two arguments. The 1st argument is message which is a json format string. The 2nd argument is *args which is callback_args. on_error: function Callback object which is called when any error occurs. on_error has two arguments. The 1st argument is message which is a json format string. The 2nd argument is *args which is callback_args. on_close: function Callback object which is called when connection closed. on_close has one arguments. The 1st argument is *args which is callback_args. callback_args: list callback_args will return in callbacks above for *args. ''' if not token or not appkey: raise InvalidParameter('Must provide token and appkey') self.__response_handler__ = { 'SentenceBegin': self.__sentence_begin, 'SentenceEnd': self.__sentence_end, 'TranscriptionStarted': self.__transcription_started, 'TranscriptionResultChanged': self.__transcription_result_changed, 'TranscriptionCompleted': self.__transcription_completed, 'TaskFailed': self.__task_failed } self.__callback_args = callback_args self.__url = url self.__appkey = appkey self.__token = token self.__start_cond = threading.Condition() self.__start_flag = False self.__on_start = on_start self.__on_sentence_begin = on_sentence_begin self.__on_sentence_end = on_sentence_end self.__on_result_changed = on_result_changed self.__on_completed = on_completed self.__on_error = on_error self.__on_close = on_close self.__allow_aformat = ( 'pcm', 'opus', 'opu', 'wav', 'amr', 'speex', 'mp3', 'aac' ) def __handle_message(self, message): logging.debug('__handle_message') try: __result = json.loads(message) if __result['header']['name'] in self.__response_handler__: __handler = self.__response_handler__[ __result['header']['name']] __handler(message) else: logging.error('cannot handle cmd{}'.format( __result['header']['name'])) return except json.JSONDecodeError: logging.error('cannot parse message:{}'.format(message)) return def __tr_core_on_open(self): logging.debug('__tr_core_on_open') def __tr_core_on_msg(self, msg, *args): logging.debug('__tr_core_on_msg:msg={} args={}'.format(msg, args)) self.__handle_message(msg) def __tr_core_on_error(self, msg, *args): logging.debug('__tr_core_on_error:msg={} args={}'.format(msg, args)) def __tr_core_on_close(self): logging.debug('__tr_core_on_close') if self.__on_close: self.__on_close(*self.__callback_args) with self.__start_cond: self.__start_flag = False self.__start_cond.notify() def __sentence_begin(self, message): logging.debug('__sentence_begin') if self.__on_sentence_begin: self.__on_sentence_begin(message, *self.__callback_args) def __sentence_end(self, message): logging.debug('__sentence_end') if self.__on_sentence_end: self.__on_sentence_end(message, *self.__callback_args) def __transcription_started(self, message): logging.debug('__transcription_started') if self.__on_start: self.__on_start(message, *self.__callback_args) with self.__start_cond: self.__start_flag = True self.__start_cond.notify() def __transcription_result_changed(self, message): logging.debug('__transcription_result_changed') if self.__on_result_changed: self.__on_result_changed(message, *self.__callback_args) def __transcription_completed(self, message): logging.debug('__transcription_completed') self.__nls.shutdown() logging.debug('__transcription_completed shutdown done') if self.__on_completed: self.__on_completed(message, *self.__callback_args) with self.__start_cond: self.__start_flag = False self.__start_cond.notify() def __task_failed(self, message): logging.debug('__task_failed') with self.__start_cond: self.__start_flag = False self.__start_cond.notify() if self.__on_error: self.__on_error(message, *self.__callback_args) def start(self, aformat='pcm', sample_rate=16000, ch=1, enable_intermediate_result=False, enable_punctuation_prediction=False, enable_inverse_text_normalization=False, timeout=10, ping_interval=8, ping_timeout=None, ex:dict=None): """ Transcription start Parameters: ----------- aformat: str audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm' sample_rate: int audio sample rate, default is 16000 ch: int audio channels, only support mono which is 1 enable_intermediate_result: bool whether enable return intermediate recognition result, default is False enable_punctuation_prediction: bool whether enable punctuation prediction, default is False enable_inverse_text_normalization: bool whether enable ITN, default is False timeout: int wait timeout for connection setup ping_interval: int send ping interval, 0 for disable ping send, default is 8 ping_timeout: int timeout after send ping and recive pong, set None for disable timeout check and default is None ex: dict dict which will merge into 'payload' field in request """ self.__nls = NlsCore( url=self.__url, token=self.__token, on_open=self.__tr_core_on_open, on_message=self.__tr_core_on_msg, on_close=self.__tr_core_on_close, on_error=self.__tr_core_on_error, callback_args=[]) if ch != 1: raise ValueError('not support channel: {}'.format(ch)) if aformat not in self.__allow_aformat: raise ValueError('format {} not support'.format(aformat)) __id4 = uuid.uuid4().hex self.__task_id = uuid.uuid4().hex __header = { 'message_id': __id4, 'task_id': self.__task_id, 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['start'], 'appkey': self.__appkey } __payload = { 'format': aformat, 'sample_rate': sample_rate, 'enable_intermediate_result': enable_intermediate_result, 'enable_punctuation_prediction': enable_punctuation_prediction, 'enable_inverse_text_normalization': enable_inverse_text_normalization } if ex: __payload.update(ex) __msg = { 'header': __header, 'payload': __payload, 'context': util.GetDefaultContext() } __jmsg = json.dumps(__msg) with self.__start_cond: if self.__start_flag: logging.debug('already start...') return self.__nls.start(__jmsg, ping_interval, ping_timeout) if self.__start_flag == False: if self.__start_cond.wait(timeout): return else: raise StartTimeoutException(f'Waiting Start over {timeout}s') def stop(self, timeout=10): """ Stop transcription and mark session finished Parameters: ----------- timeout: int timeout for waiting completed message from cloud """ __id4 = uuid.uuid4().hex __header = { 'message_id': __id4, 'task_id': self.__task_id, 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['stop'], 'appkey': self.__appkey } __msg = { 'header': __header, 'context': util.GetDefaultContext() } __jmsg = json.dumps(__msg) with self.__start_cond: if not self.__start_flag: logging.debug('not start yet...') return self.__nls.send(__jmsg, False) if self.__start_flag == True: logging.debug('stop wait..') if self.__start_cond.wait(timeout): return else: raise StopTimeoutException(f'Waiting stop over {timeout}s') def ctrl(self, **kwargs): """ Send control message to cloud Parameters: ----------- kwargs: dict dict which will merge into 'payload' field in request """ if not kwargs: raise InvalidParameter('Empty kwargs not allowed!') __id4 = uuid.uuid4().hex __header = { 'message_id': __id4, 'task_id': self.__task_id, 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['control'], 'appkey': self.__appkey } payload = {} payload.update(kwargs) __msg = { 'header': __header, 'payload': payload, 'context': util.GetDefaultContext() } __jmsg = json.dumps(__msg) with self.__start_cond: if not self.__start_flag: logging.debug('not start yet...') return self.__nls.send(__jmsg, False) def shutdown(self): """ Shutdown connection immediately """ self.__nls.shutdown() def send_audio(self, pcm_data): """ Send audio binary, audio size prefer 20ms length Parameters: ----------- pcm_data: bytes audio binary which format is 'aformat' in start method """ __data = pcm_data with self.__start_cond: if not self.__start_flag: return try: self.__nls.send(__data, True) except ConnectionResetError as __e: logging.error('connection reset') self.__start_flag = False self.__nls.shutdown() raise __e