dask.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from functools import partial
  2. from dask.callbacks import Callback
  3. from .auto import tqdm as tqdm_auto
  4. __author__ = {"github.com/": ["casperdcl"]}
  5. __all__ = ['TqdmCallback']
  6. class TqdmCallback(Callback):
  7. """Dask callback for task progress."""
  8. def __init__(self, start=None, pretask=None, tqdm_class=tqdm_auto,
  9. **tqdm_kwargs):
  10. """
  11. Parameters
  12. ----------
  13. tqdm_class : optional
  14. `tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
  15. tqdm_kwargs : optional
  16. Any other arguments used for all bars.
  17. """
  18. super(TqdmCallback, self).__init__(start=start, pretask=pretask)
  19. if tqdm_kwargs:
  20. tqdm_class = partial(tqdm_class, **tqdm_kwargs)
  21. self.tqdm_class = tqdm_class
  22. def _start_state(self, _, state):
  23. self.pbar = self.tqdm_class(total=sum(
  24. len(state[k]) for k in ['ready', 'waiting', 'running', 'finished']))
  25. def _posttask(self, *_, **__):
  26. self.pbar.update()
  27. def _finish(self, *_, **__):
  28. self.pbar.close()
  29. def display(self):
  30. """Displays in the current cell in Notebooks."""
  31. container = getattr(self.bar, 'container', None)
  32. if container is None:
  33. return
  34. from .notebook import display
  35. display(container)