Source code for jacinle.nd.meta
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : meta.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/30/2017
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import collections
from typing import Any, Optional, Sequence, List
import numpy as np
__all__ = [
'isndarray', 'is_ndarray',
'nd_concat', 'nd_len', 'nd_batch_size',
'nd_split_n', 'size_split_n'
]
[docs]
def isndarray(thing: Any) -> bool:
"""Check if the given object is a numpy array."""
return isinstance(thing, np.ndarray)
[docs]
def is_ndarray(thing: Any) -> bool:
return isinstance(thing, np.ndarray)
[docs]
def nd_concat(list_of_arrays: Sequence[np.ndarray]) -> Optional[np.ndarray]:
"""Concatenate a list of numpy arrays. This function handles the case when the list is empty or contains only one element.
Args:
list_of_arrays: a list of numpy arrays.
Returns:
the concatenated array, or None if the list is empty.
"""
if len(list_of_arrays) == 0:
return None
elif len(list_of_arrays) == 1:
return list_of_arrays[0]
else:
return np.concatenate(list_of_arrays)
[docs]
def nd_len(thing: Any) -> int:
"""Get the length of a numpy array. This function handles the case when the input is a scalar or plain Python objects.
Args:
thing: the input array.
Returns:
the length of the array, or 1 if the input is a scalar or plain Python objects.
"""
if type(thing) in (int, float):
return 1
if isndarray(thing):
return thing.shape[0]
return len(thing)
[docs]
def nd_batch_size(thing: Any) -> int:
"""Get the batch size of a numpy array. This function handles the case when the input a nested list or dict.
Examples:
>>> nd_batch_size(np.array([1, 2, 3]))
3
>>> nd_batch_size([np.zeros((2, 3)), np.zeros((2, 5))])
2
>>> nd_batch_size({'a': np.zeros((2, 3)), 'b': np.zeros((2, 5))})
2
Args:
thing: the input array or nested list/dict.
Returns:
the batch size of the array.
"""
if type(thing) in (tuple, list):
return nd_len(thing[0])
elif type(thing) in (dict, collections.OrderedDict):
return nd_len(next(thing.values()))
else:
return nd_len(thing)
[docs]
def size_split_n(full_size: Optional[int], n: int) -> Optional[List[int]]:
"""Split a size into n parts. If the size is not divisible by n, the last part will be larger.
When the size is None, None will be returned.
Args:
full_size: the size to be split.
n: the number of parts.
Returns:
a list of sizes.
"""
if full_size is None:
return None
result = [full_size // n] * n
rest = full_size % n
if rest != 0:
result[-1] += rest
return result
[docs]
def nd_split_n(ndarray: np.ndarray, n: int) -> List[np.ndarray]:
"""Split a numpy array into n parts. If the size is not divisible by n, the last part will be larger.
Args:
ndarray: the array to be split.
n: the number of parts.
Returns:
a list of arrays.
"""
sub_sizes = size_split_n(len(ndarray), n)
res = []
cur = 0
for i in range(n):
res.append(ndarray[cur:cur+sub_sizes[i]])
cur += sub_sizes[i]
return res