import matplotlib.pyplot as plt
import numpy as np

def read_labeled_matrix(filename):
	"""Read and parse the labeled dataset.

	  Output:
	      Xij: dictionary of measured statistics
	          Dictionary is indexed by tuples (i,j).
	          The value assigned to each key is a (1,2) numpy.matrix encoding X_ij.
	      Zij: dictionary of party choices.
	          Dictionary is indexed by tuples (i,j).
	          The value assigned to each key is a float.
	      N, M: Counts of precincts and voters.
	  """
	Zij = {}
	Xij = {}
	M = 0.0
	N = 0.0
	with open(filename, 'r') as f:
		lines = f.readlines()
		for line in lines[1:]:
			i, j, Z, X1, X2 = line.split()
			i, j = int(i), int(j)
			if i > N:
				N = i
			if j > M:
				M = j

			Zij[i - 1, j - 1] = float(Z)
			Xij[i - 1, j - 1] = np.array([float(X1), float(X2)])
	return Xij, Zij, N, M

def read_unlabeled_matrix(filename):
	"""Read and parse the unlabeled dataset.

	  Output:
	      Xij: dictionary of measured statistics
	          Dictionary is indexed by tuples (i,j).
	          The value assigned to each key is a (1,2) numpy.matrix encoding X_ij.
	      N, M: Counts of precincts and voters.
	  """
	Xij = {}
	M = 0.0
	N = 0.0
	with open(filename, 'r') as f:
		lines = f.readlines()
		for line in lines[1:]:
			i, j, X1, X2 = line.split()
			i, j = int(i), int(j)
			if i > N:
				N = i
			if j > M:
				M = j

			Xij[i - 1, j - 1] = np.array([float(X1), float(X2)])
	return Xij, N, M

def multivariate_gaussian(x, mu, sigma):
	n = len(mu)
	Z = ((2 * np.pi) ** (n/2)) * (np.linalg.det(sigma) ** 0.5)
	return 1 / Z * np.exp(-0.5 * (x-mu) @ np.linalg.inv(sigma) @ (x-mu))

def compute_log_likelihood_A(Xij, N, M, pi, mu1, mu0, sigma1, sigma0):
	l = 0
	for n in range(N):
		for m in range(M):
			l += np.log(
				(1-pi) * multivariate_gaussian(Xij[n,m], mu0, sigma0) +
				pi * multivariate_gaussian(Xij[n,m], mu1, sigma1)
			)
	return l

def plot_samples(ax, Xij, Zij, N, M):
	for n in range(N):
		for m in range(M):
			if Zij[n,m] == 1:
				ax.plot(Xij[n,m][0], Xij[n,m][1], 'bo', markersize=2)
			else:
				ax.plot(Xij[n, m][0], Xij[n, m][1], 'ro', markersize=2)

def Q_A_1():
	X_ij, Z_ij, N, M = read_labeled_matrix('./surveylabeled.dat')
	M1 = int(sum(Z_ij.values()))
	M0 = (M*N) - M1

	pi = M1 / (M0 + M1)

	mu1 = np.zeros(shape=[2])
	mu0 = np.zeros(shape=[2])
	for n in range(N):
		for m in range(M):
			if Z_ij[n, m] == 1.:
				mu1 += X_ij[n, m]
			else:
				mu0 += X_ij[n, m]
	mu1 /= M1
	mu0 /= M0

	sigma1 = np.zeros(shape=[2,2])
	sigma0 = np.zeros(shape=[2,2])
	for n in range(N):
		for m in range(M):
			if Z_ij[n,m] == 1.:
				sigma1 += np.outer(X_ij[n,m], X_ij[n,m])
			else:
				sigma0 += np.outer(X_ij[n,m], X_ij[n,m])

	sigma1 /= M1
	sigma0 /= M0
	return pi, mu1, mu0, sigma1, sigma0

