Source code for jactorch.functional.probability
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : probability.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/04/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
"""Probability distributions related functions."""
__all__ = ['normalize_prob', 'check_prob_normalization']
[docs]
def normalize_prob(a, dim=-1):
"""Perform 1-norm along the specific dimension."""
return a / a.sum(dim=dim, keepdim=True)
[docs]
def check_prob_normalization(p, dim=-1, atol=1e-5):
"""Check if the probability is normalized along a specific dimension."""
tot = p.sum(dim=dim)
cond = (tot > 1 - atol) * (tot < 1 + atol)
cond = cond.prod()
assert int(cond.data.cpu().numpy()) == 1, 'Probability normalization check failed.'