import matplotlib.pyplot as plt
from torch import nn
from torch import optim
import numpy as np

from MNIST import MNIST
from VAE import VAE
from MGVAE import MGVAE
from SSVAE import SSVAE


def train_evidence_lower_bound(model:VAE or MGVAE, mnist:MNIST, device, batch_size=64, learning_rate=1e-3):
	optimizer = optim.Adam(model.parameters(), lr=learning_rate)
	num_iter = 0
	model.to(device)

	while True:
		num_iter += 1
		optimizer.zero_grad()

		images = mnist.load_batch(batch_size).to(device)
		loss, KL, expected_log_likelihood = model.negative_evidence_lower_bound(images)
		loss.backward()
		optimizer.step()

		print("iter %d, loss %.5f, KL divergence: %.5f, Expected log-likelihood % .5f" %
		      (num_iter, loss.detach().numpy(),
		       KL.detach().numpy(),
		       expected_log_likelihood.detach().numpy())
		)

		if num_iter % 100 == 0:
			img = mnist.load_test(1)
			# sample = model.sample_from_img(img)
			sample = model.sample_from_prior(1)
			sample = np.reshape(sample[0].detach().numpy(), [28, 28])

			plt.imsave('./samples/%d.png' % (num_iter), sample, cmap='gray')

def train_importance_weighting_lower_bound(model:VAE or MGVAE, mnist:MNIST, device, batch_size=64, learning_rate=1e-3):
	optimizer = optim.Adam(model.parameters(), lr=learning_rate)
	num_iter = 0
	model.to(device)

	while True:
		num_iter += 1
		optimizer.zero_grad()

		images = mnist.load_batch(batch_size).to(device)
		loss = model.negative_importance_weighting_lower_bound(images)
		loss.backward()
		optimizer.step()

		print("iter %d, loss %.5ff" %
		      (num_iter, loss.detach().numpy())
		)

		if num_iter % 100 == 0:
			img = mnist.load_test(1)
			# sample = model.sample_from_img(img)
			sample = model.sample_from_prior(1)
			sample = np.reshape(sample[0].detach().numpy(), [28, 28])

			plt.imsave('./samples/%d.png' % (num_iter), sample, cmap='gray')

def train_ssvae(model:SSVAE, mnist:MNIST, device, batch_size=64, labeled_batch_size=30, learning_rate=1e-3):
	optimizer = optim.Adam(model.parameters(), lr=learning_rate)
	num_iter = 0
	model.to(device)

	while True:
		num_iter += 1
		optimizer.zero_grad()

		images = mnist.load_batch(batch_size).to(device)
		imgs_with_label, label = mnist.load_batch_with_label(labeled_batch_size)
		imgs_with_label.to(device)
		label.to(device)

		loss, nelbo, KL_y, KL_z, log_likelihoods = model.loss(images, imgs_with_label, label)
		loss.backward()
		optimizer.step()

		print("iter %d, loss %.5f, nelbo: %.5f, KL_y: %.5f, KL_z:%.5f, log_likelihoods:%.5f" %
		      (
			      num_iter,
			      loss.detach().numpy(),
			      nelbo.detach().numpy(),
			      KL_y.detach().numpy(),
			      KL_z.detach().numpy(),
			      log_likelihoods.detach().numpy()
		      )
		)

		if num_iter % 20 == 0:
			ys = np.random.randint(0,10,2)
			imgs = model.sample_from_prior(ys)
			plt.imsave('./samples/%d_%d.png' % (num_iter, ys[0]), imgs[0].detach().numpy().reshape([28,28]), cmap='gray')
			plt.imsave('./samples/%d_%d.png' % (num_iter, ys[1]), imgs[1].detach().numpy().reshape([28,28]), cmap='gray')
			pass

if __name__ == '__main__':
	mnist = MNIST()
	# vae = VAE(z_dim=20)
	# mgvae = MGVAE(z_dim=15, k=20)
	# # train_evidence_lower_bound(vae, mnist, 'cpu')
	# train_importance_weighting_lower_bound(mgvae, mnist, 'cpu')

	ssvae = SSVAE()
	train_ssvae(ssvae, mnist, 'cpu')