def EM_Mixture_of_Gaussian(Xij, N, M, pi=None, mu1=None, mu0=None, sigma1=None, sigma0=None):
	# initialization
	if pi is None:
		pi = 0.5
		mu1 = np.random.randn(2)
		mu0 = np.random.randn(2)
		sigma1 = np.identity(2)
		sigma0 = np.identity(2)

	Q = np.zeros(shape=[2, N, M])   # Q[0,:,:] = P(Z=0 | X) and Q[1,:,:] = P(Z=1 | X)

	num_iter = 0
	l = None
	log_likelihood = []

	while True:
		num_iter += 1
		l_new = compute_log_likelihood_A(Xij, N, M, pi, mu1, mu0, sigma1, sigma0)

		log_likelihood.append(l_new)
		print('iteration %d, log-likelihood %.5f' % (num_iter, l_new))

		# E-step
		for n in range(N):
			for m in range(M):
				Q[0,n,m] = (1-pi) * multivariate_gaussian(Xij[n,m], mu0, sigma0)
				Q[1,n,m] = pi * multivariate_gaussian(Xij[n,m], mu1, sigma1)

		ZQ = np.sum(Q, axis=0)
		Q /= ZQ # renormalize

		# M-step
		mu0_old, mu1_old = mu0.copy(), mu1.copy()
		sigma0_old, sigma1_old = sigma0.copy(), sigma1.copy()
		mu0, mu1 = np.zeros_like(mu0), np.zeros_like(mu1)
		sigma0, sigma1 = np.zeros_like(sigma0), np.zeros_like(sigma1)
		for n in range(N):
			for m in range(M):
				mu1 += Q[1,n,m] * Xij[n,m]
				mu0 += Q[0,n,m] * Xij[n,m]

				sigma1 += Q[1,n,m] * np.outer(Xij[n,m]-mu1_old, Xij[n,m]-mu1_old)
				sigma0 += Q[0,n,m] * np.outer(Xij[n,m]-mu0_old, Xij[n,m]-mu0_old)

		mu1 /= np.sum(Q[1])
		mu0 /= np.sum(Q[0])

		sigma1 /= np.sum(Q[1])
		sigma0 /= np.sum(Q[0])

		pi = np.mean(Q[1])

		if l is not None and (l_new - l) <= 1e-2:
			return pi, mu1, mu0, sigma1, sigma0, Q, log_likelihood

		l = l_new

def Q_A_2():
	Xij, N, M = read_unlabeled_matrix('./surveyunlabeled.dat')
	fig, (ax1, ax2, ax3) = plt.subplots(1,3)

	# poor initialization
	pi, mu1, mu0, sigma1, sigma0, Q, log_likelihood = EM_Mixture_of_Gaussian(Xij, N, M)

	ax1.plot(log_likelihood, 'b--', label='poor initialization')
	plot_samples(ax2, Xij, Q[1] >= 0.5, N, M)
	ax2.set_title('poor initialization')

	# MLE initialization
	pi, mu1, mu0, sigma1, sigma0 = Q_A_1()
	pi, mu1, mu0, sigma1, sigma0, Q, log_likelihood = EM_Mixture_of_Gaussian(Xij, N, M, pi, mu1, mu0, sigma1, sigma0)


	ax1.plot(log_likelihood, 'r--', label='MLE initialization')
	plot_samples(ax3, Xij, Q[1] >= 0.5, N, M)
	ax3.set_title('MLE initialization')

	ax1.legend()
	plt.show()

def get_precincts_preferences(Z_ij, N, M):
	Y_i = []
	for n in range(N):
		cn = 0
		for m in range(M):
			if Z_ij[n, m] == 1.:
				cn += 1
		if cn >= (M / 2):
			Y_i.append(1)
		else:
			Y_i.append(0)
	return Y_i

def get_individual_concensus(Y_i, Z_ij, N, M):
	concensus = np.zeros([N,M])
	for n in range(N):
		for m in range(M):
			if Z_ij[n,m] == Y_i[n]:
				concensus[n,m] = 1
	return concensus

def Q_B_1():
	X_ij, Z_ij, N, M = read_labeled_matrix('./surveylabeled.dat')
	Y_i = get_precincts_preferences(Z_ij, N, M)
	concensus = get_individual_concensus(Y_i, Z_ij, N, M)

	phi = sum(Y_i) / len(Y_i)
	lamb = np.sum(concensus) / (N * M)
	_, mu1, mu0, sigma1, sigma0 = Q_A_1()

	return phi, lamb, mu1, mu0, sigma1, sigma0

