In this notebook we demonstrate the asymmetry of KL divergence \(D_{KL}(p \| q)\) by using KL as cost function for an descent algorithm for density estimation.

Definition of KL

\begin{align} D_{KL}(p | q) & = -\int_{\mathbb{R}} p(x) \log \left(\frac{q(x)}{p(x)}\right) dx\\
&= \int_{\mathbb{R}} p(x) \log \left(\frac{p(x)}{q(x)}\right) dx \\
&= \int_{\mathbb{R}} p(x) \log p(x) dx - \int_{\mathbb{R}} p(x) \log q(x) dx \\
&= -H(p) + \underbrace{H(p,q)}_{\text{cross-entropy}} \\
\end{align}

\[D_{KL}(q \| p) = \int_{\mathbb{R}} q(x) \log \left( \frac{q(x)}{p(x)} \right) dx\]
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm,entropy
import seaborn as sns
sns.set_style("whitegrid")

We first create a mixutre of Gaussian of the following form

\[p(x_1,x_2) = \theta_1 \frac{1}{\sqrt{2\pi} \sigma_{1}} \exp\left(-\frac{(x-\mu_1)^2}{2\sigma_{1}^{2}}\right) + \theta_2 \frac{1}{\sqrt{2\pi} \sigma_{2}} \exp\left(-\frac{(x-\mu_2)^2}{2\sigma_{2}^{2}}\right)\]

where $\theta_1 + \theta_2 = 1$ represent the proportions of the two gaussian distributions.

x = np.arange(-10,10,0.01)
theta = 0.5
mu1,mu2 = -2.0,2.0
sigma1,sigma2 =1.,1.
y1 = theta*norm.pdf(x, loc=mu1, scale=sigma1)
y2 = (1.-theta)*norm.pdf(x, loc=mu2, scale=sigma2)
y = y1+y2

%matplotlib inline
fig, ax = plt.subplots(1,1)
ax.plot(x, y1+y2,
       'r-', lw=5, alpha=0.6, label='$p = p_1+p_2$')
ax.plot(x, y1,
       'k--', lw=2, alpha=1, label='$p_1$')
ax.plot(x, y2,
       'k:', lw=2, alpha=1, label='$p_2$')
ax.set_xlim([-5,5])
ax.set_xlabel('x')
ax.set_ylabel('p(x)')
plt.legend()
plt.show()

png

We then fit the GMM model defined above by a single gaussian parametrized by $\theta = (\mu,\sigma)$:

\[q_{\theta}(x) = \frac{1}{\sqrt{2\pi} \sigma} \exp\left(-\frac{(x-\mu)^2}{2\sigma^{2}}\right)\]

