# Copyright (c) Alibaba, Inc. and its affiliates.

import logging
import uuid
import json
import threading

from nls.core import NlsCore
from . import logging
from . import util
from nls.exception import (StartTimeoutException,
                        StopTimeoutException,
                        NotStartException,
                        InvalidParameter)

__SPEECH_TRANSCRIBER_NAMESPACE__ = 'SpeechTranscriber'

__SPEECH_TRANSCRIBER_REQUEST_CMD__ = {
    'start': 'StartTranscription',
    'stop': 'StopTranscription',
    'control': 'ControlTranscriber'
}

__URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1'
__all__ = ['NlsSpeechTranscriber']


class NlsSpeechTranscriber:
    """
    Api for realtime speech transcription
    """

    def __init__(self, 
                 url=__URL__,
                 token=None,
                 appkey=None,
                 on_start=None,
                 on_sentence_begin=None,
                 on_sentence_end=None,
                 on_result_changed=None,
                 on_completed=None,
                 on_error=None,
                 on_close=None,
                 callback_args=[]):
        '''
        NlsSpeechTranscriber initialization

        Parameters:
        -----------
        url: str
            websocket url.
        token: str
            access token. if you do not have a token, provide access id and key
            secret from your aliyun account.
        appkey: str
            appkey from aliyun
        on_start: function
            Callback object which is called when recognition started.
            on_start has two arguments.
            The 1st argument is message which is a json format string.
            The 2nd argument is *args which is callback_args.
        on_sentence_begin: function
            Callback object which is called when one sentence started.
            on_sentence_begin has two arguments.
            The 1st argument is message which is a json format string.
            The 2nd argument is *args which is callback_args.
        on_sentence_end: function
            Callback object which is called when sentence is end.
            on_sentence_end has two arguments.
            The 1st argument is message which is a json format string.
            The 2nd argument is *args which is callback_args.
        on_result_changed: function
            Callback object which is called when partial recognition result
            arrived.
            on_result_changed has two arguments.
            The 1st argument is message which is a json format string.
            The 2nd argument is *args which is callback_args.
        on_completed: function
            Callback object which is called when recognition is completed.
            on_completed has two arguments.
            The 1st argument is message which is a json format string.
            The 2nd argument is *args which is callback_args.
        on_error: function
            Callback object which is called when any error occurs.
            on_error has two arguments.
            The 1st argument is message which is a json format string.
            The 2nd argument is *args which is callback_args.
        on_close: function
            Callback object which is called when connection closed.
            on_close has one arguments.
            The 1st argument is *args which is callback_args.
        callback_args: list
            callback_args will return in callbacks above for *args.
        '''
        if not token or not appkey:
            raise InvalidParameter('Must provide token and appkey')
        self.__response_handler__ = {
            'SentenceBegin': self.__sentence_begin,
            'SentenceEnd': self.__sentence_end,
            'TranscriptionStarted': self.__transcription_started,
            'TranscriptionResultChanged': self.__transcription_result_changed,
            'TranscriptionCompleted': self.__transcription_completed,
            'TaskFailed': self.__task_failed
        }
        self.__callback_args = callback_args
        self.__url = url
        self.__appkey = appkey
        self.__token = token
        self.__start_cond = threading.Condition()
        self.__start_flag = False
        self.__on_start = on_start
        self.__on_sentence_begin = on_sentence_begin
        self.__on_sentence_end = on_sentence_end
        self.__on_result_changed = on_result_changed
        self.__on_completed = on_completed
        self.__on_error = on_error
        self.__on_close = on_close
        self.__allow_aformat = (
            'pcm', 'opus', 'opu', 'wav', 'amr', 'speex', 'mp3', 'aac'
        )

    def __handle_message(self, message):
        logging.debug('__handle_message')
        try:
            __result = json.loads(message)
            if __result['header']['name'] in self.__response_handler__:
                __handler = self.__response_handler__[
                    __result['header']['name']]
                __handler(message)
            else:
                logging.error('cannot handle cmd{}'.format(
                    __result['header']['name']))
                return
        except json.JSONDecodeError:
            logging.error('cannot parse message:{}'.format(message))
            return

    def __tr_core_on_open(self):
        logging.debug('__tr_core_on_open')

    def __tr_core_on_msg(self, msg, *args):
        logging.debug('__tr_core_on_msg:msg={} args={}'.format(msg, args))
        self.__handle_message(msg)

    def __tr_core_on_error(self, msg, *args):
        logging.debug('__tr_core_on_error:msg={} args={}'.format(msg, args))

    def __tr_core_on_close(self):
        logging.debug('__tr_core_on_close')
        if self.__on_close:
            self.__on_close(*self.__callback_args)
        with self.__start_cond:
            self.__start_flag = False
            self.__start_cond.notify()

    def __sentence_begin(self, message):
        logging.debug('__sentence_begin')
        if self.__on_sentence_begin:
            self.__on_sentence_begin(message, *self.__callback_args)

    def __sentence_end(self, message):
        logging.debug('__sentence_end')
        if self.__on_sentence_end:
            self.__on_sentence_end(message, *self.__callback_args)

    def __transcription_started(self, message):
        logging.debug('__transcription_started')
        if self.__on_start:
            self.__on_start(message, *self.__callback_args)
        with self.__start_cond:
            self.__start_flag = True
            self.__start_cond.notify()

    def __transcription_result_changed(self, message):
        logging.debug('__transcription_result_changed')
        if self.__on_result_changed:
            self.__on_result_changed(message, *self.__callback_args)

    def __transcription_completed(self, message):
        logging.debug('__transcription_completed')
        self.__nls.shutdown()
        logging.debug('__transcription_completed shutdown done')
        if self.__on_completed:
            self.__on_completed(message, *self.__callback_args)
        with self.__start_cond:
            self.__start_flag = False
            self.__start_cond.notify()

    def __task_failed(self, message):
        logging.debug('__task_failed')
        with self.__start_cond:
            self.__start_flag = False
            self.__start_cond.notify()
        if self.__on_error:
            self.__on_error(message, *self.__callback_args)

    def start(self, aformat='pcm', sample_rate=16000, ch=1,
              enable_intermediate_result=False,
              enable_punctuation_prediction=False,
              enable_inverse_text_normalization=False,
              timeout=10,
              ping_interval=8,
              ping_timeout=None,
              ex:dict=None):
        """
        Transcription start 

        Parameters:
        -----------
        aformat: str
            audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm'
        sample_rate: int
            audio sample rate, default is 16000
        ch: int
            audio channels, only support mono which is 1
        enable_intermediate_result: bool
            whether enable return intermediate recognition result, default is False
        enable_punctuation_prediction: bool
            whether enable punctuation prediction, default is False
        enable_inverse_text_normalization: bool
            whether enable ITN, default is False
        timeout: int
            wait timeout for connection setup
        ping_interval: int
            send ping interval, 0 for disable ping send, default is 8
        ping_timeout: int
            timeout after send ping and recive pong, set None for disable timeout check and default is None
        ex: dict
            dict which will merge into 'payload' field in request
        """
        self.__nls = NlsCore(
            url=self.__url, 
            token=self.__token,
            on_open=self.__tr_core_on_open,
            on_message=self.__tr_core_on_msg,
            on_close=self.__tr_core_on_close,
            on_error=self.__tr_core_on_error,
            callback_args=[])

        if ch != 1:
            raise ValueError('not support channel: {}'.format(ch))
        if aformat not in self.__allow_aformat:
            raise ValueError('format {} not support'.format(aformat))
        __id4 = uuid.uuid4().hex
        self.__task_id = uuid.uuid4().hex
        __header = {
            'message_id': __id4,
            'task_id': self.__task_id,
            'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__,
            'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['start'],
            'appkey': self.__appkey
        }
        __payload = {
            'format': aformat,
            'sample_rate': sample_rate,
            'enable_intermediate_result': enable_intermediate_result,
            'enable_punctuation_prediction': enable_punctuation_prediction,
            'enable_inverse_text_normalization': enable_inverse_text_normalization
        }

        if ex:
            __payload.update(ex)

        __msg = {
            'header': __header,
            'payload': __payload,
            'context': util.GetDefaultContext()
        }
        __jmsg = json.dumps(__msg)
        with self.__start_cond:
            if self.__start_flag:
                logging.debug('already start...')
                return
            self.__nls.start(__jmsg, ping_interval, ping_timeout)
            if self.__start_flag == False:
                if self.__start_cond.wait(timeout):
                    return
                else:
                    raise StartTimeoutException(f'Waiting Start over {timeout}s')

    def stop(self, timeout=10):
        """
        Stop transcription and mark session finished

        Parameters:
        -----------
        timeout: int
            timeout for waiting completed message from cloud
        """
        __id4 = uuid.uuid4().hex
        __header = {
            'message_id': __id4,
            'task_id': self.__task_id,
            'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__,
            'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['stop'],
            'appkey': self.__appkey
        }
        __msg = {
            'header': __header,
            'context': util.GetDefaultContext()
        }
        __jmsg = json.dumps(__msg)
        with self.__start_cond:
            if not self.__start_flag:
                logging.debug('not start yet...')
                return
            self.__nls.send(__jmsg, False)
            if self.__start_flag == True:
                logging.debug('stop wait..')
                if self.__start_cond.wait(timeout):
                    return
                else:
                    raise StopTimeoutException(f'Waiting stop over {timeout}s')

    def ctrl(self, **kwargs):
        """
        Send control message to cloud

        Parameters:
        -----------
        kwargs: dict
            dict which will merge into 'payload' field in request
        """
        if not kwargs:
            raise InvalidParameter('Empty kwargs not allowed!')
        __id4 = uuid.uuid4().hex
        __header = {
            'message_id': __id4,
            'task_id': self.__task_id,
            'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__,
            'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['control'],
            'appkey': self.__appkey
        }
        payload = {}
        payload.update(kwargs)
        __msg = {
            'header': __header,
            'payload': payload,
            'context': util.GetDefaultContext()
        }
        __jmsg = json.dumps(__msg)
        with self.__start_cond:
            if not self.__start_flag:
                logging.debug('not start yet...')
                return
            self.__nls.send(__jmsg, False)

    def shutdown(self):
        """
        Shutdown connection immediately
        """
        self.__nls.shutdown()

    def send_audio(self, pcm_data):
        """
        Send audio binary, audio size prefer 20ms length 

        Parameters:
        -----------
        pcm_data: bytes
            audio binary which format is 'aformat' in start method 
        """

        __data = pcm_data
        with self.__start_cond:
            if not self.__start_flag:
                return
        try:
            self.__nls.send(__data, True)
        except ConnectionResetError as __e:
            logging.error('connection reset')
            self.__start_flag = False
            self.__nls.shutdown()
            raise __e