import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

def KL_categorical(p:torch.Tensor, log_p:torch.Tensor, log_q:torch.Tensor):
	return torch.sum(p * (log_p - log_q), dim=-1)

def KL_gaussians(mu1, var1, mu2, var2):
	"""
	Compute the KL divergence of two scaled (NOT TILTED) multivariate Gaussian
		KL( N(mu1, diag(var1)) || N(mu2, diag(var2)) )
	:param mu1: torch.Tensor, of shape [...,dim]
	:param var1: torch.Tensor, of shape [...,dim]
	:param mu2: torch.Tensor, of shape [..., dim]
	:param var2: torch.Tensor, of shape [..., dim]
	:return:
	"""
	return 0.5 * torch.sum(
		torch.log(var2) - torch.log(var1) - 1 + var1 / var2 + torch.square(mu1 - mu2) / var2, dim=-1
	)



def sample_from_gaussian_repram(mu, var, num_sample=10):
	"""
	Sample from Multivariate Gaussian with reparameterization
	:param mu: torch.Tensor, of shape [batch, dim]
	:param var: torch.Tensor, of shape [batch, dim]
	:param num_sample: int
	:return: torch.Tensor, of shape [batch, num_sample, dim]
	"""
	ret = []
	batch_size = mu.shape[0]
	dim = mu.shape[1]
	for i in range(batch_size):
		ret.append(
			torch.distributions.MultivariateNormal(
				loc=mu[i], covariance_matrix=torch.diag(var[i])
			).rsample([num_sample])
		)
	return torch.stack(ret)

def sample_from_mixture_gaussian(mu, var, ws, num_sample):
	"""
	:param mu: torch.Tensor, of shape [k, dim]
	:param var: torch.Tensor, of shape [k, dim]
	:param ws: torch.Tensor, of shape [k,]
	:return: torch.Tensor, of shape [num_sample, dim]
	"""
	k = ws.shape[0]
	idx = np.random.choice(np.arange(0, k), size=num_sample, p=ws.detach().numpy())
	counts = [np.sum(idx==i) for i in range(k)]
	ret = []
	for gaussian_id, number in enumerate(counts):
		z = torch.distributions.MultivariateNormal(
			mu[gaussian_id], torch.diag(var[gaussian_id])
		).sample([number])
		ret.append(z)
	return torch.cat(ret, dim=0)

def log_gaussian_pdf(mu, var, x):
	"""
	Compute the log of Gaussian pdf
	:param mu: torch.Tensor, of shape [..., dim]
	:param var: torch.Tensor, of shape [..., dim]
	:param x: torch.Tensor, of shape [..., dim]
	:return:
	"""
	return torch.sum(
		-np.log(np.sqrt(np.pi*2)) - torch.log(torch.sqrt(var)) - torch.square(x - mu) / (2 * var),
		dim=-1
	)

def log_mixture_gaussian_pdf(mu, var, weight, x):
	"""
	Compute the log of mixture of Gaussian pdf
	:param mu: torch.Tensor, of shape [..., k, dim]
	:param var: torch.Tensor, of shape [..., k, dim]
	:param weight: torch.Tensor, of shape [..., k]
	:param x: torch.Tensor, pf shape [..., dim]
	:return:
	"""
	k = weight.shape[-1]
	x = x[...,None,:] # [...,1,dim]
	x = torch.repeat_interleave(x, k, dim=-2)   # [..., k, dim]
	log_gaussians = log_gaussian_pdf(mu, var, x) # [..., k]
	log_gaussian_max = torch.max(log_gaussians, dim=-1)[0] # [...,]

	return log_gaussian_max + torch.log(
		torch.sum(
			weight * torch.exp(log_gaussians - log_gaussian_max[...,None]), dim=-1
		)
	)





