#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : packing.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import os
import functools
from jacinle.utils.enum import JacEnum
from jacinle.utils.registry import RegistryGroup, CallbackRegistry
__all__ = [
'check_pickle', 'loadb_pickle', 'dumpb_pickle',
'check_msgpack', 'loadb_msgpack', 'dumpb_msgpack',
'check_pyarrow', 'loadb_pyarrow', 'dumpb_pyarrow',
'loadb', 'dumpb',
'get_available_backends', 'get_default_backend', 'set_default_backend'
]
import pickle
loadb_pickle = pickle.loads
dumpb_pickle = pickle.dumps
try:
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
dumpb_msgpack = functools.partial(msgpack.dumps, use_bin_type=True)
loadb_msgpack = msgpack.loads
except ImportError:
dumpb_msgpack = loadb_msgpack = None
try:
import pyarrow
dumpb_pyarrow = lambda obj: pyarrow.serialize(obj).to_buffer()
loadb_pyarrow = lambda buffer: pyarrow.deserialize(buffer)
except ImportError:
dumpb_pyarrow = loadb_pyarrow = None
class _PackingFunctionRegistryGroup(RegistryGroup):
__base_class__ = CallbackRegistry
def dispatch(self, registry_name, entry, *args, **kwargs):
return self[registry_name].dispatch(entry, *args, **kwargs)
_packing_function_registry = _PackingFunctionRegistryGroup()
[docs]
def check_pickle():
return True
[docs]
def check_msgpack():
return dumpb_msgpack is not None
[docs]
def check_pyarrow():
return dumpb_pyarrow is not None
class _PackingBackend(JacEnum):
PICKLE = 'pickle'
MSGPACK = 'msgpack'
PYARROW = 'pyarrow'
_packing_function_registry.register('check', _PackingBackend.PICKLE, lambda: True)
_packing_function_registry.register('check', _PackingBackend.MSGPACK, check_msgpack)
_packing_function_registry.register('check', _PackingBackend.PYARROW, check_pyarrow)
_packing_function_registry.register('loadb', _PackingBackend.PICKLE, loadb_pickle)
_packing_function_registry.register('dumpb', _PackingBackend.PICKLE, dumpb_pickle)
_packing_function_registry.register('loadb', _PackingBackend.MSGPACK, loadb_msgpack)
_packing_function_registry.register('dumpb', _PackingBackend.MSGPACK, dumpb_msgpack)
_packing_function_registry.register('loadb', _PackingBackend.PYARROW, loadb_pyarrow)
_packing_function_registry.register('dumpb', _PackingBackend.PYARROW, dumpb_pyarrow)
_default_packing_backend = _PackingBackend.PICKLE
[docs]
def get_default_backend():
return _default_packing_backend.name
[docs]
def get_available_backends():
return [obj.name for obj in _PackingBackend.choice_objs()
if _packing_function_registry.dispatch('check', obj)]
[docs]
def set_default_backend(backend):
global _default_packing_backend
_default_packing_backend = _PackingBackend.from_string(backend)
assert _default_packing_backend.name in get_available_backends(), (
'Unsupported backend on your machine: "{}".'.format(_default_packing_backend.name))
[docs]
def loadb(bstr, *args, backend=None, **kwargs):
backend = backend or _default_packing_backend
return _packing_function_registry.dispatch('loadb', backend, bstr, *args, **kwargs)
[docs]
def dumpb(obj, *args, backend=None, **kwargs):
backend = backend or _default_packing_backend
return _packing_function_registry.dispatch('dumpb', backend, obj, *args, **kwargs)
def _initialize_backend():
set_default_backend(os.getenv('JAC_PACKING_BACKEND', _PackingBackend.PICKLE))
_initialize_backend()