Source code for jacinle.comm.distrib.controller

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : controller.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import collections
import contextlib
import os
import queue
import random
import threading

import zmq

from jacinle.comm.distrib import _configs
from jacinle.concurrency import zmq_utils as utils
from jacinle.logging import get_logger
from jacinle.utils.registry import CallbackRegistry

logger = get_logger(__file__)

__all__ = [
    'Controller', 'control'
]


# mapping from pipe_name => list of pipe
[docs] class ControllerPipeStorage(collections.defaultdict):
[docs] def __init__(self): super().__init__(list)
[docs] def pipe_info(self): return [(k, p.identifier) for k, pipes in self.items() for p in pipes]
[docs] def put(self, pipe): self[pipe.name].append(pipe)
[docs] @staticmethod def fn_filter_notempty(pipe): return not pipe.empty()
[docs] def filter_notempty(self, name): return list(filter(ControllerPipeStorage.fn_filter_notempty, self[name]))
[docs] @staticmethod def fn_filter_notfull(pipe): return not pipe.full()
[docs] def filter_notfull(self, name): return list(filter(ControllerPipeStorage.fn_filter_notfull, self[name]))
ControlMessage = collections.namedtuple('ControlMessage', ['sock', 'identifier', 'payload', 'countdown']) ControllerPeer = collections.namedtuple('ControllerPeer', ['info', 'csock']) PipePeer = collections.namedtuple('PipePeer', ['addr', 'port', 'dsock', 'pipes', 'ids']) BroadcastMessage = collections.namedtuple('BroadcastMessage', ['from_identifier', 'payload']) UnicastMessage = collections.namedtuple('UnicastMessage', ['from_identifier', 'to_identifier', 'payload'])
[docs] class Controller(object):
[docs] def __init__(self): self._uid = utils.uid() self._addr = utils.get_addr() # pipes of this controller self._imanager = ControllerPipeStorage() self._omanager = ControllerPipeStorage() # context and poller self._context = zmq.Context() self._context.sndhwm = _configs.CTL_CTL_HWM self._context.rcvhwm = _configs.CTL_CTL_HWM # socket pools self._ns_socket = None # control router and dispatcher self._control_router = None self._control_router_port = 0 self._control_dispatcher = CallbackRegistry() # queue of ControlMessage self._control_mqueue = queue.Queue() self._control_poller = zmq.Poller() # peers respect to the controller # map uid => ControllerPeer self._controller_peers = dict() # the peers of input pipes (i.e. the output pipes) self._input_from = dict() self._input_cache = dict() self._output_to = dict() self._output_to_pipe = collections.defaultdict(dict) # the peers of output pipes self._output_to_id = dict() self._output_cache = dict() # map pipe_name => cache self._data_poller = zmq.Poller() # threads and stop-event self._all_socks = set() self._all_threads = [] self._stop_event = threading.Event()
[docs] def socket(self, socket_type): sock = self._context.socket(socket_type) self._all_socks.add(sock) return sock
[docs] def close_socket(self, sock): utils.graceful_close(sock) self._all_socks.remove(sock) return self
[docs] def initialize(self, pipes=None): pipes = pipes or [] for pipe in pipes: if pipe.direction == 'IN': self._imanager.put(pipe) else: assert pipe.direction == 'OUT' self._omanager.put(pipe) pipe.set_controller(self) # setup ns socket self._ns_socket = self.socket(zmq.REQ) self._ns_socket.connect(os.getenv( 'JAC_NAME_SERVER', '{}://localhost:{}'.format( _configs.NS_CTL_PROTOCAL, _configs.NS_CTL_PORT ) )) self._control_poller.register(self._ns_socket, zmq.POLLIN) # setup router socket self._control_router = self.socket(zmq.ROUTER) self._control_router_port = self._control_router.bind_to_random_port('tcp://*') self._control_poller.register(self._control_router, zmq.POLLIN) # register on the name-server response = utils.req_send_and_recv(self._ns_socket, { 'action': _configs.Actions.NS_REGISTER_CTL_REQ, 'uid': self._uid, 'ctl_protocal': 'tcp', 'ctl_addr': self._addr, 'ctl_port': self._control_router_port, 'meta': {} }) assert response['action'] == _configs.Actions.NS_REGISTER_CTL_REP # register pipes on name-server response = utils.req_send_and_recv(self._ns_socket, { 'action': _configs.Actions.NS_REGISTER_OUTPUTS_REQ, 'uid': self._uid, 'outputs': list(self._omanager.keys()) }) assert response['action'] == _configs.Actions.NS_REGISTER_OUTPUTS_REP # query name-server for ipipes response = utils.req_send_and_recv(self._ns_socket, { 'action': _configs.Actions.NS_REGISTER_INPUTS_REQ, 'uid': self._uid, 'inputs': list(self._imanager.keys()) }) assert response['action'] == _configs.Actions.NS_REGISTER_INPUTS_REP logger.info('IPipes query {}.'.format(response['results'])) self._initialize_recv_peers(response['results']) # setup dispatcher self._control_dispatcher.register(_configs.Actions.NS_HEARTBEAT_REP, lambda msg: None) self._control_dispatcher.register(_configs.Actions.CTL_CONNECT_REQ, self._on_ctl_connect_req) self._control_dispatcher.register(_configs.Actions.CTL_CONNECT_REP, self._on_ctl_connect_rep) self._control_dispatcher.register(_configs.Actions.CTL_CONNECTED_REQ, self._on_ctl_connected_req) self._control_dispatcher.register(_configs.Actions.CTL_CONNECTED_REP, lambda msg: None) self._control_dispatcher.register(_configs.Actions.NS_NOTIFY_OPEN_REQ, self._on_ctl_notify_open_req) self._control_dispatcher.register(_configs.Actions.NS_NOTIFY_OPEN_REP, lambda msg: None) self._control_dispatcher.register(_configs.Actions.NS_NOTIFY_CLOSE_REQ, self._on_ctl_notify_close_req) self._control_dispatcher.register(_configs.Actions.NS_NOTIFY_CLOSE_REP, lambda msg: None) # run threads self._all_threads.append(threading.Thread(target=self._main, name='ctl-main')) self._all_threads.append(threading.Thread(target=self._main_heartbeat, name='ctl-main-ns-heartbeat')) for i in self._all_threads: i.start()
[docs] def finalize(self): self._stop_event.set() for i in self._all_threads: i.join() for sock in self._all_socks: utils.graceful_close(sock)
def _main(self): wait = 0 while True: if self._stop_event.wait(wait / 1000): break nr_done = 0 socks = dict(self._control_poller.poll(0)) nr_done += self._main_do_control_recv(socks) nr_done += self._main_do_control_send() socks = dict(self._data_poller.poll(0)) nr_done += self._main_do_data_recv(socks) nr_done += self._main_do_data_send() if nr_done > 0: wait = wait / 2 if wait > 1 else 0 else: wait = wait + 1 if wait < 50 else 50 def _main_heartbeat(self): while True: self._control_mqueue.put(ControlMessage(self._ns_socket, None, { 'action': _configs.Actions.NS_HEARTBEAT_REQ, 'uid': self._uid }, countdown=0)) if self._stop_event.wait(_configs.NS_HEARTBEAT_INTERVAL): break def _main_do_control_recv(self, socks): nr_done = 0 # ns if self._ns_socket in socks: for msg in utils.iter_recv(utils.req_recv_json, self._ns_socket): self._control_dispatcher.dispatch(msg['action'], msg) nr_done += 1 # router if self._control_router in socks: for identifier, msg in utils.iter_recv(utils.router_recv_json, self._control_router): self._control_dispatcher.dispatch(msg['action'], identifier, msg) nr_done += 1 for info, csock in self._controller_peers.values(): if csock in socks: for msg in utils.iter_recv(utils.req_recv_json, csock): self._control_dispatcher.dispatch(msg['action'], msg) nr_done += 1 return nr_done def _main_do_control_send(self): nr_scheduled = self._control_mqueue.qsize() nr_done = 0 for i in range(nr_scheduled): job = self._control_mqueue.get() if job.identifier is not None: rc = utils.router_send_json(job.sock, job.identifier, job.payload, flag=zmq.NOBLOCK) else: rc = utils.req_send_json(job.sock, job.payload, flag=zmq.NOBLOCK) if not rc: if job.countdown > 0: self._control_mqueue.put(ControlMessage(job[0], job[1], job[2], job.countdown - 1)) else: nr_done += 1 return nr_done def _main_do_data_recv(self, in_socks): nr_done = 0 for name in self._input_from: cache = self._input_cache.pop(name, None) if cache is None: peer = self._input_from[name] if peer.dsock in in_socks: msg = utils.pull_pyobj(peer.dsock) cache = (msg['name'], msg['from_identifier'], msg.get('to_identifier', None), msg['data']) if cache is None: continue nr_done_this = 0 if cache[2] is None: # is broadcast for p in self._imanager.filter_notfull(cache[0]): p.raw_queue.put_nowait(BroadcastMessage(cache[1], cache[-1])) nr_done_this += 1 else: for p in self._imanager.filter_notfull(cache[0]): if p.identifier == cache[2]: p.raw_queue.put_nowait(UnicastMessage(cache[1], cache[2], cache[-1])) nr_done_this += 1 if nr_done_this > 0: nr_done += nr_done_this else: self._input_cache[name] = cache return nr_done def _main_do_data_send(self): nr_done = 0 for name in self._omanager.keys(): cache = self._output_cache.get(name, None) if cache is None: pipes = self._omanager.filter_notempty(name) if len(pipes) != 0: pipe = random.choice(pipes) cache = pipe.raw_queue.get_nowait() self._output_cache[name] = cache if cache is None: continue nr_done_this = 0 if isinstance(cache, BroadcastMessage): for peer in self._output_to_pipe[name].values(): nr_done_this += utils.push_pyobj(peer.dsock, { 'uid': self._uid, 'name': name, 'from_identifier': cache.from_identifier, 'data': cache.payload }, flag=zmq.NOBLOCK) elif isinstance(cache, UnicastMessage): if (name, cache.to_identifier) in self._output_to_id: peer = self._output_to_id[(name, cache.to_identifier)] nr_done_this += utils.push_pyobj(peer.dsock, { 'uid': self._uid, 'name': name, 'from_identifier': cache.from_identifier, 'to_identifier': cache.to_identifier, 'data': cache.payload }, flag=zmq.NOBLOCK) else: raise TypeError('Unknown message type: {}.'.format(type(cache))) if nr_done_this > 0: self._output_cache[name] = None nr_done += nr_done_this return nr_done # BEGIN:: Connection def _initialize_recv_peers(self, results): for peers in results.values(): for info in peers: uid = info['uid'] if uid not in self._controller_peers: self._controller_peers[uid] = ControllerPeer(info, None) self._do_setup_ctl_peer(uid) def _do_setup_ctl_peer(self, uid): info, sock = self._controller_peers[uid] if sock is not None: return sock = self.socket(zmq.REQ) sock.connect('{}://{}:{}'.format(info['ctl_protocal'], info['ctl_addr'], info['ctl_port'])) self._control_poller.register(sock, zmq.POLLIN) self._control_mqueue.put(ControlMessage(sock, None, { 'action': _configs.Actions.CTL_CONNECT_REQ, 'uid': self._uid, 'inputs': self._imanager.pipe_info() }, countdown=_configs.CTL_CTL_SND_COUNTDOWN)) self._controller_peers[uid] = ControllerPeer(info, sock) logger.info('Connecting to "{}".'.format(uid)) def _on_ctl_connect_req(self, identifier, msg): uid, pipes = msg['uid'], msg['inputs'] flag = False for name, _ in pipes: if name in self._omanager: flag = True break if flag: response = {} if uid in self._output_to: port = self._output_to[uid].port else: sock = self.socket(zmq.PUSH) port = sock.bind_to_random_port('{}://{}'.format(_configs.CTL_DAT_PROTOCAL, _configs.CTL_DAT_HOST)) pipes_rec = {p[0] for p in pipes} ids_rec = {tuple(p) for p in pipes} peer = PipePeer(self._addr, port, sock, pipes_rec, ids_rec) self._output_to[uid] = peer for p in pipes_rec: self._output_to_pipe[p][uid] = peer for i in ids_rec: self._output_to_id[i] = peer logger.info('Connection opened for "{}": port={}.'.format(uid, port)) if port > 0: response = { 'dat_protocal': _configs.CTL_DAT_PROTOCAL, 'dat_addr': self._addr, 'dat_port': port } self._control_mqueue.put(ControlMessage(self._control_router, identifier, { 'action': _configs.Actions.CTL_CONNECT_REP, 'uid': self._uid, 'conn': response }, countdown=_configs.CTL_CTL_SND_COUNTDOWN)) def _on_ctl_connect_rep(self, msg): uid, conn = msg['uid'], msg['conn'] if len(conn) and uid not in self._input_from: sock = self.socket(zmq.PULL) sock.connect('{}://{}:{}'.format(conn['dat_protocal'], conn['dat_addr'], conn['dat_port'])) self._data_poller.register(sock, zmq.POLLIN) self._input_from[uid] = PipePeer(conn['dat_addr'], conn['dat_port'], sock, None, None) logger.info('Connection established to "{}": remote_port={}.'.format(uid, conn['dat_port'])) self._control_mqueue.put(ControlMessage(self._controller_peers[uid].csock, None, { 'action': _configs.Actions.CTL_CONNECTED_REQ, 'uid': self._uid }, countdown=_configs.CTL_CTL_SND_COUNTDOWN)) def _on_ctl_connected_req(self, identifier, msg): self._control_mqueue.put(ControlMessage(self._control_router, identifier, { 'action': _configs.Actions.CTL_CONNECTED_REP, 'uid': self._uid }, countdown=_configs.CTL_CTL_SND_COUNTDOWN)) logger.info('Connection established for "{}".'.format(msg['uid'])) # END:: Connection def _on_ctl_notify_open_req(self, identifier, msg): uid = msg['uid'] if uid not in self._controller_peers: self._controller_peers[uid] = ControllerPeer(msg['info'], None) self._do_setup_ctl_peer(uid) self._control_mqueue.put(ControlMessage(self._control_router, identifier, { 'action': _configs.Actions.NS_NOTIFY_OPEN_REP, 'uid': self._uid }, countdown=_configs.CTL_CTL_SND_COUNTDOWN)) logger.info('Found new controller: "{}".'.format(uid)) def _on_ctl_notify_close_req(self, identifier, msg): uid = msg['uid'] if uid in self._controller_peers: peer = self._controller_peers.pop(uid) self._control_poller.unregister(peer.csock) self.close_socket(peer.csock) if uid in self._input_from: peer = self._input_from.pop(uid) self._data_poller.unregister(peer.dsock) self.close_socket(peer.dsock) if uid in self._input_cache: del self._input_cache[uid] if uid in self._output_to: peer = self._output_to.pop(uid) self.close_socket(peer.dsock) for k in peer.pipes: self._output_to_pipe[k].pop(uid) for k in peer.ids: self._output_to_id.pop(k) if uid in self._output_cache: del self._output_cache[uid] self._control_mqueue.put(ControlMessage(self._control_router, identifier, { 'action': _configs.Actions.NS_NOTIFY_CLOSE_REP, 'uid': self._uid }, countdown=_configs.CTL_CTL_SND_COUNTDOWN)) logger.info('Close timeout controller: "{}".'.format(uid))
[docs] @contextlib.contextmanager def control(pipes): ctl = Controller() ctl.initialize(pipes) yield ctl ctl.finalize()