import torch
import torch.nn as nn
import probs
import nets
from torch.nn import functional as F
import numpy as np
from MNIST import MNIST

class SSVAE(nn.Module):
	def __init__(self, name='ssvae', input_dim=784, z_dim=25, y_dim=10):
		super(SSVAE, self).__init__()
		self.name = name
		self.input_dim = input_dim
		self.z_dim = z_dim
		self.y_dim = y_dim
		self.gen_weight = 1
		self.classification_weight = 200

		self.encoder = nets.EncoderV1(input_dim, z_dim, y_dim)
		self.decoder = nets.DecoderV1(input_dim, z_dim, y_dim)
		self.classifier = nets.ClassifierV1(input_dim, y_dim)
		self.classification_loss = nn.CrossEntropyLoss()

		self.z_prior_mean = torch.nn.Parameter(torch.zeros(z_dim), requires_grad=False)
		self.z_prior_var = torch.nn.Parameter(torch.ones(z_dim), requires_grad=False)
		self.y_prior = torch.nn.Parameter(torch.ones(10) / 10, requires_grad=False)

	def negative_evidence_lower_bound_unlabeld(self,x, num_sample=5):
		batch_size = x.shape[0]
		P_y_given_x = self.classifier.classify(x) # [batch, y_dim]

		# Compute the KL divergence KL( q(yi | xi) || P(yi) )
		KL1 = probs.KL_categorical(
			P_y_given_x, torch.log(P_y_given_x), torch.log(self.y_prior)
		)  # [batch,]
		KL1 = torch.mean(KL1)
		# KL1 = torch.zeros(1)

		# compute the expected KL divergence E_{yi ~ q(yi|xi)} KL( q(zi|xi,yi) || P(zi) )
		q_y_given_x = self.classifier.classify(x) # [batch, y_dim]
		generated_y = torch.eye(self.y_dim)[None,...]
		generated_y = torch.repeat_interleave(generated_y, batch_size, 0) # [batch, y_dim, y_dim]
		generated_x = x[:,None,:] # [batch, 1, input_dim]
		generated_x = torch.repeat_interleave(generated_x, self.y_dim, 1) # [batch, y_dim, input_dim]

		z_mean, z_var = self.encoder.encode(generated_x, generated_y) # [batch, y_dim, z_dim], [batch, y_dim, z_dim]
		KL2 = probs.KL_gaussians(z_mean, z_var, self.z_prior_mean, self.z_prior_var) # [batch, y_dim]
		expected_KL = torch.sum(q_y_given_x * KL2, dim=-1) # [batch,]
		expected_KL = torch.mean(expected_KL)

		#
		z_samples = probs.sample_from_gaussian_repram(
			torch.reshape(z_mean, [batch_size * self.y_dim, self.z_dim]),
			torch.reshape(z_var, [batch_size * self.y_dim, self.z_dim]),
			num_sample
		) #  z_mean, [batch_size * y_dim, num_sample, z_dim]
		z_samples = torch.reshape(z_samples, [batch_size, self.y_dim, num_sample, self.z_dim])
		y_samples = generated_y[:,:,None,:] # [batch_size, y_dim, 1, z_dim]
		y_samples = torch.repeat_interleave(y_samples, num_sample, 2) # [batch_size, y_dim, num_sample, z_dim]

		P_x_given_y_z = self.decoder.decode(z_samples, y_samples) # [batch, y_dim, num_sample, input_dim]
		P_x_given_y_z = torch.transpose(P_x_given_y_z, 0, 2) # [num_sample, y_dim, batch, input_dim]
		log_likelihoods = self.compute_log_likelihood(P_x_given_y_z, x) # [num_sample, y_dim, batch]
		log_likelihoods = torch.transpose(log_likelihoods, 0, 2) # [batch, y_dim, num_sample]

		E_z_given_xy_log_likelihoods = torch.mean(log_likelihoods, dim=-1) # [batch, y_dim]
		expected_log_likelihoods = torch.sum(E_z_given_xy_log_likelihoods * P_y_given_x, dim=-1) # [batch,]
		expected_log_likelihoods = torch.mean(expected_log_likelihoods)

		return KL1 + expected_KL - expected_log_likelihoods, KL1, expected_KL, expected_log_likelihoods

	def loss(self, unlabeled_x, x, y):
		nelbo, KL_y, KL_z, log_likelihoods = self.negative_evidence_lower_bound_unlabeld(unlabeled_x)

		yhat = self.classifier.classify(x)
		classification_loss = self.classification_loss(yhat, y)

		return self.gen_weight * nelbo + self.classification_weight * classification_loss, nelbo, KL_y, KL_z, log_likelihoods


	def compute_log_likelihood(self, P_x_given_z, x, eps=1e-8):
		"""
		:param P_x_given_z: torch.Tensor, of shape [num_sample, batch, input_dim]
		:param x: torch.Tensor, of shape [batch, input_dim]
		:param eps: float
		:return:
		"""
		return torch.sum(
			x * torch.log(P_x_given_z + eps) + (1 - x) * torch.log(1 - P_x_given_z + eps),
			dim=-1
		)

	def sample_from_prior(self, ys):
		batch_size = len(ys)
		Y = np.zeros(shape=[batch_size,10])
		Y[range(batch_size),ys] = 1
		Y = torch.from_numpy(Y).float()
		Z = probs.sample_from_gaussian_repram(self.z_prior_mean[None,:], self.z_prior_var[None,:], batch_size)[0]

		P_x_given_yz = self.decoder.decode(Z, Y)
		return P_x_given_yz




if __name__ == '__main__':
	mnist = MNIST()
	ssvae = SSVAE()

	x = mnist.load_batch(10)
	xl, yl = mnist.load_batch_with_label(10)

	ssvae.loss(x, xl, yl)


