rich.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """
  2. `rich.progress` decorator for iterators.
  3. Usage:
  4. >>> from tqdm.rich import trange, tqdm
  5. >>> for i in trange(10):
  6. ... ...
  7. """
  8. from warnings import warn
  9. from rich.progress import (
  10. BarColumn, Progress, ProgressColumn, Text, TimeElapsedColumn, TimeRemainingColumn, filesize)
  11. from .std import TqdmExperimentalWarning
  12. from .std import tqdm as std_tqdm
  13. __author__ = {"github.com/": ["casperdcl"]}
  14. __all__ = ['tqdm_rich', 'trrange', 'tqdm', 'trange']
  15. class FractionColumn(ProgressColumn):
  16. """Renders completed/total, e.g. '0.5/2.3 G'."""
  17. def __init__(self, unit_scale=False, unit_divisor=1000):
  18. self.unit_scale = unit_scale
  19. self.unit_divisor = unit_divisor
  20. super().__init__()
  21. def render(self, task):
  22. """Calculate common unit for completed and total."""
  23. completed = int(task.completed)
  24. total = int(task.total)
  25. if self.unit_scale:
  26. unit, suffix = filesize.pick_unit_and_suffix(
  27. total,
  28. ["", "K", "M", "G", "T", "P", "E", "Z", "Y"],
  29. self.unit_divisor,
  30. )
  31. else:
  32. unit, suffix = filesize.pick_unit_and_suffix(total, [""], 1)
  33. precision = 0 if unit == 1 else 1
  34. return Text(
  35. f"{completed/unit:,.{precision}f}/{total/unit:,.{precision}f} {suffix}",
  36. style="progress.download")
  37. class RateColumn(ProgressColumn):
  38. """Renders human readable transfer speed."""
  39. def __init__(self, unit="", unit_scale=False, unit_divisor=1000):
  40. self.unit = unit
  41. self.unit_scale = unit_scale
  42. self.unit_divisor = unit_divisor
  43. super().__init__()
  44. def render(self, task):
  45. """Show data transfer speed."""
  46. speed = task.speed
  47. if speed is None:
  48. return Text(f"? {self.unit}/s", style="progress.data.speed")
  49. if self.unit_scale:
  50. unit, suffix = filesize.pick_unit_and_suffix(
  51. speed,
  52. ["", "K", "M", "G", "T", "P", "E", "Z", "Y"],
  53. self.unit_divisor,
  54. )
  55. else:
  56. unit, suffix = filesize.pick_unit_and_suffix(speed, [""], 1)
  57. precision = 0 if unit == 1 else 1
  58. return Text(f"{speed/unit:,.{precision}f} {suffix}{self.unit}/s",
  59. style="progress.data.speed")
  60. class tqdm_rich(std_tqdm): # pragma: no cover
  61. """Experimental rich.progress GUI version of tqdm!"""
  62. # TODO: @classmethod: write()?
  63. def __init__(self, *args, **kwargs):
  64. """
  65. This class accepts the following parameters *in addition* to
  66. the parameters accepted by `tqdm`.
  67. Parameters
  68. ----------
  69. progress : tuple, optional
  70. arguments for `rich.progress.Progress()`.
  71. options : dict, optional
  72. keyword arguments for `rich.progress.Progress()`.
  73. """
  74. kwargs = kwargs.copy()
  75. kwargs['gui'] = True
  76. # convert disable = None to False
  77. kwargs['disable'] = bool(kwargs.get('disable', False))
  78. progress = kwargs.pop('progress', None)
  79. options = kwargs.pop('options', {}).copy()
  80. super(tqdm_rich, self).__init__(*args, **kwargs)
  81. if self.disable:
  82. return
  83. warn("rich is experimental/alpha", TqdmExperimentalWarning, stacklevel=2)
  84. d = self.format_dict
  85. if progress is None:
  86. progress = (
  87. "[progress.description]{task.description}"
  88. "[progress.percentage]{task.percentage:>4.0f}%",
  89. BarColumn(bar_width=None),
  90. FractionColumn(
  91. unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']),
  92. "[", TimeElapsedColumn(), "<", TimeRemainingColumn(),
  93. ",", RateColumn(unit=d['unit'], unit_scale=d['unit_scale'],
  94. unit_divisor=d['unit_divisor']), "]"
  95. )
  96. options.setdefault('transient', not self.leave)
  97. self._prog = Progress(*progress, **options)
  98. self._prog.__enter__()
  99. self._task_id = self._prog.add_task(self.desc or "", **d)
  100. def close(self):
  101. if self.disable:
  102. return
  103. super(tqdm_rich, self).close()
  104. self._prog.__exit__(None, None, None)
  105. def clear(self, *_, **__):
  106. pass
  107. def display(self, *_, **__):
  108. if not hasattr(self, '_prog'):
  109. return
  110. self._prog.update(self._task_id, completed=self.n, description=self.desc)
  111. def reset(self, total=None):
  112. """
  113. Resets to 0 iterations for repeated use.
  114. Parameters
  115. ----------
  116. total : int or float, optional. Total to use for the new bar.
  117. """
  118. if hasattr(self, '_prog'):
  119. self._prog.reset(total=total)
  120. super(tqdm_rich, self).reset(total=total)
  121. def trrange(*args, **kwargs):
  122. """Shortcut for `tqdm.rich.tqdm(range(*args), **kwargs)`."""
  123. return tqdm_rich(range(*args), **kwargs)
  124. # Aliases
  125. tqdm = tqdm_rich
  126. trange = trrange