Source code for jacinle.utils.tqdm

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : tqdm.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/23/2017
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import threading

from tqdm import tqdm as _tqdm
from .meta import gofor

__all__ = ['get_tqdm_defaults', 'get_current_tqdm', 'tqdm', 'tqdm_pbar', 'tqdm_gofor', 'tqdm_zip']

__tqdm_defaults = {'dynamic_ncols': True, 'ascii': True}


[docs] def get_tqdm_defaults(): """Get the default kwargs for tqdm.""" return __tqdm_defaults
[docs] def get_current_tqdm(): """Get the current tqdm instance. Only tqdms created by :func:`tqdm` will be returned.""" _init_tqdm_stack() assert len(get_current_tqdm._stack.data) > 0, 'No registered tqdm.' return get_current_tqdm._stack.data[0]
get_current_tqdm._stack = threading.local() get_current_tqdm._stack.data = list() def _init_tqdm_stack(): if not hasattr(get_current_tqdm._stack, 'data'): get_current_tqdm._stack.data = list()
[docs] def tqdm(iterable, **kwargs): """Wrapped tqdm, where default kwargs will be load, and support `for i in tqdm(10)` usage.""" for k, v in get_tqdm_defaults().items(): kwargs.setdefault(k, v) if type(iterable) is int: iterable, total = range(iterable), iterable elif type(iterable) is float: iterable, total = range(int(iterable)), iterable else: try: total = len(iterable) except TypeError: total = None if 'total' not in kwargs and total is not None: kwargs['total'] = total with _tqdm(**kwargs) as pbar: _init_tqdm_stack() get_current_tqdm._stack.data.append(pbar) try: for data in iterable: yield data pbar.update() finally: get_current_tqdm._stack.data.pop()
[docs] def tqdm_pbar(**kwargs): """Create a tqdm progress bar with the given kwargs.""" for k, v in get_tqdm_defaults().items(): kwargs.setdefault(k, v) return _tqdm(**kwargs)
[docs] def tqdm_gofor(iterable, **kwargs): """Create a tqdm progress bar for the given iterable, and use it as the progress bar for :func:`jacinle.utils.meta.gofor`.""" try: total = len(iterable) except TypeError: total = None kwargs.setdefault('total', total) return tqdm(gofor(iterable), **kwargs)
[docs] def tqdm_zip(*iterable, **kwargs): """Create a tqdm progress bar for a zip of the given iterables, and use it as the progress bar.""" try: total = len(iterable[0]) except TypeError: total = None kwargs.setdefault('total', total) return tqdm(zip(*iterable), **kwargs)