Тензорный поток: дивергенция КЛ для гауссовой смеси

Я знаю с Python и Scikit, как рассчитать расхождение KL для гауссовой смеси, учитывая, что ее параметры, такие как вес, среднее значение и ковариация, как np.array, как показано ниже.

Инициализация GaussianMixture с использованием параметров компонента - sklearn

KL-Дивергенция двух GMM

Но мне интересно, с Tensorflow, есть ли способ рассчитать расхождение KL между двумя гауссовской смеси, учитывая, что ее параметры как Tensor,

1) Я попробовал scikit выше в Tensorflow, но он не работал, так как Tensorflow не дает ему фактические значения, пока сессия не будет выполнена.

2) Есть несколько пакетов TF, но не совсем KL для гауссовой смеси. https://www.tensorflow.org/api_docs/python/tf/contrib/distributions/Mixture

https://www.tensorflow.org/api_docs/python/tf/distributions/kl_divergence

Любая помощь с благодарностью.

Позже я попробовал использовать последнюю версию библиотеки TF, как показано ниже.

import tensorflow as tf
print('tensorflow ',tf.__version__)  # for Python 3
import numpy as np
import matplotlib.pyplot as plt

ds = tf.contrib.distributions
kl_divergence=tf.contrib.distributions.kl_divergence

# Gaussian Mixure1
mix = 0.3# weight
bimix_gauss1 = ds.Mixture(
cat=ds.Categorical(probs=[mix, 1.-mix]),#weight
components=[
   ds.Normal(loc=-1., scale=0.1),
   ds.Normal(loc=+1., scale=0.5),
])

# Gaussian Mixture2
mix = 0.4# weight
bimix_gauss2 = ds.Mixture(
    cat=ds.Categorical(probs=[mix, 1.-mix]),#weight
    components=[
    ds.Normal(loc=-0.4, scale=0.2),
    ds.Normal(loc=+1.2, scale=0.6),
])

# KL between GM1 and GM2
kl_value=kl_divergence(
    distribution_a=bimix_gauss1,
    distribution_b=bimix_gauss2,
    allow_nan_stats=True,
    name=None
)

sess = tf.Session() # 
with sess.as_default():

    x = tf.linspace(-2., 3., int(1e4)).eval()
    plt.plot(x, bimix_gauss1.prob(x).eval(),'r-')
    plt.plot(x, bimix_gauss2.prob(x).eval(),'b-')
    plt.show()

    print('kl_value=',kl_value.eval())

Затем я получил эту ошибку...NotImplementedError: Нет KL(distribution_a || distribution_b) зарегистрировано для типа distribution_a Mixture и типа distribution_b Mixture

Мне сейчас очень грустно.:(

0 ответов

Другие вопросы по тегам