Source code for jacinle.comm.cs_simple

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : cs_simple.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 12/26/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Simplified communication module for client-server communication using REP-REQ pattern."""

import collections
import contextlib
import queue
import threading

import zmq

from jacinle.concurrency.packing import dumpb, loadb
from jacinle.concurrency.zmq_utils import get_addr, bind_to_random_ipc, graceful_close
from jacinle.logging import get_logger
from jacinle.utils.meta import notnone_property
from jacinle.utils.registry import CallbackRegistry

logger = get_logger(__file__)


[docs] class SimpleServerPipe(object):
[docs] def __init__(self, name: str, mode: str = 'tcp'): assert mode in ('ipc', 'tcp') self._name = name self._conn_info = None self._mode = mode self._context = zmq.Context() self._socket = self._context.socket(zmq.REP) self._message_counter = 0 self._dispatcher = CallbackRegistry()
@property def dispatcher(self): return self._dispatcher @property def conn_info(self): assert self._conn_info is not None, 'The pipe is not initialized yet.' return self._conn_info
[docs] def initialize(self, tcp_port=None, ipc_port=None): if self._mode == 'tcp': if tcp_port is None: port = self._socket.bind_to_random_port('tcp://*') else: if isinstance(tcp_port, (int, str)): tcp_port = tcp_port elif isinstance(tcp_port, (tuple, list)): tcp_port = tcp_port[0] else: raise ValueError('Invalid tcp_port: {}.'.format(tcp_port)) self._socket.bind('tcp://*:{}'.format(tcp_port)) port = tcp_port self._conn_info = 'tcp://{}:{}'.format(get_addr(), port) print('ServerPipe initialized: {}.'.format(self._conn_info)) else: if ipc_port is None: ipc_port = bind_to_random_ipc(self._socket, self._name) else: if len(ipc_port) == 2: ipc_port = ipc_port[0] self._socket.bind('ipc://{}'.format(ipc_port)) self._conn_info = 'ipc://{}'.format(ipc_port) logger.info('ServerPipe initialized: {}.'.format(self._conn_info))
[docs] def finalize(self): graceful_close(self._socket) self._context.term()
[docs] @contextlib.contextmanager def activate(self, tcp_port=None, ipc_port=None): self.initialize(tcp_port, ipc_port) try: yield finally: self.finalize()
[docs] def serve_forever(self): while True: type, message = loadb(self._socket.recv()) self._message_counter += 1 self._dispatcher.dispatch(type, self, self._message_counter, message)
[docs] def send(self, identifier, response): assert self._message_counter == identifier self._message_counter += 1 self._socket.send(dumpb(response))
[docs] class SimpleClientPipe(object):
[docs] def __init__(self, name: str, conn_info): self._name = name self._conn_info = conn_info self._context = zmq.Context() self._socket = self._context.socket(zmq.REQ) self._initialized = False if isinstance(self._conn_info, (tuple, list)): self._conn_info = conn_info[0]
[docs] def initialize(self, timeout=None): rv = self._socket.connect(self._conn_info) self._initialized = True return True
[docs] def finalize(self): if self._initialized: graceful_close(self._socket) self._context.term() self._initialized = False
[docs] @contextlib.contextmanager def activate(self, timeout=None): self.initialize() try: yield finally: self.finalize()
[docs] def query(self, type, message=None, do_recv=True): self._socket.send(dumpb((type, message))) if do_recv: return self.recv()
[docs] def recv(self): return loadb(self._socket.recv())