Source code for jaclearn.models.naive_bayes.hybrid_nb
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : hybrid_nb.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/26/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import sklearn.naive_bayes as nb
from jacinle.logging import get_logger
from jacinle.utils.enum import JacEnum
logger = get_logger(__file__)
__all__ = ['NaiveBayesianDistribution', 'HybridNB']
[docs]
class NaiveBayesianDistribution(JacEnum):
GAUSSIAN = 'gaussian'
MULTINOMIAL = 'multinomial'
BERNOULLI = 'bernoulli'
[docs]
class HybridNB(object):
[docs]
def __init__(self, distributions, weights=None, **kwargs):
self.models = []
for dist in distributions:
dist = NaiveBayesianDistribution.from_string(dist)
if dist is NaiveBayesianDistribution.GAUSSIAN:
model = nb.GaussianNB(**kwargs)
elif dist is NaiveBayesianDistribution.MULTINOMIAL:
model = nb.MultinomialNB(**kwargs)
elif dist is NaiveBayesianDistribution.BERNOULLI:
model = nb.BernoulliNB(**kwargs)
else:
raise ValueError('Unknown distribution: {}.'.format(dist))
kwargs['fit_prior'] = False # Except the first model.
self.models.append(model)
self.weights = weights
[docs]
def fit(self, xs, y, verbose=True):
assert len(xs) == len(self.models)
for x, model in zip(xs, self.models):
if verbose:
logger.info('Fitting model: {}.'.format(repr(model)))
model.fit(x, y)
[docs]
def predict(self, xs, verbose=True):
if self.weights is not None:
raise NotImplementedError('HybridNB.weights is not supported.')
log_prob = 0
for x, model in zip(xs, self.models):
if verbose:
logger.info('Predicting using model: {}.'.format(repr(model)))
log_prob += model.predict_log_proba(x)
return log_prob.argmax(axis=1)