import matplotlib.pyplot as plt
import numpy as np
import torch
import gzip
import os

"""
This file is poorly coded
Use the training set as unlabeled dataset and the testing set as labeled dataset
"""

mnist_path = "../../../../DL/MNIST/"
class MNIST():
	def __init__(self):
		self.train_imgs = self.load_imgs(os.path.join(mnist_path,'train-images-idx3-ubyte.gz'))
		self.test_imgs = self.load_imgs(os.path.join(mnist_path,'t10k-images-idx3-ubyte.gz'))
		self.train_labels = self.load_labels(os.path.join(mnist_path,'train-labels-idx1-ubyte.gz'))
		self.test_labels = self.load_labels(os.path.join(mnist_path,'t10k-labels-idx1-ubyte.gz'))

		self.epochs = 0
		self.cursor = 0
		self.test_epochs = 0
		self.test_cursor = 0

	def load_batch(self, batch_size):
		if self.cursor + batch_size < len(self.train_imgs):

			X = torch.from_numpy(self.train_imgs[self.cursor: self.cursor+batch_size]).float()
			self.cursor += batch_size
			return X
		else:
			X = []
			X.append(self.train_imgs[self.cursor:])

			new_cursor = (self.cursor + batch_size) % len(self.train_imgs)

			X.append(self.train_imgs[:new_cursor])
			self.cursor = new_cursor
			self.epochs += 1

			return torch.from_numpy(np.concatenate(X)).float()

	def load_batch_with_label(self, batch_size):
		if self.test_cursor + batch_size < len(self.test_imgs):
			imgs =  torch.from_numpy(self.test_imgs[self.test_cursor: self.test_cursor+batch_size]).float()
			labels = torch.from_numpy(self.test_labels[self.test_cursor: self.test_cursor+batch_size])
			self.test_cursor += batch_size
			return imgs, labels.type(torch.LongTensor)
		else:
			X = []
			y = []
			X.append(self.test_imgs[self.test_cursor:])
			y.append(self.test_labels[self.test_cursor:])

			new_cursor = (self.test_cursor + batch_size) % len(self.test_imgs)

			X.append(self.test_imgs[:new_cursor])
			y.append(self.test_labels[:new_cursor])
			self.test_cursor = new_cursor
			self.test_epochs += 1

			imgs = torch.from_numpy(np.concatenate(X)).float()
			labels = torch.from_numpy(np.concatenate(y))
			return imgs, labels.type(torch.LongTensor)

	def load_test(self, batch_size):
		idx = np.random.randint(0, len(self.test_imgs), batch_size)
		return torch.from_numpy(self.test_imgs[idx]).float()

	def load_imgs(self, path):
		with gzip.open(path) as f:
			pixels = np.frombuffer(f.read(), 'B', offset=16)
			return pixels.reshape(-1, 784).astype('float32') / 255

	def load_labels(self, path):
		with gzip.open(path) as f:
			# First 8 bytes are magic_number, n_labels
			integer_labels = np.frombuffer(f.read(), 'B', offset=8)

			# def _onehot(integer_labels):
			# 	"""Return matrix whose rows are onehot encodings of integers."""
			# 	n_rows = len(integer_labels)
			# 	n_cols = integer_labels.max() + 1
			# 	onehot = np.zeros((n_rows, n_cols), dtype='uint8')
			# 	onehot[np.arange(n_rows), integer_labels] = 1
			# 	return onehot
			#
			# return _onehot(integer_labels)
			return np.array(integer_labels)


if __name__ == '__main__':
	mnist = MNIST()
	X = mnist.load_batch(10)

	# for i in range(10):
	# 	img = mnist.load_batch(1).detach().numpy()
	# 	plt.imshow(img[0].reshape([28,28]), cmap='gray')
	# 	plt.show()
	img, label = mnist.load_batch_with_label(10)
	img = img.detach().numpy()
	label = label.detach().numpy()
	for i in range(10):

		plt.imshow(img[i].reshape([28,28]), cmap='gray')
		print(label[i])
		plt.show()
