tests_contrib.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """
  2. Tests for `tqdm.contrib`.
  3. """
  4. import pytest
  5. from tqdm import tqdm
  6. from tqdm.contrib import tenumerate, tmap, tzip
  7. from .tests_tqdm import StringIO, closing, importorskip
  8. def incr(x):
  9. """Dummy function"""
  10. return x + 1
  11. @pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
  12. def test_enumerate(tqdm_kwargs):
  13. """Test contrib.tenumerate"""
  14. with closing(StringIO()) as our_file:
  15. a = range(9)
  16. assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a))
  17. assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list(
  18. enumerate(a, 42)
  19. )
  20. with closing(StringIO()) as our_file:
  21. _ = list(tenumerate(iter(a), file=our_file, **tqdm_kwargs))
  22. assert "100%" not in our_file.getvalue()
  23. with closing(StringIO()) as our_file:
  24. _ = list(tenumerate(iter(a), file=our_file, total=len(a), **tqdm_kwargs))
  25. assert "100%" in our_file.getvalue()
  26. def test_enumerate_numpy():
  27. """Test contrib.tenumerate(numpy.ndarray)"""
  28. np = importorskip("numpy")
  29. with closing(StringIO()) as our_file:
  30. a = np.random.random((42, 7))
  31. assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a))
  32. @pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
  33. def test_zip(tqdm_kwargs):
  34. """Test contrib.tzip"""
  35. with closing(StringIO()) as our_file:
  36. a = range(9)
  37. b = [i + 1 for i in a]
  38. gen = tzip(a, b, file=our_file, **tqdm_kwargs)
  39. assert gen != list(zip(a, b))
  40. assert list(gen) == list(zip(a, b))
  41. @pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
  42. def test_map(tqdm_kwargs):
  43. """Test contrib.tmap"""
  44. with closing(StringIO()) as our_file:
  45. a = range(9)
  46. b = [i + 1 for i in a]
  47. gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs)
  48. assert gen != b
  49. assert list(gen) == b