_abnf.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. """
  2. """
  3. """
  4. _abnf.py
  5. websocket - WebSocket client library for Python
  6. Copyright 2021 engn33r
  7. Licensed under the Apache License, Version 2.0 (the "License");
  8. you may not use this file except in compliance with the License.
  9. You may obtain a copy of the License at
  10. http://www.apache.org/licenses/LICENSE-2.0
  11. Unless required by applicable law or agreed to in writing, software
  12. distributed under the License is distributed on an "AS IS" BASIS,
  13. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. See the License for the specific language governing permissions and
  15. limitations under the License.
  16. """
  17. import array
  18. import os
  19. import struct
  20. import sys
  21. from ._exceptions import *
  22. from ._utils import validate_utf8
  23. from threading import Lock
  24. try:
  25. # If wsaccel is available, use compiled routines to mask data.
  26. # wsaccel only provides around a 10% speed boost compared
  27. # to the websocket-client _mask() implementation.
  28. # Note that wsaccel is unmaintained.
  29. from wsaccel.xormask import XorMaskerSimple
  30. def _mask(_m, _d):
  31. return XorMaskerSimple(_m).process(_d)
  32. except ImportError:
  33. # wsaccel is not available, use websocket-client _mask()
  34. native_byteorder = sys.byteorder
  35. def _mask(mask_value, data_value):
  36. datalen = len(data_value)
  37. data_value = int.from_bytes(data_value, native_byteorder)
  38. mask_value = int.from_bytes(mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder)
  39. return (data_value ^ mask_value).to_bytes(datalen, native_byteorder)
  40. __all__ = [
  41. 'ABNF', 'continuous_frame', 'frame_buffer',
  42. 'STATUS_NORMAL',
  43. 'STATUS_GOING_AWAY',
  44. 'STATUS_PROTOCOL_ERROR',
  45. 'STATUS_UNSUPPORTED_DATA_TYPE',
  46. 'STATUS_STATUS_NOT_AVAILABLE',
  47. 'STATUS_ABNORMAL_CLOSED',
  48. 'STATUS_INVALID_PAYLOAD',
  49. 'STATUS_POLICY_VIOLATION',
  50. 'STATUS_MESSAGE_TOO_BIG',
  51. 'STATUS_INVALID_EXTENSION',
  52. 'STATUS_UNEXPECTED_CONDITION',
  53. 'STATUS_BAD_GATEWAY',
  54. 'STATUS_TLS_HANDSHAKE_ERROR',
  55. ]
  56. # closing frame status codes.
  57. STATUS_NORMAL = 1000
  58. STATUS_GOING_AWAY = 1001
  59. STATUS_PROTOCOL_ERROR = 1002
  60. STATUS_UNSUPPORTED_DATA_TYPE = 1003
  61. STATUS_STATUS_NOT_AVAILABLE = 1005
  62. STATUS_ABNORMAL_CLOSED = 1006
  63. STATUS_INVALID_PAYLOAD = 1007
  64. STATUS_POLICY_VIOLATION = 1008
  65. STATUS_MESSAGE_TOO_BIG = 1009
  66. STATUS_INVALID_EXTENSION = 1010
  67. STATUS_UNEXPECTED_CONDITION = 1011
  68. STATUS_BAD_GATEWAY = 1014
  69. STATUS_TLS_HANDSHAKE_ERROR = 1015
  70. VALID_CLOSE_STATUS = (
  71. STATUS_NORMAL,
  72. STATUS_GOING_AWAY,
  73. STATUS_PROTOCOL_ERROR,
  74. STATUS_UNSUPPORTED_DATA_TYPE,
  75. STATUS_INVALID_PAYLOAD,
  76. STATUS_POLICY_VIOLATION,
  77. STATUS_MESSAGE_TOO_BIG,
  78. STATUS_INVALID_EXTENSION,
  79. STATUS_UNEXPECTED_CONDITION,
  80. STATUS_BAD_GATEWAY,
  81. )
  82. class ABNF:
  83. """
  84. ABNF frame class.
  85. See http://tools.ietf.org/html/rfc5234
  86. and http://tools.ietf.org/html/rfc6455#section-5.2
  87. """
  88. # operation code values.
  89. OPCODE_CONT = 0x0
  90. OPCODE_TEXT = 0x1
  91. OPCODE_BINARY = 0x2
  92. OPCODE_CLOSE = 0x8
  93. OPCODE_PING = 0x9
  94. OPCODE_PONG = 0xa
  95. # available operation code value tuple
  96. OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
  97. OPCODE_PING, OPCODE_PONG)
  98. # opcode human readable string
  99. OPCODE_MAP = {
  100. OPCODE_CONT: "cont",
  101. OPCODE_TEXT: "text",
  102. OPCODE_BINARY: "binary",
  103. OPCODE_CLOSE: "close",
  104. OPCODE_PING: "ping",
  105. OPCODE_PONG: "pong"
  106. }
  107. # data length threshold.
  108. LENGTH_7 = 0x7e
  109. LENGTH_16 = 1 << 16
  110. LENGTH_63 = 1 << 63
  111. def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
  112. opcode=OPCODE_TEXT, mask=1, data=""):
  113. """
  114. Constructor for ABNF. Please check RFC for arguments.
  115. """
  116. self.fin = fin
  117. self.rsv1 = rsv1
  118. self.rsv2 = rsv2
  119. self.rsv3 = rsv3
  120. self.opcode = opcode
  121. self.mask = mask
  122. if data is None:
  123. data = ""
  124. self.data = data
  125. self.get_mask_key = os.urandom
  126. def validate(self, skip_utf8_validation=False):
  127. """
  128. Validate the ABNF frame.
  129. Parameters
  130. ----------
  131. skip_utf8_validation: skip utf8 validation.
  132. """
  133. if self.rsv1 or self.rsv2 or self.rsv3:
  134. raise WebSocketProtocolException("rsv is not implemented, yet")
  135. if self.opcode not in ABNF.OPCODES:
  136. raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
  137. if self.opcode == ABNF.OPCODE_PING and not self.fin:
  138. raise WebSocketProtocolException("Invalid ping frame.")
  139. if self.opcode == ABNF.OPCODE_CLOSE:
  140. l = len(self.data)
  141. if not l:
  142. return
  143. if l == 1 or l >= 126:
  144. raise WebSocketProtocolException("Invalid close frame.")
  145. if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
  146. raise WebSocketProtocolException("Invalid close frame.")
  147. code = 256 * self.data[0] + self.data[1]
  148. if not self._is_valid_close_status(code):
  149. raise WebSocketProtocolException("Invalid close opcode.")
  150. @staticmethod
  151. def _is_valid_close_status(code):
  152. return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
  153. def __str__(self):
  154. return "fin=" + str(self.fin) \
  155. + " opcode=" + str(self.opcode) \
  156. + " data=" + str(self.data)
  157. @staticmethod
  158. def create_frame(data, opcode, fin=1):
  159. """
  160. Create frame to send text, binary and other data.
  161. Parameters
  162. ----------
  163. data: <type>
  164. data to send. This is string value(byte array).
  165. If opcode is OPCODE_TEXT and this value is unicode,
  166. data value is converted into unicode string, automatically.
  167. opcode: <type>
  168. operation code. please see OPCODE_XXX.
  169. fin: <type>
  170. fin flag. if set to 0, create continue fragmentation.
  171. """
  172. if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
  173. data = data.encode("utf-8")
  174. # mask must be set if send data from client
  175. return ABNF(fin, 0, 0, 0, opcode, 1, data)
  176. def format(self):
  177. """
  178. Format this object to string(byte array) to send data to server.
  179. """
  180. if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
  181. raise ValueError("not 0 or 1")
  182. if self.opcode not in ABNF.OPCODES:
  183. raise ValueError("Invalid OPCODE")
  184. length = len(self.data)
  185. if length >= ABNF.LENGTH_63:
  186. raise ValueError("data is too long")
  187. frame_header = chr(self.fin << 7 |
  188. self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
  189. self.opcode).encode('latin-1')
  190. if length < ABNF.LENGTH_7:
  191. frame_header += chr(self.mask << 7 | length).encode('latin-1')
  192. elif length < ABNF.LENGTH_16:
  193. frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1')
  194. frame_header += struct.pack("!H", length)
  195. else:
  196. frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1')
  197. frame_header += struct.pack("!Q", length)
  198. if not self.mask:
  199. return frame_header + self.data
  200. else:
  201. mask_key = self.get_mask_key(4)
  202. return frame_header + self._get_masked(mask_key)
  203. def _get_masked(self, mask_key):
  204. s = ABNF.mask(mask_key, self.data)
  205. if isinstance(mask_key, str):
  206. mask_key = mask_key.encode('utf-8')
  207. return mask_key + s
  208. @staticmethod
  209. def mask(mask_key, data):
  210. """
  211. Mask or unmask data. Just do xor for each byte
  212. Parameters
  213. ----------
  214. mask_key: <type>
  215. 4 byte string.
  216. data: <type>
  217. data to mask/unmask.
  218. """
  219. if data is None:
  220. data = ""
  221. if isinstance(mask_key, str):
  222. mask_key = mask_key.encode('latin-1')
  223. if isinstance(data, str):
  224. data = data.encode('latin-1')
  225. return _mask(array.array("B", mask_key), array.array("B", data))
  226. class frame_buffer:
  227. _HEADER_MASK_INDEX = 5
  228. _HEADER_LENGTH_INDEX = 6
  229. def __init__(self, recv_fn, skip_utf8_validation):
  230. self.recv = recv_fn
  231. self.skip_utf8_validation = skip_utf8_validation
  232. # Buffers over the packets from the layer beneath until desired amount
  233. # bytes of bytes are received.
  234. self.recv_buffer = []
  235. self.clear()
  236. self.lock = Lock()
  237. def clear(self):
  238. self.header = None
  239. self.length = None
  240. self.mask = None
  241. def has_received_header(self):
  242. return self.header is None
  243. def recv_header(self):
  244. header = self.recv_strict(2)
  245. b1 = header[0]
  246. fin = b1 >> 7 & 1
  247. rsv1 = b1 >> 6 & 1
  248. rsv2 = b1 >> 5 & 1
  249. rsv3 = b1 >> 4 & 1
  250. opcode = b1 & 0xf
  251. b2 = header[1]
  252. has_mask = b2 >> 7 & 1
  253. length_bits = b2 & 0x7f
  254. self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
  255. def has_mask(self):
  256. if not self.header:
  257. return False
  258. return self.header[frame_buffer._HEADER_MASK_INDEX]
  259. def has_received_length(self):
  260. return self.length is None
  261. def recv_length(self):
  262. bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
  263. length_bits = bits & 0x7f
  264. if length_bits == 0x7e:
  265. v = self.recv_strict(2)
  266. self.length = struct.unpack("!H", v)[0]
  267. elif length_bits == 0x7f:
  268. v = self.recv_strict(8)
  269. self.length = struct.unpack("!Q", v)[0]
  270. else:
  271. self.length = length_bits
  272. def has_received_mask(self):
  273. return self.mask is None
  274. def recv_mask(self):
  275. self.mask = self.recv_strict(4) if self.has_mask() else ""
  276. def recv_frame(self):
  277. with self.lock:
  278. # Header
  279. if self.has_received_header():
  280. self.recv_header()
  281. (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
  282. # Frame length
  283. if self.has_received_length():
  284. self.recv_length()
  285. length = self.length
  286. # Mask
  287. if self.has_received_mask():
  288. self.recv_mask()
  289. mask = self.mask
  290. # Payload
  291. payload = self.recv_strict(length)
  292. if has_mask:
  293. payload = ABNF.mask(mask, payload)
  294. # Reset for next frame
  295. self.clear()
  296. frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
  297. frame.validate(self.skip_utf8_validation)
  298. return frame
  299. def recv_strict(self, bufsize):
  300. shortage = bufsize - sum(map(len, self.recv_buffer))
  301. while shortage > 0:
  302. # Limit buffer size that we pass to socket.recv() to avoid
  303. # fragmenting the heap -- the number of bytes recv() actually
  304. # reads is limited by socket buffer and is relatively small,
  305. # yet passing large numbers repeatedly causes lots of large
  306. # buffers allocated and then shrunk, which results in
  307. # fragmentation.
  308. bytes_ = self.recv(min(16384, shortage))
  309. self.recv_buffer.append(bytes_)
  310. shortage -= len(bytes_)
  311. unified = bytes("", 'utf-8').join(self.recv_buffer)
  312. if shortage == 0:
  313. self.recv_buffer = []
  314. return unified
  315. else:
  316. self.recv_buffer = [unified[bufsize:]]
  317. return unified[:bufsize]
  318. class continuous_frame:
  319. def __init__(self, fire_cont_frame, skip_utf8_validation):
  320. self.fire_cont_frame = fire_cont_frame
  321. self.skip_utf8_validation = skip_utf8_validation
  322. self.cont_data = None
  323. self.recving_frames = None
  324. def validate(self, frame):
  325. if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
  326. raise WebSocketProtocolException("Illegal frame")
  327. if self.recving_frames and \
  328. frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
  329. raise WebSocketProtocolException("Illegal frame")
  330. def add(self, frame):
  331. if self.cont_data:
  332. self.cont_data[1] += frame.data
  333. else:
  334. if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
  335. self.recving_frames = frame.opcode
  336. self.cont_data = [frame.opcode, frame.data]
  337. if frame.fin:
  338. self.recving_frames = None
  339. def is_fire(self, frame):
  340. return frame.fin or self.fire_cont_frame
  341. def extract(self, frame):
  342. data = self.cont_data
  343. self.cont_data = None
  344. frame.data = data[1]
  345. if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
  346. raise WebSocketPayloadException(
  347. "cannot decode: " + repr(frame.data))
  348. return [data[0], frame]