utils.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. """
  2. General helpers required for `tqdm.std`.
  3. """
  4. import os
  5. import re
  6. import sys
  7. from functools import wraps
  8. # TODO consider using wcswidth third-party package for 0-width characters
  9. from unicodedata import east_asian_width
  10. from warnings import warn
  11. from weakref import proxy
  12. _range, _unich, _unicode, _basestring = range, chr, str, str
  13. CUR_OS = sys.platform
  14. IS_WIN = any(CUR_OS.startswith(i) for i in ['win32', 'cygwin'])
  15. IS_NIX = any(CUR_OS.startswith(i) for i in ['aix', 'linux', 'darwin'])
  16. RE_ANSI = re.compile(r"\x1b\[[;\d]*[A-Za-z]")
  17. try:
  18. if IS_WIN:
  19. import colorama
  20. else:
  21. raise ImportError
  22. except ImportError:
  23. colorama = None
  24. else:
  25. try:
  26. colorama.init(strip=False)
  27. except TypeError:
  28. colorama.init()
  29. class FormatReplace(object):
  30. """
  31. >>> a = FormatReplace('something')
  32. >>> "{:5d}".format(a)
  33. 'something'
  34. """ # NOQA: P102
  35. def __init__(self, replace=''):
  36. self.replace = replace
  37. self.format_called = 0
  38. def __format__(self, _):
  39. self.format_called += 1
  40. return self.replace
  41. class Comparable(object):
  42. """Assumes child has self._comparable attr/@property"""
  43. def __lt__(self, other):
  44. return self._comparable < other._comparable
  45. def __le__(self, other):
  46. return (self < other) or (self == other)
  47. def __eq__(self, other):
  48. return self._comparable == other._comparable
  49. def __ne__(self, other):
  50. return not self == other
  51. def __gt__(self, other):
  52. return not self <= other
  53. def __ge__(self, other):
  54. return not self < other
  55. class ObjectWrapper(object):
  56. def __getattr__(self, name):
  57. return getattr(self._wrapped, name)
  58. def __setattr__(self, name, value):
  59. return setattr(self._wrapped, name, value)
  60. def wrapper_getattr(self, name):
  61. """Actual `self.getattr` rather than self._wrapped.getattr"""
  62. try:
  63. return object.__getattr__(self, name)
  64. except AttributeError: # py2
  65. return getattr(self, name)
  66. def wrapper_setattr(self, name, value):
  67. """Actual `self.setattr` rather than self._wrapped.setattr"""
  68. return object.__setattr__(self, name, value)
  69. def __init__(self, wrapped):
  70. """
  71. Thin wrapper around a given object
  72. """
  73. self.wrapper_setattr('_wrapped', wrapped)
  74. class SimpleTextIOWrapper(ObjectWrapper):
  75. """
  76. Change only `.write()` of the wrapped object by encoding the passed
  77. value and passing the result to the wrapped object's `.write()` method.
  78. """
  79. # pylint: disable=too-few-public-methods
  80. def __init__(self, wrapped, encoding):
  81. super(SimpleTextIOWrapper, self).__init__(wrapped)
  82. self.wrapper_setattr('encoding', encoding)
  83. def write(self, s):
  84. """
  85. Encode `s` and pass to the wrapped object's `.write()` method.
  86. """
  87. return self._wrapped.write(s.encode(self.wrapper_getattr('encoding')))
  88. def __eq__(self, other):
  89. return self._wrapped == getattr(other, '_wrapped', other)
  90. class DisableOnWriteError(ObjectWrapper):
  91. """
  92. Disable the given `tqdm_instance` upon `write()` or `flush()` errors.
  93. """
  94. @staticmethod
  95. def disable_on_exception(tqdm_instance, func):
  96. """
  97. Quietly set `tqdm_instance.miniters=inf` if `func` raises `errno=5`.
  98. """
  99. tqdm_instance = proxy(tqdm_instance)
  100. def inner(*args, **kwargs):
  101. try:
  102. return func(*args, **kwargs)
  103. except OSError as e:
  104. if e.errno != 5:
  105. raise
  106. try:
  107. tqdm_instance.miniters = float('inf')
  108. except ReferenceError:
  109. pass
  110. except ValueError as e:
  111. if 'closed' not in str(e):
  112. raise
  113. try:
  114. tqdm_instance.miniters = float('inf')
  115. except ReferenceError:
  116. pass
  117. return inner
  118. def __init__(self, wrapped, tqdm_instance):
  119. super(DisableOnWriteError, self).__init__(wrapped)
  120. if hasattr(wrapped, 'write'):
  121. self.wrapper_setattr(
  122. 'write', self.disable_on_exception(tqdm_instance, wrapped.write))
  123. if hasattr(wrapped, 'flush'):
  124. self.wrapper_setattr(
  125. 'flush', self.disable_on_exception(tqdm_instance, wrapped.flush))
  126. def __eq__(self, other):
  127. return self._wrapped == getattr(other, '_wrapped', other)
  128. class CallbackIOWrapper(ObjectWrapper):
  129. def __init__(self, callback, stream, method="read"):
  130. """
  131. Wrap a given `file`-like object's `read()` or `write()` to report
  132. lengths to the given `callback`
  133. """
  134. super(CallbackIOWrapper, self).__init__(stream)
  135. func = getattr(stream, method)
  136. if method == "write":
  137. @wraps(func)
  138. def write(data, *args, **kwargs):
  139. res = func(data, *args, **kwargs)
  140. callback(len(data))
  141. return res
  142. self.wrapper_setattr('write', write)
  143. elif method == "read":
  144. @wraps(func)
  145. def read(*args, **kwargs):
  146. data = func(*args, **kwargs)
  147. callback(len(data))
  148. return data
  149. self.wrapper_setattr('read', read)
  150. else:
  151. raise KeyError("Can only wrap read/write methods")
  152. def _is_utf(encoding):
  153. try:
  154. u'\u2588\u2589'.encode(encoding)
  155. except UnicodeEncodeError:
  156. return False
  157. except Exception:
  158. try:
  159. return encoding.lower().startswith('utf-') or ('U8' == encoding)
  160. except Exception:
  161. return False
  162. else:
  163. return True
  164. def _supports_unicode(fp):
  165. try:
  166. return _is_utf(fp.encoding)
  167. except AttributeError:
  168. return False
  169. def _is_ascii(s):
  170. if isinstance(s, str):
  171. for c in s:
  172. if ord(c) > 255:
  173. return False
  174. return True
  175. return _supports_unicode(s)
  176. def _screen_shape_wrapper(): # pragma: no cover
  177. """
  178. Return a function which returns console dimensions (width, height).
  179. Supported: linux, osx, windows, cygwin.
  180. """
  181. _screen_shape = None
  182. if IS_WIN:
  183. _screen_shape = _screen_shape_windows
  184. if _screen_shape is None:
  185. _screen_shape = _screen_shape_tput
  186. if IS_NIX:
  187. _screen_shape = _screen_shape_linux
  188. return _screen_shape
  189. def _screen_shape_windows(fp): # pragma: no cover
  190. try:
  191. import struct
  192. from ctypes import create_string_buffer, windll
  193. from sys import stdin, stdout
  194. io_handle = -12 # assume stderr
  195. if fp == stdin:
  196. io_handle = -10
  197. elif fp == stdout:
  198. io_handle = -11
  199. h = windll.kernel32.GetStdHandle(io_handle)
  200. csbi = create_string_buffer(22)
  201. res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
  202. if res:
  203. (_bufx, _bufy, _curx, _cury, _wattr, left, top, right, bottom,
  204. _maxx, _maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
  205. return right - left, bottom - top # +1
  206. except Exception: # nosec
  207. pass
  208. return None, None
  209. def _screen_shape_tput(*_): # pragma: no cover
  210. """cygwin xterm (windows)"""
  211. try:
  212. import shlex
  213. from subprocess import check_call # nosec
  214. return [int(check_call(shlex.split('tput ' + i))) - 1
  215. for i in ('cols', 'lines')]
  216. except Exception: # nosec
  217. pass
  218. return None, None
  219. def _screen_shape_linux(fp): # pragma: no cover
  220. try:
  221. from array import array
  222. from fcntl import ioctl
  223. from termios import TIOCGWINSZ
  224. except ImportError:
  225. return None, None
  226. else:
  227. try:
  228. rows, cols = array('h', ioctl(fp, TIOCGWINSZ, '\0' * 8))[:2]
  229. return cols, rows
  230. except Exception:
  231. try:
  232. return [int(os.environ[i]) - 1 for i in ("COLUMNS", "LINES")]
  233. except (KeyError, ValueError):
  234. return None, None
  235. def _environ_cols_wrapper(): # pragma: no cover
  236. """
  237. Return a function which returns console width.
  238. Supported: linux, osx, windows, cygwin.
  239. """
  240. warn("Use `_screen_shape_wrapper()(file)[0]` instead of"
  241. " `_environ_cols_wrapper()(file)`", DeprecationWarning, stacklevel=2)
  242. shape = _screen_shape_wrapper()
  243. if not shape:
  244. return None
  245. @wraps(shape)
  246. def inner(fp):
  247. return shape(fp)[0]
  248. return inner
  249. def _term_move_up(): # pragma: no cover
  250. return '' if (os.name == 'nt') and (colorama is None) else '\x1b[A'
  251. def _text_width(s):
  252. return sum(2 if east_asian_width(ch) in 'FW' else 1 for ch in str(s))
  253. def disp_len(data):
  254. """
  255. Returns the real on-screen length of a string which may contain
  256. ANSI control codes and wide chars.
  257. """
  258. return _text_width(RE_ANSI.sub('', data))
  259. def disp_trim(data, length):
  260. """
  261. Trim a string which may contain ANSI control characters.
  262. """
  263. if len(data) == disp_len(data):
  264. return data[:length]
  265. ansi_present = bool(RE_ANSI.search(data))
  266. while disp_len(data) > length: # carefully delete one char at a time
  267. data = data[:-1]
  268. if ansi_present and bool(RE_ANSI.search(data)):
  269. # assume ANSI reset is required
  270. return data if data.endswith("\033[0m") else data + "\033[0m"
  271. return data