import torch
import torch.nn as nn
import probs
import nets
import numpy as np

class MGVAE(nn.Module):
	def __init__(self, name='mgvae', input_dim=784, z_dim=3, k=2):
		super(MGVAE, self).__init__()
		self.name = name
		self.z_dim = z_dim
		self.input_dim = input_dim
		self.k = k

		self.encoder = nets.EncoderV1(self.input_dim, self.z_dim)
		self.decoder = nets.DecoderV1(self.input_dim, self.z_dim)

		self.z_prior_mean = torch.nn.Parameter(torch.zeros(k, z_dim), )
		self.z_prior_var = torch.nn.Parameter(torch.ones(k, z_dim), )
		self.z_prior_weight = torch.nn.Parameter(torch.ones(k) / k, requires_grad=False)

	def negative_evidence_lower_bound(self, x):
		mean, var = self.encoder.encode(x)  # [batch, z_dim], [batch, z_dim]
		zs = probs.sample_from_gaussian_repram(mean, var, num_sample=10) # [batch, num_samples, z_dim]
		zs = torch.transpose(zs, 0, 1) # [num_sample, batch, dim]

		# approximate the KL( q(z|x) || p(z) ) with monte carlo
		log_qs = probs.log_gaussian_pdf(mean, var, zs) # [num_sample, batch]
		log_Ps = probs.log_mixture_gaussian_pdf(
			self.z_prior_mean, self.z_prior_var, self.z_prior_weight, zs
		 ) # [num_sample, batch]
		KL = torch.mean(log_qs - log_Ps, dim=0) # [batch,]

		# approximate the expected log-likelihood E_{z~q} [logP(x|z)] with monte carlo
		P_x_given_z = self.decoder.decode(zs) # [num_sample, batch, input_dim]
		log_likelihood = self.compute_log_likelihood(P_x_given_z, x) # [num_sample, batch]
		expected_log_likelihood = torch.mean(log_likelihood, dim=0) # [batch]

		KL = torch.mean(KL)
		expected_log_likelihood = torch.mean(expected_log_likelihood)

		return KL - expected_log_likelihood, KL, expected_log_likelihood

	def negative_importance_weighting_lower_bound(self, x):
		mean, var = self.encoder.encode(x) # [batch, z_dim], [batch, z_dim]
		z_samples = probs.sample_from_gaussian_repram(mean, var) # [batch, num_sample, z_dim]
		z_samples = torch.transpose(z_samples, 0, 1) # [num_sample, batch, z_dim]

		log_Ps = probs.log_mixture_gaussian_pdf(
			self.z_prior_mean, self.z_prior_var, self.z_prior_weight, z_samples
		) # [num_sample, batch]
		log_q_z_given_x = probs.log_gaussian_pdf(mean, var, z_samples) # [num_sample, batch]
		P_x_given_z = self.decoder.decode(z_samples) # [num_sample, batch, input_dim]

		log_P_x_given_z = self.compute_log_likelihood(P_x_given_z, x) # [num_sample, batch]

		log_probs = log_P_x_given_z + log_Ps - log_q_z_given_x # [num_sample, batch]
		log_probs_max = torch.max(log_probs, dim=0)[0][None, :]  # [1,batch]
		iwlb = log_probs_max[0] + torch.log(
			torch.mean(
				torch.exp(log_probs - log_probs_max), dim=0
			)
		)  # [batch]
		return -torch.mean(iwlb)
	def compute_log_likelihood(self, P_x_given_z, x, eps=1e-8):
		"""
		:param P_x_given_z: torch.Tensor, of shape [num_sample, batch, input_dim]
		:param x: torch.Tensor, of shape [batch, input_dim]
		:param eps: float
		:return:
		"""
		return torch.sum(
			x * torch.log(P_x_given_z + eps) + (1 - x) * torch.log(1 - P_x_given_z + eps),
			dim=-1
		)

	def sample_from_prior(self, num_sample):
		z = probs.sample_from_mixture_gaussian(
			self.z_prior_mean, self.z_prior_var, self.z_prior_weight, num_sample
		) # [num_sample, z_dim]

		P_x_given_z = self.decoder.decode(z) # [num_sample, input_dim]
		return P_x_given_z



if __name__ == '__main__':
	mgvae = MGVAE(k=4)
	arr = torch.ones(16, 784)
	# mgvae.negative_evidence_lower_bound(arr)
	# mgvae.sample_from_prior(9)
	mgvae.negative_importance_weighting_lower_bound(arr)