import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

class EncoderV1(nn.Module):
	"""
	The neural network that represents q(z;phi) ~ P(z | x; theta)
	The network takes in data x, and output the parameter phi
	"""
	def __init__(self, input_dim, z_dim, y_dim=0):
		super(EncoderV1, self).__init__()
		self.input_dim = input_dim
		self.z_dim = z_dim
		self.y_dim = y_dim
		self.net = nn.Sequential(
			nn.Linear(input_dim + y_dim, 300),
			nn.ELU(),
			nn.Linear(300,300),
			nn.ELU(),
			nn.Linear(300, 2 * z_dim)
		)

	def encode(self, x, y=None):
		xy = x if y is None else torch.cat([x,y], dim=-1)
		h = self.net(xy)    # [batch size, 2 * z_dim]
		# Gaussian parameters representing q(z) ~ P(z|x) = N(encoder(x))
		mean, var = h[...,:self.z_dim], h[...,self.z_dim:]
		var = torch.nn.ReLU()(var) + 1e-6
		return mean, var # F.softplus(var) + 1e-6

class DecoderV1(nn.Module):
	def __init__(self, output_dim, z_dim, y_dim=0):
		super(DecoderV1, self).__init__()
		self.output_dim = output_dim
		self.z_dim = z_dim
		self.y_dim = y_dim

		self.net = nn.Sequential(
			nn.Linear(z_dim + y_dim, 300),
			nn.ELU(),
			nn.Linear(300,300),
			nn.ELU(),
			nn.Linear(300, output_dim),
			nn.Sigmoid()
		)
	def decode(self, z, y=None):
		zy = z if y is None else torch.cat((z, y), dim=-1)
		return self.net(zy)

class ClassifierV1(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(ClassifierV1, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.net = nn.Sequential(
			nn.Linear(input_dim, 300),
			nn.ELU(),
			nn.Linear(300, 300),
			nn.ELU(),
			nn.Linear(300, output_dim),
			nn.Softmax(dim=-1)
		)

	def classify(self, x):
		return self.net(x)

if __name__ == "__main__":
	encoder = EncoderV1(783, 3)
	x = np.random.randn(10, 784)
	x = torch.from_numpy(x).float()

	encoder.encode(x)

