_handshake.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. """
  2. _handshake.py
  3. websocket - WebSocket client library for Python
  4. Copyright 2021 engn33r
  5. Licensed under the Apache License, Version 2.0 (the "License");
  6. you may not use this file except in compliance with the License.
  7. You may obtain a copy of the License at
  8. http://www.apache.org/licenses/LICENSE-2.0
  9. Unless required by applicable law or agreed to in writing, software
  10. distributed under the License is distributed on an "AS IS" BASIS,
  11. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. See the License for the specific language governing permissions and
  13. limitations under the License.
  14. """
  15. import hashlib
  16. import hmac
  17. import os
  18. from base64 import encodebytes as base64encode
  19. from http import client as HTTPStatus
  20. from ._cookiejar import SimpleCookieJar
  21. from ._exceptions import *
  22. from ._http import *
  23. from ._logging import *
  24. from ._socket import *
  25. __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
  26. # websocket supported version.
  27. VERSION = 13
  28. SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
  29. SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
  30. CookieJar = SimpleCookieJar()
  31. class handshake_response:
  32. def __init__(self, status, headers, subprotocol):
  33. self.status = status
  34. self.headers = headers
  35. self.subprotocol = subprotocol
  36. CookieJar.add(headers.get("set-cookie"))
  37. def handshake(sock, hostname, port, resource, **options):
  38. headers, key = _get_handshake_headers(resource, hostname, port, options)
  39. header_str = "\r\n".join(headers)
  40. send(sock, header_str)
  41. dump("request header", header_str)
  42. #print("request header:", header_str)
  43. status, resp = _get_resp_headers(sock)
  44. if status in SUPPORTED_REDIRECT_STATUSES:
  45. return handshake_response(status, resp, None)
  46. success, subproto = _validate(resp, key, options.get("subprotocols"))
  47. if not success:
  48. raise WebSocketException("Invalid WebSocket Header")
  49. return handshake_response(status, resp, subproto)
  50. def _pack_hostname(hostname):
  51. # IPv6 address
  52. if ':' in hostname:
  53. return '[' + hostname + ']'
  54. return hostname
  55. def _get_handshake_headers(resource, host, port, options):
  56. headers = [
  57. "GET %s HTTP/1.1" % resource,
  58. "Upgrade: websocket"
  59. ]
  60. if port == 80 or port == 443:
  61. hostport = _pack_hostname(host)
  62. else:
  63. hostport = "%s:%d" % (_pack_hostname(host), port)
  64. if "host" in options and options["host"] is not None:
  65. headers.append("Host: %s" % options["host"])
  66. else:
  67. headers.append("Host: %s" % hostport)
  68. if "suppress_origin" not in options or not options["suppress_origin"]:
  69. if "origin" in options and options["origin"] is not None:
  70. headers.append("Origin: %s" % options["origin"])
  71. else:
  72. headers.append("Origin: http://%s" % hostport)
  73. key = _create_sec_websocket_key()
  74. # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
  75. if 'header' not in options or 'Sec-WebSocket-Key' not in options['header']:
  76. key = _create_sec_websocket_key()
  77. headers.append("Sec-WebSocket-Key: %s" % key)
  78. else:
  79. key = options['header']['Sec-WebSocket-Key']
  80. if 'header' not in options or 'Sec-WebSocket-Version' not in options['header']:
  81. headers.append("Sec-WebSocket-Version: %s" % VERSION)
  82. if 'connection' not in options or options['connection'] is None:
  83. headers.append('Connection: Upgrade')
  84. else:
  85. headers.append(options['connection'])
  86. subprotocols = options.get("subprotocols")
  87. if subprotocols:
  88. headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
  89. if "header" in options:
  90. header = options["header"]
  91. if isinstance(header, dict):
  92. header = [
  93. ": ".join([k, v])
  94. for k, v in header.items()
  95. if v is not None
  96. ]
  97. headers.extend(header)
  98. server_cookie = CookieJar.get(host)
  99. client_cookie = options.get("cookie", None)
  100. cookie = "; ".join(filter(None, [server_cookie, client_cookie]))
  101. if cookie:
  102. headers.append("Cookie: %s" % cookie)
  103. headers.append("")
  104. headers.append("")
  105. return headers, key
  106. def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
  107. status, resp_headers, status_message = read_headers(sock)
  108. if status not in success_statuses:
  109. raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
  110. return status, resp_headers
  111. _HEADERS_TO_CHECK = {
  112. "upgrade": "websocket",
  113. "connection": "upgrade",
  114. }
  115. def _validate(headers, key, subprotocols):
  116. subproto = None
  117. for k, v in _HEADERS_TO_CHECK.items():
  118. r = headers.get(k, None)
  119. if not r:
  120. return False, None
  121. r = [x.strip().lower() for x in r.split(',')]
  122. if v not in r:
  123. return False, None
  124. if subprotocols:
  125. subproto = headers.get("sec-websocket-protocol", None)
  126. if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
  127. error("Invalid subprotocol: " + str(subprotocols))
  128. return False, None
  129. subproto = subproto.lower()
  130. result = headers.get("sec-websocket-accept", None)
  131. if not result:
  132. return False, None
  133. result = result.lower()
  134. if isinstance(result, str):
  135. result = result.encode('utf-8')
  136. value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
  137. hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
  138. success = hmac.compare_digest(hashed, result)
  139. if success:
  140. return True, subproto
  141. else:
  142. return False, None
  143. def _create_sec_websocket_key():
  144. randomness = os.urandom(16)
  145. return base64encode(randomness).decode('utf-8').strip()