For each $\theta$, we compute KL divergences: \(D_{KL}(p \| q_{\theta}), D_{KL}(q_{\theta} \| p)\)

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
MU = np.arange(-5.,5.,0.1)
SIGMA = np.arange(0.1,10.,0.1)
%matplotlib widget
class KL_exp(object):
    def __init__(self):
        
        self.theta=0.5
        self.mu1 = -3.0
        self.mu2 = 3.0
        self.sigma1 = 1.0
        self.sigma2 = 1.0
        self.mu = 0.
        self.sigma = 1.0
        
        self.dx = 0.05
        self.x = np.arange(-10,10,self.dx)
        self.MU = np.arange(-5.,5.,0.1)
        self.SIGMA = np.arange(0.1,10.,0.1)
        
        self.MM,self.SS = np.meshgrid(self.MU,self.SIGMA)
        MM_rep = self.MM.flatten().reshape(-1,1).repeat(len(self.x),axis=1)
        SS_rep = self.SS.flatten().reshape(-1,1).repeat(len(self.x),axis=1)
        x_rep = self.x.reshape(1,-1).repeat(MM_rep.shape[0],axis=0)
        Qs = (1./SS_rep)*np.exp(-(x_rep-MM_rep)**2/(2*(SS_rep**2)))
        Qs /= self.dx*np.sum(Qs,axis=1).reshape(-1,1).repeat(len(self.x),axis=1)
        Qs[Qs==0] = 1e-8 # set 0 to be finite small value
        self.Qs = Qs.copy()
        
        self.fig = plt.figure(figsize=(12,6))
        import matplotlib.gridspec as gridspec
        # set up subplot grid
        gridspec.GridSpec(2,3)

        self.ax1 = plt.subplot2grid((2,3),(0,0),colspan=2,rowspan=2)
        self.ax2 = plt.subplot2grid((2,3),(0,2))
        self.ax3 = plt.subplot2grid((2,3),(1,2))
        
        self.computeKL(self.theta,self.mu1,self.mu2,self.sigma1,self.sigma2)
        self.changeSingle(self.mu,self.sigma)
        self.start()
        
    def start(self):
        # interactive
        interactive(self.update,
                    theta=(0.,1.,0.1),
                    mu1 =(-3.,3.,0.1),
                    mu2 =(-3.,3.,0.1),
                    sigma1=(0.1,2.,0.1),
                    sigma2=(0.1,2.,0.1),
                    mu = (-5.,5.,0.1),
                    sigma = (0.1,10.,0.1))
    
    def update(self,theta,mu1,mu2,sigma1,sigma2,mu,sigma):
        if (self.theta != theta) or (self.mu1 != mu1) or \
            (mu2!= self.mu2) or (sigma1 != self.sigma1) or \
            (sigma2 != self.sigma2):
            self.computeKL(theta,mu1,mu2,sigma1,sigma2)
            self.changeSingle(mu,sigma)
        else:
            self.changeSingle(mu,sigma)

        self.theta,self.mu1,self.mu2,self.sigma1,self.sigma2,self.mu,self.sigma = \
            theta,mu1,mu2,sigma1,sigma2,mu,sigma

    def computeKL(self,theta,mu1,mu2,sigma1,sigma2):
        """Compute KL divergence with new GMM model

        theta: float
            mixture between 2 gaussians
        mu1,mu2: float,float
            mean of two gaussians
        sigma1,sigma2: float,float
            std of two gaussians

        Return
        -------
        MM,SS: ndarray of float
            Meshgrid of MU and SIGMA used `np.meshgrid(MU,SIGMA)`
        KL_PQ: ndarray of float 
            same size as MM
        KL_QP: ndarray of float 
            same size as MM
        """
        _y1 = theta*norm.pdf(self.x, loc=mu1, scale=sigma1)
        _y2 = (1.0-theta)*norm.pdf(self.x, loc=mu2, scale=sigma2)
        self.y_ref = (_y1+_y2)/(np.sum(_y1+_y2)*self.dx)
        Ps = self.y_ref.reshape(1,-1).repeat(self.Qs.shape[0],axis=0)
        
        self.KL_PQ = np.sum(Ps*np.log(Ps/self.Qs),axis=1).reshape(self.MM.shape)
        self.KL_QP = np.sum(self.Qs*np.log(self.Qs/Ps),axis=1).reshape(self.MM.shape)


    def changeSingle(self,mu,sigma):
        """Fit single gaussian
        mu: float
            mean, in range (-5,5)
        sigma: float
            std, in range (0.1,10.)
        """
        y = norm.pdf(self.x, loc=mu, scale=sigma)
        self.ax1.cla()
        self.ax2.cla()
        self.ax3.cla()

        self.ax2.contour(self.MM,self.SS,np.log(self.KL_PQ))
        self.ax2.set_xlabel('$\mu$')
        self.ax2.set_ylabel('$\sigma$')
        self.ax2.set_title('$\log D_{KL}(p||q)$')


        self.ax3.contour(self.MM,self.SS,np.log(self.KL_QP))
        self.ax3.set_xlabel('$\mu$')
        self.ax3.set_title('$\log D_{KL}(q||p)$')
        self.ax3.set_ylabel('$\sigma$')

        self.ax1.set_xlim([-8,8])
        self.ax1.plot(self.x,self.y_ref,
                'r-', lw=5, alpha=0.6, label='$p$')
        self.ax1.plot(self.x,y,
                'k-', lw=2, alpha=0.6, label='$q_{\\theta}$')
        self.ax1.set_title("$KL(p||q) = {:.2f},KL(q||p) = {:.2f}$".format(entropy(self.y_ref,y),entropy(y,self.y_ref)))
        self.ax1.legend()
        self.ax2.plot(mu,sigma,'yo')
        self.ax3.plot(mu,sigma,'yo')
        
        self.fig.tight_layout()
        self.fig.show()
myKL = KL_exp()
interactive(myKL.update,
                    theta=(0.,1.,0.1),
                    mu1 =(-3.,3.,0.1),
                    mu2 =(-3.,3.,0.1),
                    sigma1=(0.1,2.,0.1),
                    sigma2=(0.1,2.,0.1),
                    mu = (-5.,5.,0.1),
                    sigma = (0.1,10.,0.1))

Note: running this notebook in JupyterLab with Matplotlib extension installed will give you an interactive environment as follows: