stream_input_tts.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import logging
  3. import uuid
  4. import json
  5. import threading
  6. from enum import IntEnum
  7. from nls.core import NlsCore
  8. from . import logging
  9. from .exception import StartTimeoutException, WrongStateException, InvalidParameter
  10. __STREAM_INPUT_TTS_NAMESPACE__ = "FlowingSpeechSynthesizer"
  11. __STREAM_INPUT_TTS_REQUEST_CMD__ = {
  12. "start": "StartSynthesis",
  13. "send": "RunSynthesis",
  14. "stop": "StopSynthesis",
  15. }
  16. __STREAM_INPUT_TTS_REQUEST_NAME__ = {
  17. "started": "SynthesisStarted",
  18. "sentence_begin": "SentenceBegin",
  19. "sentence_synthesis": "SentenceSynthesis",
  20. "sentence_end": "SentenceEnd",
  21. "completed": "SynthesisCompleted",
  22. "task_failed": "TaskFailed",
  23. }
  24. __URL__ = "wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1"
  25. __all__ = ["NlsStreamInputTtsSynthesizer"]
  26. class NlsStreamInputTtsRequest:
  27. def __init__(self, task_id, session_id, appkey):
  28. self.task_id = task_id
  29. self.appkey = appkey
  30. self.session_id = session_id
  31. def getStartCMD(self, voice, format, sample_rate, volumn, speech_rate, pitch_rate, ex):
  32. self.voice = voice
  33. self.format = format
  34. self.sample_rate = sample_rate
  35. self.volumn = volumn
  36. self.speech_rate = speech_rate
  37. self.pitch_rate = pitch_rate
  38. cmd = {
  39. "header": {
  40. "message_id": uuid.uuid4().hex,
  41. "task_id": self.task_id,
  42. "name": __STREAM_INPUT_TTS_REQUEST_CMD__["start"],
  43. "namespace": __STREAM_INPUT_TTS_NAMESPACE__,
  44. "appkey": self.appkey,
  45. },
  46. "payload": {
  47. "session_id": self.session_id,
  48. "voice": self.voice,
  49. "format": self.format,
  50. "sample_rate": self.sample_rate,
  51. "volumn": self.volumn,
  52. "speech_rate": self.speech_rate,
  53. "pitch_rate": self.pitch_rate,
  54. },
  55. }
  56. if ex:
  57. cmd["payload"].update(ex)
  58. return json.dumps(cmd)
  59. def getSendCMD(self, text):
  60. cmd = {
  61. "header": {
  62. "message_id": uuid.uuid4().hex,
  63. "task_id": self.task_id,
  64. "name": __STREAM_INPUT_TTS_REQUEST_CMD__["send"],
  65. "namespace": __STREAM_INPUT_TTS_NAMESPACE__,
  66. "appkey": self.appkey,
  67. },
  68. "payload": {"text": text},
  69. }
  70. return json.dumps(cmd)
  71. def getStopCMD(self):
  72. cmd = {
  73. "header": {
  74. "message_id": uuid.uuid4().hex,
  75. "task_id": self.task_id,
  76. "name": __STREAM_INPUT_TTS_REQUEST_CMD__["stop"],
  77. "namespace": __STREAM_INPUT_TTS_NAMESPACE__,
  78. "appkey": self.appkey,
  79. },
  80. }
  81. return json.dumps(cmd)
  82. class NlsStreamInputTtsStatus(IntEnum):
  83. Begin = 1
  84. Start = 2
  85. Started = 3
  86. WaitingComplete = 3
  87. Completed = 4
  88. Failed = 5
  89. Closed = 6
  90. class ThreadSafeStatus:
  91. def __init__(self, state: NlsStreamInputTtsStatus):
  92. self._state = state
  93. self._lock = threading.Lock()
  94. def get(self) -> NlsStreamInputTtsStatus:
  95. with self._lock:
  96. return self._state
  97. def set(self, state: NlsStreamInputTtsStatus):
  98. with self._lock:
  99. self._state = state
  100. class NlsStreamInputTtsSynthesizer:
  101. """
  102. Api for text-to-speech
  103. """
  104. def __init__(
  105. self,
  106. url=__URL__,
  107. token=None,
  108. appkey=None,
  109. session_id=None,
  110. on_data=None,
  111. on_sentence_begin=None,
  112. on_sentence_synthesis=None,
  113. on_sentence_end=None,
  114. on_completed=None,
  115. on_error=None,
  116. on_close=None,
  117. callback_args=[],
  118. ):
  119. """
  120. NlsSpeechSynthesizer initialization
  121. Parameters:
  122. -----------
  123. url: str
  124. websocket url.
  125. akid: str
  126. access id from aliyun. if you provide a token, ignore this argument.
  127. appkey: str
  128. appkey from aliyun
  129. session_id: str
  130. 32-character string, if empty, sdk will generate a random string.
  131. on_data: function
  132. Callback object which is called when partial synthesis result arrived
  133. arrived.
  134. on_result_changed has two arguments.
  135. The 1st argument is binary data corresponding to aformat in start
  136. method.
  137. The 2nd argument is *args which is callback_args.
  138. on_sentence_begin: function
  139. Callback object which is called when detected sentence start.
  140. on_start has two arguments.
  141. The 1st argument is message which is a json format string.
  142. The 2nd argument is *args which is callback_args.
  143. on_sentence_synthesis: function
  144. Callback object which is called when detected sentence synthesis.
  145. The incremental timestamp is returned within payload.
  146. on_start has two arguments.
  147. The 1st argument is message which is a json format string.
  148. The 2nd argument is *args which is callback_args.
  149. on_sentence_end: function
  150. Callback object which is called when detected sentence end.
  151. The timestamp of the whole sentence is returned within payload.
  152. on_start has two arguments.
  153. The 1st argument is message which is a json format string.
  154. The 2nd argument is *args which is callback_args.
  155. on_completed: function
  156. Callback object which is called when recognition is completed.
  157. on_completed has two arguments.
  158. The 1st argument is message which is a json format string.
  159. The 2nd argument is *args which is callback_args.
  160. on_error: function
  161. Callback object which is called when any error occurs.
  162. on_error has two arguments.
  163. The 1st argument is message which is a json format string.
  164. The 2nd argument is *args which is callback_args.
  165. on_close: function
  166. Callback object which is called when connection closed.
  167. on_close has one arguments.
  168. The 1st argument is *args which is callback_args.
  169. callback_args: list
  170. callback_args will return in callbacks above for *args.
  171. """
  172. if not token or not appkey:
  173. raise InvalidParameter("Must provide token and appkey")
  174. self.__response_handler__ = {
  175. __STREAM_INPUT_TTS_REQUEST_NAME__["started"]: self.__synthesis_started,
  176. __STREAM_INPUT_TTS_REQUEST_NAME__["sentence_begin"]: self.__sentence_begin,
  177. __STREAM_INPUT_TTS_REQUEST_NAME__[
  178. "sentence_synthesis"
  179. ]: self.__sentence_synthesis,
  180. __STREAM_INPUT_TTS_REQUEST_NAME__["sentence_end"]: self.__sentence_end,
  181. __STREAM_INPUT_TTS_REQUEST_NAME__["completed"]: self.__synthesis_completed,
  182. __STREAM_INPUT_TTS_REQUEST_NAME__["task_failed"]: self.__task_failed,
  183. }
  184. self.__callback_args = callback_args
  185. self.__url = url
  186. self.__appkey = appkey
  187. self.__token = token
  188. self.__session_id = session_id
  189. self.start_sended = threading.Event()
  190. self.started_event = threading.Event()
  191. self.complete_event = threading.Event()
  192. self.__on_sentence_begin = on_sentence_begin
  193. self.__on_sentence_synthesis = on_sentence_synthesis
  194. self.__on_sentence_end = on_sentence_end
  195. self.__on_data = on_data
  196. self.__on_completed = on_completed
  197. self.__on_error = on_error
  198. self.__on_close = on_close
  199. self.__allow_aformat = ("pcm", "wav", "mp3")
  200. self.__allow_sample_rate = (
  201. 8000,
  202. 11025,
  203. 16000,
  204. 22050,
  205. 24000,
  206. 32000,
  207. 44100,
  208. 48000,
  209. )
  210. self.state = ThreadSafeStatus(NlsStreamInputTtsStatus.Begin)
  211. if not self.__session_id:
  212. self.__session_id = uuid.uuid4().hex
  213. self.request = NlsStreamInputTtsRequest(
  214. uuid.uuid4().hex, self.__session_id, self.__appkey
  215. )
  216. def __handle_message(self, message):
  217. logging.debug("__handle_message")
  218. try:
  219. __result = json.loads(message)
  220. if __result["header"]["name"] in self.__response_handler__:
  221. __handler = self.__response_handler__[__result["header"]["name"]]
  222. __handler(message)
  223. else:
  224. logging.error("cannot handle cmd{}".format(__result["header"]["name"]))
  225. return
  226. except json.JSONDecodeError:
  227. logging.error("cannot parse message:{}".format(message))
  228. return
  229. def __syn_core_on_open(self):
  230. logging.debug("__syn_core_on_open")
  231. self.start_sended.set()
  232. def __syn_core_on_data(self, data, opcode, flag):
  233. logging.debug("__syn_core_on_data")
  234. if self.__on_data:
  235. self.__on_data(data, *self.__callback_args)
  236. def __syn_core_on_msg(self, msg, *args):
  237. logging.debug("__syn_core_on_msg:msg={} args={}".format(msg, args))
  238. self.__handle_message(msg)
  239. def __syn_core_on_error(self, msg, *args):
  240. logging.debug("__sr_core_on_error:msg={} args={}".format(msg, args))
  241. def __syn_core_on_close(self):
  242. logging.debug("__sr_core_on_close")
  243. if self.__on_close:
  244. self.__on_close(*self.__callback_args)
  245. self.state.set(NlsStreamInputTtsStatus.Closed)
  246. self.start_sended.set()
  247. self.started_event.set()
  248. self.complete_event.set()
  249. def __synthesis_started(self, message):
  250. logging.debug("__synthesis_started")
  251. self.started_event.set()
  252. def __sentence_begin(self, message):
  253. logging.debug("__sentence_begin")
  254. if self.__on_sentence_begin:
  255. self.__on_sentence_begin(message, *self.__callback_args)
  256. def __sentence_synthesis(self, message):
  257. logging.debug("__sentence_synthesis")
  258. if self.__on_sentence_synthesis:
  259. self.__on_sentence_synthesis(message, *self.__callback_args)
  260. def __sentence_end(self, message):
  261. logging.debug("__sentence_end")
  262. if self.__on_sentence_end:
  263. self.__on_sentence_end(message, *self.__callback_args)
  264. def __synthesis_completed(self, message):
  265. logging.debug("__synthesis_completed")
  266. if self.__on_completed:
  267. self.__on_completed(message, *self.__callback_args)
  268. self.__nls.shutdown()
  269. logging.debug("__synthesis_completed shutdown done")
  270. self.complete_event.set()
  271. def __task_failed(self, message):
  272. logging.debug("__task_failed")
  273. self.start_sended.set()
  274. self.started_event.set()
  275. self.complete_event.set()
  276. if self.__on_error:
  277. self.__on_error(message, *self.__callback_args)
  278. self.state.set(NlsStreamInputTtsStatus.Failed)
  279. def startStreamInputTts(
  280. self,
  281. voice="longxiaochun",
  282. aformat="pcm",
  283. sample_rate=24000,
  284. volume=50,
  285. speech_rate=0,
  286. pitch_rate=0,
  287. ex:dict=None,
  288. ):
  289. """
  290. Synthesis start
  291. Parameters:
  292. -----------
  293. voice: str
  294. voice for text-to-speech, default is xiaoyun
  295. aformat: str
  296. audio binary format, support: 'pcm', 'wav', 'mp3', default is 'pcm'
  297. sample_rate: int
  298. audio sample rate, default is 24000, support:8000, 11025, 16000, 22050,
  299. 24000, 32000, 44100, 48000
  300. volume: int
  301. audio volume, from 0~100, default is 50
  302. speech_rate: int
  303. speech rate from -500~500, default is 0
  304. pitch_rate: int
  305. pitch for voice from -500~500, default is 0
  306. ex: dict
  307. dict which will merge into 'payload' field in request
  308. """
  309. self.__nls = NlsCore(
  310. url=self.__url,
  311. token=self.__token,
  312. on_open=self.__syn_core_on_open,
  313. on_message=self.__syn_core_on_msg,
  314. on_data=self.__syn_core_on_data,
  315. on_close=self.__syn_core_on_close,
  316. on_error=self.__syn_core_on_error,
  317. callback_args=[],
  318. )
  319. if aformat not in self.__allow_aformat:
  320. raise InvalidParameter("format {} not support".format(aformat))
  321. if sample_rate not in self.__allow_sample_rate:
  322. raise InvalidParameter("samplerate {} not support".format(sample_rate))
  323. if volume < 0 or volume > 100:
  324. raise InvalidParameter("volume {} not support".format(volume))
  325. if speech_rate < -500 or speech_rate > 500:
  326. raise InvalidParameter("speech_rate {} not support".format(speech_rate))
  327. if pitch_rate < -500 or pitch_rate > 500:
  328. raise InvalidParameter("pitch rate {} not support".format(pitch_rate))
  329. request = self.request.getStartCMD(
  330. voice, aformat, sample_rate, volume, speech_rate, pitch_rate, ex
  331. )
  332. last_state = self.state.get()
  333. if last_state != NlsStreamInputTtsStatus.Begin:
  334. logging.debug("start with wrong state {}".format(last_state))
  335. self.state.set(NlsStreamInputTtsStatus.Failed)
  336. raise WrongStateException("start with wrong state {}".format(last_state))
  337. logging.debug("start with request: {}".format(request))
  338. self.__nls.start(request, ping_interval=0, ping_timeout=None)
  339. self.state.set(NlsStreamInputTtsStatus.Start)
  340. if not self.start_sended.wait(timeout=10):
  341. logging.debug("syn start timeout")
  342. raise StartTimeoutException(f"Waiting Connection before Start over 10s")
  343. if last_state != NlsStreamInputTtsStatus.Begin:
  344. logging.debug("start with wrong state {}".format(last_state))
  345. self.state.set(NlsStreamInputTtsStatus.Failed)
  346. raise WrongStateException("start with wrong state {}".format(last_state))
  347. if not self.started_event.wait(timeout=10):
  348. logging.debug("syn started timeout")
  349. self.state.set(NlsStreamInputTtsStatus.Failed)
  350. raise StartTimeoutException(f"Waiting Started over 10s")
  351. self.state.set(NlsStreamInputTtsStatus.Started)
  352. def sendStreamInputTts(self, text):
  353. """
  354. send text to server
  355. Parameters:
  356. -----------
  357. text: str
  358. utf-8 text
  359. """
  360. last_state = self.state.get()
  361. if last_state != NlsStreamInputTtsStatus.Started:
  362. logging.debug("send with wrong state {}".format(last_state))
  363. self.state.set(NlsStreamInputTtsStatus.Failed)
  364. raise WrongStateException("send with wrong state {}".format(last_state))
  365. request = self.request.getSendCMD(text)
  366. logging.debug("send with request: {}".format(request))
  367. self.__nls.send(request, None)
  368. def stopStreamInputTts(self):
  369. """
  370. Synthesis end
  371. """
  372. last_state = self.state.get()
  373. if last_state != NlsStreamInputTtsStatus.Started:
  374. logging.debug("send with wrong state {}".format(last_state))
  375. self.state.set(NlsStreamInputTtsStatus.Failed)
  376. raise WrongStateException("stop with wrong state {}".format(last_state))
  377. request = self.request.getStopCMD()
  378. logging.debug("stop with request: {}".format(request))
  379. self.__nls.send(request, None)
  380. self.state.set(NlsStreamInputTtsStatus.WaitingComplete)
  381. self.complete_event.wait()
  382. self.state.set(NlsStreamInputTtsStatus.Completed)
  383. self.shutdown()
  384. def shutdown(self):
  385. """
  386. Shutdown connection immediately
  387. """
  388. self.__nls.shutdown()