Source code for jacinle.comm.cs

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : cs.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import collections
import contextlib
import queue
import threading

import zmq

from jacinle.concurrency.packing import dumpb, loadb
from jacinle.concurrency.zmq_utils import get_addr, bind_to_random_ipc, graceful_close
from jacinle.logging import get_logger
from jacinle.utils.meta import notnone_property
from jacinle.utils.registry import CallbackRegistry

logger = get_logger(__file__)

__all__ = ['ServerPipe', 'ClientPipe', 'make_cs_pair']

_QueryMessage = collections.namedtuple('QueryMessage', ['identifier', 'payload'])


[docs] class ServerPipe(object):
[docs] def __init__(self, name, send_qsize=0, mode='tcp'): self._name = name self._conn_info = None self._context_lock = threading.Lock() self._context = zmq.Context() self._tosock = self._context.socket(zmq.ROUTER) self._frsock = self._context.socket(zmq.PULL) self._tosock.set_hwm(10) self._frsock.set_hwm(10) self._dispatcher = CallbackRegistry() self._send_queue = queue.Queue(maxsize=send_qsize) self._rcv_thread = None self._snd_thread = None self._mode = mode assert mode in ('ipc', 'tcp')
@property def dispatcher(self): return self._dispatcher @notnone_property def conn_info(self): return self._conn_info
[docs] def initialize(self, tcp_port=None, ipc_port=None): self._conn_info = [] if self._mode == 'tcp': if tcp_port is not None: port = tcp_port[0] self._frsock.bind('tcp://*:{}'.format(port)) else: port = self._frsock.bind_to_random_port('tcp://*') self._conn_info.append('tcp://{}:{}'.format(get_addr(), port)) if tcp_port is not None: port = tcp_port[1] self._tosock.bind('tcp://*:{}'.format(port)) else: port = self._tosock.bind_to_random_port('tcp://*') self._conn_info.append('tcp://{}:{}'.format(get_addr(), port)) elif self._mode == 'ipc': if ipc_port is not None: self._frsock.bind('ipc://' + ipc_port[0]) self._conn_info.append('ipc://' + ipc_port[0]) else: self._conn_info.append(bind_to_random_ipc(self._frsock, self._name + '-c2s-')) if ipc_port is not None: self._tosock.bind('ipc://' + ipc_port[1]) self._conn_info.append('ipc://' + ipc_port[1]) else: self._conn_info.append(bind_to_random_ipc(self._tosock, self._name + '-s2c-')) self._rcv_thread = threading.Thread(target=self.mainloop_recv, daemon=True) self._rcv_thread.start() self._snd_thread = threading.Thread(target=self.mainloop_send, daemon=True) self._snd_thread.start()
[docs] def finalize(self): graceful_close(self._tosock) graceful_close(self._frsock) self._context.term()
[docs] @contextlib.contextmanager def activate(self, tcp_port=None, ipc_port=None): self.initialize(tcp_port=tcp_port, ipc_port=ipc_port) try: yield finally: self.finalize()
[docs] def mainloop_recv(self): try: while True: if self._frsock.closed: break msg = loadb(self._frsock.recv(copy=False).bytes) identifier, type, payload = msg self._dispatcher.dispatch(type, self, identifier, payload) except zmq.ContextTerminated: pass except zmq.ZMQError as e: if self._tosock.closed: logger.warning('Recv socket closed unexpectedly.') else: raise e
[docs] def mainloop_send(self): try: while True: if self._tosock.closed: break job = self._send_queue.get() self._tosock.send_multipart([job.identifier, dumpb(job.payload)], copy=False) except zmq.ContextTerminated: pass except zmq.ZMQError as e: if self._tosock.closed: logger.warning('Send socket closed unexpectedly.') else: raise e
[docs] def send(self, identifier, msg): self._send_queue.put(_QueryMessage(identifier, msg))
[docs] class ClientPipe(object):
[docs] def __init__(self, name, conn_info): self._name = name self._conn_info = conn_info self._context = None self._tosock = None self._frsock = None
@property def identity(self): return self._name.encode('utf-8')
[docs] def initialize(self): self._context = zmq.Context() self._tosock = self._context.socket(zmq.PUSH) self._frsock = self._context.socket(zmq.DEALER) self._tosock.setsockopt(zmq.IDENTITY, self.identity) self._frsock.setsockopt(zmq.IDENTITY, self.identity) self._tosock.set_hwm(2) self._tosock.connect(self._conn_info[0]) self._frsock.connect(self._conn_info[1])
[docs] def finalize(self): graceful_close(self._frsock) graceful_close(self._tosock) self._context.term()
[docs] @contextlib.contextmanager def activate(self): self.initialize() try: yield finally: self.finalize()
[docs] def query(self, type, inp=None, do_recv=True): self._tosock.send(dumpb((self.identity, type, inp)), copy=False) if do_recv: return self.recv()
[docs] def recv(self): out = loadb(self._frsock.recv(copy=False).bytes) return out
[docs] def make_cs_pair(name, nr_clients=None, mode='tcp', send_qsize=10): rep = ServerPipe(name + '-rep', mode=mode, send_qsize=send_qsize) rep.initialize() nr_reqs = nr_clients or 1 reqs = [ClientPipe(name + '-req-' + str(i), rep.conn_info) for i in range(nr_reqs)] if nr_clients is None: return rep, reqs[0] return rep, reqs