import torch
import torch.nn as nn
import probs
import nets
import numpy as np


class VAE(nn.Module):
	def __init__(self, name='vae', input_dim=784, z_dim=2):
		super(VAE, self).__init__()
		self.name = name
		self.z_dim = z_dim
		self.input_dim = input_dim

		self.encoder = nets.EncoderV1(self.input_dim, self.z_dim)
		self.decoder = nets.DecoderV1(self.input_dim, self.z_dim)

		self.z_prior_mean = torch.nn.Parameter(torch.zeros(z_dim), requires_grad=False)
		self.z_prior_var = torch.nn.Parameter(torch.ones(z_dim), requires_grad=False)

	def negative_evidence_lower_bound(self, x):
		"""
		:param x: torch.Tensor, of shape [batch, input_dim]
		:return:
		"""
		mean, var = self.encoder.encode(x)  # [batch, z_dim] [batch, z_dim]
		KL = probs.KL_gaussians(mean, var, self.z_prior_mean, self.z_prior_var) # [batch,]

		z_samples = probs.sample_from_gaussian_repram(mean, var, num_sample=5) # [batch, num_sample, dim]
		P_x_given_z = self.decoder.decode(z_samples)    # [batch, num_sample, input_dim]
		P_x_given_z = torch.transpose(P_x_given_z, 0, 1)    # [num_sample, batch, input_dim]

		log_likelihood = self.compute_log_likelihood(P_x_given_z, x) # [num_sample, batch]
		expected_log_likelihood = torch.mean(log_likelihood, dim=0) # [batch,]

		KL = torch.mean(KL)
		expected_log_likelihood = torch.mean(expected_log_likelihood)
		return KL - expected_log_likelihood, KL, expected_log_likelihood

	def negative_importance_weighting_lower_bound(self, x):
		mean, var = self.encoder.encode(x)  # [batch, z_dim], [batch, z_dim]
		z_samples = probs.sample_from_gaussian_repram(mean, var, num_sample=5) # [batch, num_sample, z_dim]
		z_samples = torch.transpose(z_samples, 0, 1) # [num_sample, batch, z_dim]


		log_P_z = probs.log_gaussian_pdf(self.z_prior_mean, self.z_prior_var, z_samples) # [num_sample, batch]
		log_q_z_given_x = probs.log_gaussian_pdf(mean, var, z_samples) # [num_sample, batch]
		P_X_given_z = self.decoder.decode(z_samples) # [num_sample, batch, input_size]
		log_P_x_given_z = self.compute_log_likelihood(P_X_given_z, x) # [num_sample, batch]
		log_probs = log_P_x_given_z + log_P_z - log_q_z_given_x # [num_sample, batch]

		log_probs_max = torch.max(log_probs, dim=0)[0][None,:] # [1,batch]
		iwlb = log_probs_max[0] + torch.log(
			torch.mean(
				torch.exp(log_probs - log_probs_max), dim=0
			)
		) # [batch]
		# print(iwlb)
		return -torch.mean(iwlb)




	def compute_log_likelihood(self, P_x_given_z, x, eps=1e-8):
		"""
		:param P_x_given_z: torch.Tensor, of shape [num_sample, batch, input_dim]
		:param x: torch.Tensor, of shape [batch, input_dim]
		:param eps: float
		:return:
		"""
		return torch.sum(
			x * torch.log(P_x_given_z + eps) + (1 - x) * torch.log(1 - P_x_given_z + eps),
			dim=-1
		)

	def sample_from_prior(self, num_sample):
		"""
		:param num_sample: int
		:return: of shape [num_sample, input_dim]

		"""
		zs = probs.sample_from_gaussian_repram(self.z_prior_mean[None,:], self.z_prior_var[None,:], num_sample)
		P_x_given_z = self.decoder.decode(zs) # [1, num_sample, input_dim]
		return P_x_given_z[0]



if __name__ == '__main__':
	vae = VAE()

	x = np.random.randn(10, 784)
	x = torch.from_numpy(x).float()
	vae.negative_importance_weighting_lower_bound(x)