def compute_P_Yi_given_all_Xi(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0):
	normals = np.zeros(shape=[2, N, M])  # P(X_ij | Z_ij=0), P(X_ij | Z_ij=1)
	mus = [mu0, mu1]
	sigmas = [sigma0, sigma1]

	for i in range(2):
		for n in range(N):
			for m in range(M):
				normals[i, n, m] = multivariate_gaussian(Xij[n, m], mus[i], sigmas[i])

	P_Xij_given_Yi = np.zeros(shape=[2, N, M])
	P_Xij_given_Yi[1] = lamb * normals[1] + (1 - lamb) * normals[0]  # P(X_ij | Yi=1)
	P_Xij_given_Yi[0] = (1 - lamb) * normals[1] + lamb * normals[0]  # P(X_ij | Yi=0)

	P_Xi_given_Yi = np.exp(
		np.sum(
			np.log(
				P_Xij_given_Yi
			), axis=2
		)
	)

	P_Yi_given_X = P_Xi_given_Yi.T
	P_Yi_given_X[:, 0] *= 1 - phi
	P_Yi_given_X[:, 1] *= phi

	Z = np.sum(P_Yi_given_X, axis=1)
	P_Yi_given_X = np.transpose(P_Yi_given_X.T / Z)  # [P(Yi=0|Xi), P(Yi=1|Xi)]
	return P_Yi_given_X

def Q_B_2_1():
	phi, lamb, mu1, mu0, sigma1, sigma0 = Q_B_1()
	Xij, N, M = read_unlabeled_matrix('./surveyunlabeled.dat')


	P_Yi_given_X = compute_P_Yi_given_all_Xi(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)

	result = np.stack(
		[np.arange(0, N), P_Yi_given_X[:,1], P_Yi_given_X[:,1]>=0.5]
	).T
	print(result)

def compute_P_Zij_given_Yi_Xij(yi, xij, phi, lamb, mu1, mu0, sigma1, sigma0):
	# P(yi, zij=1, xij) = P(yi) * P(zij=1 | yi) * P(xij | zij=1)
	P_Z1_Y_X = (phi if yi == 1 else 1 - phi) * (lamb if yi == 1 else 1 - lamb) * multivariate_gaussian(xij, mu1, sigma1)

	# P(yi, zij=0, xij) = P(yi) * P(zij=0 | yi) * P(xij | zij=0)
	P_Z0_Y_X = (phi if yi == 1 else 1 - phi) * (lamb if yi == 0 else 1 - lamb) * multivariate_gaussian(xij, mu0, sigma0)
	return np.array([P_Z0_Y_X, P_Z1_Y_X]) / (P_Z0_Y_X + P_Z1_Y_X)

def compute_P_Zij_given_all_Xi(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0):
	P_Yi_given_all_Xi = compute_P_Yi_given_all_Xi(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)
	P_Zij_given_all_Xi = np.zeros(shape=[2,N,M])

	for n in range(N):
		for m in range(M):
			P_Zij_given_all_Xi[1,n,m] = \
				P_Yi_given_all_Xi[n,1] * compute_P_Zij_given_Yi_Xij(1, Xij[n,m],phi,lamb,mu1, mu0, sigma1, sigma0)[1] + \
				P_Yi_given_all_Xi[n,0] * compute_P_Zij_given_Yi_Xij(0, Xij[n,m],phi,lamb,mu1, mu0, sigma1, sigma0)[1]

	P_Zij_given_all_Xi[0,:,:] = 1 - P_Zij_given_all_Xi[1,:,:]
	return P_Zij_given_all_Xi

def Q_B_2_2():
	Xij, N, M = read_unlabeled_matrix('./surveyunlabeled.dat')
	phi, lamb, mu1, mu0, sigma1, sigma0 = Q_B_1()
	P_Zij_given_all_Xi = compute_P_Zij_given_all_Xi(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)
	fig, ax = plt.subplots(1, 1)
	plot_samples(ax, Xij, P_Zij_given_all_Xi[1,:,:] >= 0.5, N, M)
	plt.show()

def compute_Q_conditional(xij, phi, lamb, mu1, mu0, sigma1, sigma0):
	Qij = np.zeros(shape=[2,2])
	mus = [mu0, mu1]
	sigmas = [sigma0, sigma1]
	for yi in range(2):
		for zij in range(2):
			Qij[yi, zij] = (phi if yi==1 else 1-phi) * (lamb if yi==zij else 1-lamb) \
							* multivariate_gaussian(xij, mus[zij], sigmas[zij])
	Qij /= np.sum(Qij)
	return Qij

def EM_compute_Q(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0):
	Q = np.zeros(shape=[N,M,2,2])
	for n in range(N):
		for m in range(M):
			Q[n,m] = compute_Q_conditional(Xij[n,m], phi, lamb, mu1, mu0, sigma1, sigma0)
	return Q

