test_fernet.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import base64
  5. import calendar
  6. import json
  7. import os
  8. import time
  9. import iso8601
  10. import pretend
  11. import pytest
  12. import cryptography_vectors
  13. from cryptography.fernet import Fernet, InvalidToken, MultiFernet
  14. from cryptography.hazmat.primitives.ciphers import algorithms, modes
  15. def json_parametrize(keys, filename):
  16. vector_file = cryptography_vectors.open_vector_file(
  17. os.path.join("fernet", filename), "r"
  18. )
  19. with vector_file:
  20. data = json.load(vector_file)
  21. return pytest.mark.parametrize(
  22. keys,
  23. [tuple([entry[k] for k in keys]) for entry in data],
  24. ids=[f"{filename}[{i}]" for i in range(len(data))],
  25. )
  26. @pytest.mark.supported(
  27. only_if=lambda backend: backend.cipher_supported(
  28. algorithms.AES(b"\x00" * 32), modes.CBC(b"\x00" * 16)
  29. ),
  30. skip_message="Does not support AES CBC",
  31. )
  32. class TestFernet:
  33. @json_parametrize(
  34. ("secret", "now", "iv", "src", "token"),
  35. "generate.json",
  36. )
  37. def test_generate(self, secret, now, iv, src, token, backend):
  38. f = Fernet(secret.encode("ascii"), backend=backend)
  39. actual_token = f._encrypt_from_parts(
  40. src.encode("ascii"),
  41. calendar.timegm(iso8601.parse_date(now).utctimetuple()),
  42. bytes(iv),
  43. )
  44. assert actual_token == token.encode("ascii")
  45. @json_parametrize(
  46. ("secret", "now", "src", "ttl_sec", "token"),
  47. "verify.json",
  48. )
  49. def test_verify(
  50. self, secret, now, src, ttl_sec, token, backend, monkeypatch
  51. ):
  52. # secret & token are both str
  53. f = Fernet(secret.encode("ascii"), backend=backend)
  54. current_time = calendar.timegm(iso8601.parse_date(now).utctimetuple())
  55. payload = f.decrypt_at_time(
  56. token, # str
  57. ttl=ttl_sec,
  58. current_time=current_time,
  59. )
  60. assert payload == src.encode("ascii")
  61. payload = f.decrypt_at_time(
  62. token.encode("ascii"), # bytes
  63. ttl=ttl_sec,
  64. current_time=current_time,
  65. )
  66. assert payload == src.encode("ascii")
  67. monkeypatch.setattr(time, "time", lambda: current_time)
  68. payload = f.decrypt(token, ttl=ttl_sec) # str
  69. assert payload == src.encode("ascii")
  70. payload = f.decrypt(token.encode("ascii"), ttl=ttl_sec) # bytes
  71. assert payload == src.encode("ascii")
  72. @json_parametrize(("secret", "token", "now", "ttl_sec"), "invalid.json")
  73. def test_invalid(self, secret, token, now, ttl_sec, backend, monkeypatch):
  74. f = Fernet(secret.encode("ascii"), backend=backend)
  75. current_time = calendar.timegm(iso8601.parse_date(now).utctimetuple())
  76. with pytest.raises(InvalidToken):
  77. f.decrypt_at_time(
  78. token.encode("ascii"),
  79. ttl=ttl_sec,
  80. current_time=current_time,
  81. )
  82. monkeypatch.setattr(time, "time", lambda: current_time)
  83. with pytest.raises(InvalidToken):
  84. f.decrypt(token.encode("ascii"), ttl=ttl_sec)
  85. def test_invalid_start_byte(self, backend):
  86. f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  87. with pytest.raises(InvalidToken):
  88. f.decrypt(base64.urlsafe_b64encode(b"\x81"))
  89. def test_timestamp_too_short(self, backend):
  90. f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  91. with pytest.raises(InvalidToken):
  92. f.decrypt(base64.urlsafe_b64encode(b"\x80abc"))
  93. def test_non_base64_token(self, backend):
  94. f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  95. with pytest.raises(InvalidToken):
  96. f.decrypt(b"\x00")
  97. with pytest.raises(InvalidToken):
  98. f.decrypt("nonsensetoken")
  99. def test_invalid_types(self, backend):
  100. f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  101. with pytest.raises(TypeError):
  102. f.encrypt("") # type: ignore[arg-type]
  103. with pytest.raises(TypeError):
  104. f.decrypt(12345) # type: ignore[arg-type]
  105. def test_timestamp_ignored_no_ttl(self, monkeypatch, backend):
  106. f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  107. pt = b"encrypt me"
  108. token = f.encrypt(pt)
  109. monkeypatch.setattr(time, "time", pretend.raiser(ValueError))
  110. assert f.decrypt(token, ttl=None) == pt
  111. def test_ttl_required_in_decrypt_at_time(self, monkeypatch, backend):
  112. f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  113. pt = b"encrypt me"
  114. token = f.encrypt(pt)
  115. with pytest.raises(ValueError):
  116. f.decrypt_at_time(
  117. token,
  118. ttl=None, # type: ignore[arg-type]
  119. current_time=int(time.time()),
  120. )
  121. @pytest.mark.parametrize("message", [b"", b"Abc!", b"\x00\xFF\x00\x80"])
  122. def test_roundtrips(self, message, backend):
  123. f = Fernet(Fernet.generate_key(), backend=backend)
  124. assert f.decrypt(f.encrypt(message)) == message
  125. @pytest.mark.parametrize("key", [base64.urlsafe_b64encode(b"abc"), b"abc"])
  126. def test_bad_key(self, backend, key):
  127. with pytest.raises(ValueError):
  128. Fernet(key, backend=backend)
  129. def test_extract_timestamp(self, monkeypatch, backend):
  130. f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  131. current_time = 1526138327
  132. token = f.encrypt_at_time(b"encrypt me", current_time)
  133. assert f.extract_timestamp(token) == current_time
  134. assert f.extract_timestamp(token.decode("ascii")) == current_time
  135. with pytest.raises(InvalidToken):
  136. f.extract_timestamp(b"nonsensetoken")
  137. @pytest.mark.supported(
  138. only_if=lambda backend: backend.cipher_supported(
  139. algorithms.AES(b"\x00" * 32), modes.CBC(b"\x00" * 16)
  140. ),
  141. skip_message="Does not support AES CBC",
  142. )
  143. class TestMultiFernet:
  144. def test_encrypt(self, backend):
  145. f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  146. f2 = Fernet(base64.urlsafe_b64encode(b"\x01" * 32), backend=backend)
  147. f = MultiFernet([f1, f2])
  148. assert f1.decrypt(f.encrypt(b"abc")) == b"abc"
  149. def test_decrypt(self, backend):
  150. f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  151. f2 = Fernet(base64.urlsafe_b64encode(b"\x01" * 32), backend=backend)
  152. f = MultiFernet([f1, f2])
  153. # token as bytes
  154. assert f.decrypt(f1.encrypt(b"abc")) == b"abc"
  155. assert f.decrypt(f2.encrypt(b"abc")) == b"abc"
  156. # token as str
  157. assert f.decrypt(f1.encrypt(b"abc").decode("ascii")) == b"abc"
  158. assert f.decrypt(f2.encrypt(b"abc").decode("ascii")) == b"abc"
  159. with pytest.raises(InvalidToken):
  160. f.decrypt(b"\x00" * 16)
  161. def test_decrypt_at_time(self, backend):
  162. f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  163. f = MultiFernet([f1])
  164. pt = b"encrypt me"
  165. token = f.encrypt_at_time(pt, current_time=100)
  166. assert f.decrypt_at_time(token, ttl=1, current_time=100) == pt
  167. with pytest.raises(InvalidToken):
  168. f.decrypt_at_time(token, ttl=1, current_time=102)
  169. with pytest.raises(ValueError):
  170. f.decrypt_at_time(
  171. token, ttl=None, current_time=100 # type: ignore[arg-type]
  172. )
  173. def test_no_fernets(self, backend):
  174. with pytest.raises(ValueError):
  175. MultiFernet([])
  176. def test_non_iterable_argument(self, backend):
  177. with pytest.raises(TypeError):
  178. MultiFernet(None) # type: ignore[arg-type]
  179. def test_rotate_bytes(self, backend):
  180. f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  181. f2 = Fernet(base64.urlsafe_b64encode(b"\x01" * 32), backend=backend)
  182. mf1 = MultiFernet([f1])
  183. mf2 = MultiFernet([f2, f1])
  184. plaintext = b"abc"
  185. mf1_ciphertext = mf1.encrypt(plaintext)
  186. assert mf2.decrypt(mf1_ciphertext) == plaintext
  187. rotated = mf2.rotate(mf1_ciphertext)
  188. assert rotated != mf1_ciphertext
  189. assert mf2.decrypt(rotated) == plaintext
  190. with pytest.raises(InvalidToken):
  191. mf1.decrypt(rotated)
  192. def test_rotate_str(self, backend):
  193. f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  194. f2 = Fernet(base64.urlsafe_b64encode(b"\x01" * 32), backend=backend)
  195. mf1 = MultiFernet([f1])
  196. mf2 = MultiFernet([f2, f1])
  197. plaintext = b"abc"
  198. mf1_ciphertext = mf1.encrypt(plaintext).decode("ascii")
  199. assert mf2.decrypt(mf1_ciphertext) == plaintext
  200. rotated = mf2.rotate(mf1_ciphertext).decode("ascii")
  201. assert rotated != mf1_ciphertext
  202. assert mf2.decrypt(rotated) == plaintext
  203. with pytest.raises(InvalidToken):
  204. mf1.decrypt(rotated)
  205. def test_rotate_preserves_timestamp(self, backend, monkeypatch):
  206. f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  207. f2 = Fernet(base64.urlsafe_b64encode(b"\x01" * 32), backend=backend)
  208. mf1 = MultiFernet([f1])
  209. mf2 = MultiFernet([f2, f1])
  210. plaintext = b"abc"
  211. original_time = int(time.time()) - 5 * 60
  212. mf1_ciphertext = mf1.encrypt_at_time(plaintext, original_time)
  213. rotated_time, _ = Fernet._get_unverified_token_data(
  214. mf2.rotate(mf1_ciphertext)
  215. )
  216. assert int(time.time()) != rotated_time
  217. assert original_time == rotated_time
  218. def test_rotate_decrypt_no_shared_keys(self, backend):
  219. f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
  220. f2 = Fernet(base64.urlsafe_b64encode(b"\x01" * 32), backend=backend)
  221. mf1 = MultiFernet([f1])
  222. mf2 = MultiFernet([f2])
  223. with pytest.raises(InvalidToken):
  224. mf2.rotate(mf1.encrypt(b"abc"))