"""
CS 228: Probabilistic Graphical Models
Winter 2018
Programming Assignment 1: Bayesian Networks

Author: Aditya Grover, Luis Perez
"""

import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
from scipy.io import loadmat

NUM_PIXELS = 28 * 28


"""
Too lazy to vectorize the codes. BAD CODES !!!
"""


def plot_histogram(data, title='histogram', xlabel='value', ylabel='frequency',
                   savefile='hist'):
    '''
    Plots a histogram.
    '''

    plt.figure()
    plt.hist(data)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.savefig(savefile, bbox_inches='tight')
    plt.show()
    plt.close()

    return


def get_p_z1(z1_val):
    '''
    Helper. Computes the prior probability for variable z1 to take value z1_val.
    P(Z1=z1_val)
    '''

    return bayes_net['prior_z1'][z1_val]

def get_p_z2(z2_val):
    '''
    Helper. Computes the prior probability for variable z2 to take value z2_val.
    P(Z2=z2_val)
    '''

    return bayes_net['prior_z2'][z2_val]

def get_p_xk_cond_z1_z2(z1_val, z2_val, k):
    '''
    Helper. Computes the conditional probability that variable xk assumes value 1
    given that z1 assumes value z1_val and z2 assumes value z2_val
    P(Xk = 1 | Z1=z1_val , Z2=z2_val)
    '''

    return bayes_net['cond_likelihood'][(z1_val, z2_val)][0, k - 1]

def load_model(model_file):
    '''
    Loads a default Bayesian network with latent variables (in this case, a
    variational autoencoder)
    '''

    with open('trained_mnist_model', 'rb') as infile:
        cpts = pkl.load(infile, encoding='bytes')

    model = {}
    model['prior_z1'] = cpts[0]  # dictionary of (z1, p(z1)) pairs
    model['prior_z2'] = cpts[1]
    model['cond_likelihood'] = cpts[2]  # dictionary of ((z1,z2),[p(x1), p(x2),...p(x_784)])

    return model

def dict_to_distribution(dict):
    vals, probs = [], []
    for k, v in dict.items():
        vals.append(k)
        probs.append(v)
    return vals, probs

def q4_sample_from_bayesian_network(bayes_net, size):
    prior_z1 = bayes_net['prior_z1']
    prior_z2 = bayes_net['prior_z2']

    vals_z1, probs_z1 = dict_to_distribution(prior_z1)
    vals_z2, probs_z2 = dict_to_distribution(prior_z2)

    z1s = np.random.choice(vals_z1, size, p=probs_z1)
    z2s = np.random.choice(vals_z2, size, p=probs_z2)

    images = np.zeros(shape=[size, 784])
    for i in range(size):
        z1 = z1s[i]
        z2 = z2s[i]
        for j in range(784):
            pj_given_z1_z2 = get_p_xk_cond_z1_z2(z1, z2, j+1)
            images[i,j] = np.random.choice([1.,0.], p=[pj_given_z1_z2, 1-pj_given_z1_z2])
    images = np.reshape(images, [-1, 28, 28])
    fig, axes = plt.subplots(1,size)
    for i, ax in enumerate(axes):
        ax.imshow(images[i], cmap='gray')
    plt.show()

def q5_conditional_expectation(bayes_net):
    cpds = bayes_net['cond_likelihood']
    z1s = np.arange(-3, 3.25, 0.25)
    z2s = z1s.copy()

    images = np.zeros(shape=[len(z1s), len(z2s), 784])
    for i, z1 in enumerate(z1s):
        for j, z2 in enumerate(z2s):
            images[i,j] = [
                cpds[(z1, z2)][0,k] for k in range(784)
            ]
    images = np.reshape(images, [len(z1s), len(z2s), 28, 28])
    large_img = np.zeros(shape=[len(z1s) * 28, len(z2s) * 28])
    for i in range(len(z1s)):
        for j in range(len(z2s)):
            large_img[
                i * 28 : (i+1) * 28, j * 28 : (j+1) * 28
            ] = images[i,j]
    plt.imshow(large_img, cmap='gray')
    plt.show()

def q6_compute_marginal_log_likelihood(bayes_net, image):
    """
    calculate the log marginal likelihood P(X) = log \sum_z1 \sum_z2 P(Z1, Z2, X)
    """
    prior_z1 = bayes_net['prior_z1']
    prior_z2 = bayes_net['prior_z2']
    cpds = bayes_net['cond_likelihood']

    z1s = np.arange(-3, 3.25, 0.25)
    z2s = z1s.copy()

    likelihood = 0.

    for z1 in z1s:
        for z2 in z2s:
            p_z1_z2 = np.log(prior_z1[z1]) + np.log(prior_z2[z2])

            log_p_xs_1_given_z1_z2 = np.log(cpds[(z1,z2)])
            log_p_xs_0_given_z1_z2 = np.log(1. - cpds[(z1,z2)])
            p_xs_given_z1_z2 = log_p_xs_1_given_z1_z2 * (image>0.5) + log_p_xs_0_given_z1_z2 * (image<=0.5)
            p_x_z1_z2 = np.sum(p_xs_given_z1_z2) + p_z1_z2
            p_x_z1_z2 = np.exp(p_x_z1_z2)
            likelihood += p_x_z1_z2

    return np.log(likelihood)

def q6(bayes_net):
    mdict = loadmat('q6.mat')
    val = mdict['val_x']
    test = mdict['test_x']


    marginal_log_likelihood_val = []
    for i,image in enumerate(test[:500]):
        l = q6_compute_marginal_log_likelihood(bayes_net, image)
        print(i,l)
        marginal_log_likelihood_val.append(l)
    marginal_log_likelihood_val = np.array(marginal_log_likelihood_val)
    plot_histogram(marginal_log_likelihood_val, 'Marginal Log-likelihood on Validation Set')

    mean = np.mean(marginal_log_likelihood_val)
    std = np.std(marginal_log_likelihood_val)

    print(mean, std)

def q7_calculate_posterior_expectation(bayes_net, image):
    prior_z1 = bayes_net['prior_z1']
    prior_z2 = bayes_net['prior_z2']
    cpds = bayes_net['cond_likelihood']

    z1s = np.arange(-3, 3.25, 0.25)
    z2s = z1s.copy()

    likelihood = 0.

    # joint_probs = np.zeros(shape=[len(z1s), len(z2s)])

    Ez1 = 0
    Ez2 = 0

    for i,z1 in enumerate(z1s):
        for j,z2 in enumerate(z2s):
            p_z1_z2 = np.log(prior_z1[z1]) + np.log(prior_z2[z2])

            log_p_xs_1_given_z1_z2 = np.log(cpds[(z1, z2)])
            log_p_xs_0_given_z1_z2 = np.log(1. - cpds[(z1, z2)])
            p_xs_given_z1_z2 = log_p_xs_1_given_z1_z2 * (image > 0.5) + log_p_xs_0_given_z1_z2 * (image <= 0.5)
            p_x_z1_z2 = np.sum(p_xs_given_z1_z2) + p_z1_z2
            p_x_z1_z2 = np.exp(p_x_z1_z2)

            # joint_probs[i, j] = p_x_z1_z2
            Ez1 += p_x_z1_z2 * z1
            Ez2 += p_x_z1_z2 * z2

            likelihood += p_x_z1_z2

    Ez1 /= likelihood
    Ez2 /= likelihood
    return Ez1, Ez2

def q7(bayes_net):
    mdict = loadmat('q7.mat')
    X = mdict['x']# [:100]
    Y = mdict['y'][:,0]# [:100]

    EZ1s = []
    EZ2s = []
    for i,x in enumerate(X):
        Ez1, Ez2 = q7_calculate_posterior_expectation(bayes_net, x)
        EZ1s.append(Ez1)
        EZ2s.append(Ez2)
    EZ1_max = np.max(EZ1s)
    EZ1s = np.array(EZ1s)
    EZ2s = np.array(EZ2s)
    # markers = ['ro', 'go', 'bo', 'yo', ]
    for i in range(10):
        idx = np.where(Y==i)
        plt.scatter(EZ2s[idx], EZ1_max - EZ1s[idx], label='%d' % (i), s=12)
    plt.legend()
    plt.show()

def main():
    global disc_z1, disc_z2
    n_disc_z = 25
    disc_z1 = np.linspace(-3, 3, n_disc_z)
    disc_z2 = np.linspace(-3, 3, n_disc_z)

    global bayes_net
    bayes_net = load_model('trained_mnist_model')
    # q4_sample_from_bayesian_network(bayes_net, 5)
    # q5_conditional_expectation(bayes_net)
    # q6(bayes_net)
    q7(bayes_net)
if __name__ == '__main__':
    main()