def compute_log_likelihood_B(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0):
	mus = [mu0, mu1]
	sigmas = [sigma0, sigma1]

	l = 0
	for n in range(N):
		for m in range(M):
			P_Xij = 0
			for yi in range(2):
				for zij in range(2):
					P_Xij += (phi if yi==1 else 1-phi) * (lamb if yi==zij else 1-lamb) * \
								multivariate_gaussian(Xij[n,m], mus[zij], sigmas[zij])
			l += np.log(P_Xij)
	return l

def EM_Geography_Aware_Mixture_Model(Xij, N, M, phi=None, lamb=None, mu1=None, mu0=None, sigma1=None, sigma0=None):
	if phi is None:
		phi = 0.65
		lamb = 0.7
		mu1 = np.random.randn(2)
		mu0 = np.random.randn(2)
		sigma1 = np.identity(2)
		sigma0 = np.identity(2)

	l = None
	Q = EM_compute_Q(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)
	num_iter = 0
	log_likelihoods = []

	while True:
		num_iter += 1
		l_new = compute_log_likelihood_B(Xij ,N, M, phi, lamb, mu1, mu0, sigma1, sigma0)
		log_likelihoods.append(l_new)
		print("iterator %d, log-likelihood %.5f" % (num_iter, l_new))

		# E-step
		Q_old = Q.copy()
		Q = EM_compute_Q(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)

		if l is not None and np.linalg.norm(Q-Q_old)<=1e-1:
			return phi, lamb, mu1, mu0, sigma1, sigma0, Q, log_likelihoods
		l = l_new

		# M-step
		# max over phi
		a = np.sum(Q[:,:,1,:])
		b = np.sum(Q[:,:,0,:])
		phi = a / (a + b)

		# max over lamb
		alpha = np.sum(Q[:,:,0,0]) + np.sum(Q[:,:,1,1])
		beta = np.sum(Q[:,:,0,1]) + np.sum(Q[:,:,1,0])
		lamb = alpha / (alpha + beta)

		# max over sigma
		sigma0_numerator = np.zeros_like(sigma0)
		sigma1_numerator = np.zeros_like(sigma1)
		for n in range(N):
			for m in range(M):
				sigma0_numerator += np.outer(Xij[n,m]-mu0, Xij[n,m]-mu0) * np.sum(Q[n,m,:,0])
				sigma1_numerator += np.outer(Xij[n,m]-mu1, Xij[n,m]-mu1) * np.sum(Q[n,m,:,1])
		sigma0 = sigma0_numerator / np.sum(Q[:, :, :, 0])
		sigma1 = sigma1_numerator / np.sum(Q[:, :, :, 1])

		# max over mu
		mu0_numerator = np.zeros_like(mu0)
		mu1_numerator = np.zeros_like(mu1)
		for n in range(N):
			for m in range(M):
				mu0_numerator += Xij[n, m] * np.sum(Q[n,m,:,0])
				mu1_numerator += Xij[n, m] * np.sum(Q[n,m,:,1])
		mu0 = mu0_numerator / np.sum(Q[:, :, :, 0])
		mu1 = mu1_numerator / np.sum(Q[:, :, :, 1])

def Q_B_4():
	Xij, N, M = read_unlabeled_matrix('./surveyunlabeled.dat')
	fig, (ax1, ax2, ax3) = plt.subplots(1,3)

	# MLE initialization
	phi, lamb, mu1, mu0, sigma1, sigma0 = Q_B_1()
	phi, lamb, mu1, mu0, sigma1, sigma0, Q, log_likelihoods = \
		EM_Geography_Aware_Mixture_Model(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)
	print(phi, lamb)
	P_Zij_given_all_Xi = compute_P_Zij_given_all_Xi(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)
	plot_samples(ax2, Xij, P_Zij_given_all_Xi[1,:,:]>=0.5, N, M)
	ax2.set_title('MLE initialization')
	ax1.plot(log_likelihoods,'r--', label='MLE initialization')

	# poor initialization
	phi, lamb, mu1, mu0, sigma1, sigma0, Q, log_likelihoods = EM_Geography_Aware_Mixture_Model(Xij, N, M)
	print(phi, lamb)
	P_Zij_given_all_Xi = compute_P_Zij_given_all_Xi(Xij, N, M, phi, lamb, mu1, mu0, sigma1, sigma0)
	plot_samples(ax3, Xij, P_Zij_given_all_Xi[1,:,:] >= 0.5, N, M)
	ax3.set_title('poor initialization')
	ax1.plot(log_likelihoods, 'b--', label='poor initialization')

	ax1.legend()
	plt.show()

