tests_keras.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from .tests_tqdm import importorskip, mark
  2. pytestmark = mark.slow
  3. @mark.filterwarnings("ignore:.*:DeprecationWarning")
  4. def test_keras(capsys):
  5. """Test tqdm.keras.TqdmCallback"""
  6. TqdmCallback = importorskip('tqdm.keras').TqdmCallback
  7. np = importorskip('numpy')
  8. try:
  9. import keras as K
  10. except ImportError:
  11. K = importorskip('tensorflow.keras')
  12. # 1D autoencoder
  13. dtype = np.float32
  14. model = K.models.Sequential([
  15. K.layers.InputLayer((1, 1), dtype=dtype), K.layers.Conv1D(1, 1)])
  16. model.compile("adam", "mse")
  17. x = np.random.rand(100, 1, 1).astype(dtype)
  18. batch_size = 10
  19. batches = len(x) / batch_size
  20. epochs = 5
  21. # just epoch (no batch) progress
  22. model.fit(
  23. x,
  24. x,
  25. epochs=epochs,
  26. batch_size=batch_size,
  27. verbose=False,
  28. callbacks=[
  29. TqdmCallback(
  30. epochs,
  31. desc="training",
  32. data_size=len(x),
  33. batch_size=batch_size,
  34. verbose=0)])
  35. _, res = capsys.readouterr()
  36. assert "training: " in res
  37. assert "{epochs}/{epochs}".format(epochs=epochs) in res
  38. assert "{batches}/{batches}".format(batches=batches) not in res
  39. # full (epoch and batch) progress
  40. model.fit(
  41. x,
  42. x,
  43. epochs=epochs,
  44. batch_size=batch_size,
  45. verbose=False,
  46. callbacks=[
  47. TqdmCallback(
  48. epochs,
  49. desc="training",
  50. data_size=len(x),
  51. batch_size=batch_size,
  52. verbose=2)])
  53. _, res = capsys.readouterr()
  54. assert "training: " in res
  55. assert "{epochs}/{epochs}".format(epochs=epochs) in res
  56. assert "{batches}/{batches}".format(batches=batches) in res
  57. # auto-detect epochs and batches
  58. model.fit(
  59. x,
  60. x,
  61. epochs=epochs,
  62. batch_size=batch_size,
  63. verbose=False,
  64. callbacks=[TqdmCallback(desc="training", verbose=2)])
  65. _, res = capsys.readouterr()
  66. assert "training: " in res
  67. assert "{epochs}/{epochs}".format(epochs=epochs) in res
  68. assert "{batches}/{batches}".format(batches=batches) in res
  69. # continue training (start from epoch != 0)
  70. initial_epoch = 3
  71. model.fit(
  72. x,
  73. x,
  74. initial_epoch=initial_epoch,
  75. epochs=epochs,
  76. batch_size=batch_size,
  77. verbose=False,
  78. callbacks=[TqdmCallback(desc="training", verbose=0,
  79. miniters=1, mininterval=0, maxinterval=0)])
  80. _, res = capsys.readouterr()
  81. assert "training: " in res
  82. assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res
  83. assert "{epochs}/{epochs}".format(epochs=epochs) in res