Source code for jacinle.nd.indexing

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : indexing.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

from typing import Union, Sequence

import numpy as np

from .meta import isndarray

__all__ = ['one_hot', 'one_hot_nd', 'index_select_batch']


[docs] def one_hot(label: np.ndarray, nr_classes: int, dtype='float32') -> np.ndarray: """Convert a label array to one-hot array. This function works for either 0d (scalar) or 1d (vector) label array. If you want to convert higher dimensional label array, use :meth:`one_hot_nd` instead. Args: label: the label array. nr_classes: the number of classes. dtype: the data type of the one-hot array. Returns: the one-hot array. """ if isinstance(label, int) or (isndarray(label) and len(label.shape) == 0): out = np.zeros(nr_classes, dtype=dtype) out[int(label)] = 1 return out assert len(label.shape) == 1 nr_labels = label.shape[0] out = np.zeros((nr_labels, nr_classes), dtype=dtype) out[np.arange(nr_labels), label] = 1 return out
[docs] def one_hot_nd(label, nr_classes, dtype='float32'): """Convert a label array to one-hot array. Args: label: the label array. nr_classes: the number of classes. dtype: the data type of the one-hot array. Returns: the one-hot array. """ shape = label.shape return one_hot(label.reshape(-1), nr_classes, dtype=dtype).reshape(shape + (nr_classes, ))
[docs] def index_select_batch(data: Union[np.ndarray, Sequence[np.ndarray]], indices: Union[Sequence[int], np.ndarray]) -> np.ndarray: """Gather ``indices`` as batch indices from ``data``, which can either be typical nd array or a list of nd array. Args: data: the data array. indices: the indices to be selected. Returns: the selected data. """ assert isinstance(indices, (tuple, list)) or (isndarray(indices) and len(indices.shape) == 1) if isndarray(data): return data[indices] assert len(data) > 0 and len(indices) > 0 sample = np.array(data[0]) # Try to convert the first element to a typical nd array. output = np.empty((len(indices), ) + sample.shape, dtype=sample.dtype) for i, j in enumerate(indices): output[i] = data[j] return output