Source code for jactorch.quickstart.inference

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : inference.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/18/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import time
import queue
import threading
import contextlib

import torch

from jacinle.concurrency.future import FutureResult
from jacinle.utils.meta import map_exec_method
from jactorch.utils.meta import as_numpy, as_tensor
from jacinle.nd import batchify, unbatchify

__all__ = ['ModelInferencer', 'AsyncInferenceTask', 'AsyncModelInferencer', 'BatchedAsyncModelInferencer']


[docs] class AsyncInferenceTask(object): __slots__ = ('future', 'feed_dict')
[docs] def __init__(self, feed_dict, future=None): self.feed_dict = feed_dict if future is None: future = FutureResult() self.future = future
[docs] def get_result(self): return self.future.get()
[docs] def put_result(self, result): return self.future.put(result)
[docs] class ModelInferencer(object):
[docs] def __init__(self, model): self._model = model
[docs] @contextlib.contextmanager def activate(self): self.initialize() yield self self.finalize()
[docs] def initialize(self): pass
[docs] def finalize(self): pass
[docs] def inference(self, feed_dict): return self._inference_model(feed_dict)
def _inference_model(self, feed_dict): feed_dict = as_tensor(feed_dict) with torch.no_grad(): return as_numpy(self._model(feed_dict))
[docs] class AsyncModelInferencer(ModelInferencer):
[docs] def __init__(self, model, nr_workers=1): super().__init__(model) self._nr_workers = nr_workers self._task_queue = None self._workers = []
[docs] def initialize(self): assert len(self._workers) == 0 self._task_queue = queue.Queue() for rank in range(self._nr_workers): th = threading.Thread(target=self._mainloop_worker, args=(rank, )) th.start() self._workers.append(th)
[docs] def finalize(self): if len(self._workers) == 0: return for rank in range(self._nr_workers): self._task_queue.put(None) map_exec_method('join', self._workers)
def _mainloop_worker(self, rank): while True: task = self._task_queue.get() if task is None: break task.put_result(self._inference_model(task.feed_dict))
[docs] def inference(self, feed_dict, future=None): task = AsyncInferenceTask(feed_dict, future=future) self._task_queue.put(task) return task
[docs] class BatchedAsyncModelInferencer(AsyncModelInferencer):
[docs] def __init__(self, model, nr_workers=1, batch_size=8, latency=10): super().__init__(model, nr_workers=nr_workers) self._batch_size = batch_size self._latency = latency / 1000
def _mainloop_worker(self, rank): while True: tasks = [] stop_signal = False last_time = time.time() + self._latency for i in range(self._batch_size): if len(tasks) > 0: task = self._task_queue.get(timeout=last_time - time.time()) else: task = self._task_queue.get() if task is None: stop_signal = True break else: tasks.append(task) if len(tasks): batched_feed = batchify([t.feed_dict for t in tasks]) outputs = unbatchify(self._inference_model(batched_feed)) for t, o in zip(tasks, outputs): t.put_result(o) if stop_signal: break