Source code for jacinle.concurrency.counter
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : counter.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/18/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import itertools
import threading
import queue
import collections
import time
__all__ = ['Counter', 'CounterBasedEvent', 'CounterBasedMonitor']
[docs]
class Counter(object):
[docs]
def __init__(self):
self._cnt = itertools.count()
self._ref = itertools.count()
self._iter_cnt = iter(self._cnt)
self._iter_ref = iter(self._ref)
[docs]
def tick(self):
next(self._iter_cnt)
[docs]
def get(self):
ref = next(self._iter_ref)
cnt = next(self._iter_cnt)
return cnt - ref
[docs]
class CounterBasedEvent(object):
"""Thread-safe counter-based callback invoker. When the counter is incremented, the system will check whether
the counter has reached a target value. If so, the event will be set."""
[docs]
def __init__(self, target, tqdm=None):
self._cnt = itertools.count()
self._iter_cnt = iter(self._cnt)
self._target = target
self._event = threading.Event()
self._tick_mutex = threading.Lock()
self._tqdm = tqdm
[docs]
def tick(self):
with self._tick_mutex:
return self.__tick()
def __tick(self):
value = next(self._iter_cnt)
if self._tqdm is not None:
self._tqdm.update(1)
if value >= self._target:
self._event.set()
if self._tqdm is not None:
self._tqdm.close()
return value
[docs]
def is_set(self):
return self._event.is_set()
[docs]
def clear(self):
self._event.clear()
[docs]
def wait(self, timeout=None):
return self._event.wait(timeout=timeout)
[docs]
class CounterBasedMonitor(object):
_displayer = None
[docs]
def __init__(self, counters=None, display_names=None, interval=1, printf=None):
if counters is None:
counters = ['DEFAULT']
self._display_names = display_names
self._counters = collections.OrderedDict([(n, Counter()) for n in counters])
self._interval = interval
self._printf = printf
if self._printf is None:
from jacinle.logging import get_logger
logger = get_logger(__file__)
self._printf = logger.info
@property
def _counter_names(self):
return list(self._counters.keys())
[docs]
def tick(self, name=None):
if len(self._counter_names) == 1:
self._counters[self._counter_names[0]].tick()
else:
assert name is None, 'Must provide name if there are multiple counters.'
self._counters[name].tick()
[docs]
def start(self):
self._displayer = threading.Thread(target=self._display_thread, daemon=True)
self._displayer.start()
return self
def _display(self, deltas, interval):
names = self._display_names or self._counter_names
if len(names) == 1:
self._printf('Counter monitor {}: {} ticks/s.'.format(names[0], deltas[0]/interval))
else:
log_strs = ['Counter monitor:']
for n, v in zip(names, deltas):
log_strs.append('\t{}: {} ticks/s'.format(n, v/interval))
self._printf('\n'.join(log_strs))
def _display_thread(self):
prev = [c.get() for _, c in self._counters.items()]
while True:
time.sleep(self._interval)
curr = [c.get() for _, c in self._counters.items()]
deltas = [c - p for p, c in zip(prev, curr)]
prev = curr
self._display(deltas, self._interval)