import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt

from factor import Factor, all_assignments
from factor_graph import FactorGraph

def construct_factor_graph(H:np.array, yTilde, epsilon):
	"""
	:param H:
	:param yTilde: Observed noisy y
	:param epsilon: The probability that a bit is flipped under noise
	:return:
	"""
	m,n = H.shape
	assert n == 2*m
	var_names = np.array(['Y_%d'%(i) for i in range(1,n+1)])

	G = FactorGraph(numVar=n, numFactor=m+n)
	G.var = var_names


	yt0 = [1.-epsilon, epsilon]
	yt1 = [epsilon, 1.-epsilon]
	# unary factors, P(Y[i] | y_tilde[i])
	for i, var in enumerate(var_names):
		fName = var + '_tilde' # name of this factor
		G.domains[var] = [0,1]
		G.factors.append(
			Factor(scope=[var], card=[2], val=yt0 if yTilde[i]==0 else yt1, name=fName)
		)
		varToFactor = G.varToFactor.get(var, set())
		varToFactor.add(fName)
		G.varToFactor[var] = varToFactor

		factorToVar = G.factorToVar.get(fName, set())
		factorToVar.add(var)
		G.factorToVar[fName] = factorToVar

	# parity check factors
	for i,row in enumerate(H):
		idx = np.where(row==1)[0]
		scope = var_names[idx]
		card = [2 for _ in idx]
		assignments = all_assignments(card=card)
		val = []
		for assi in assignments:
			if np.sum(assi) % 2 == 0:
				val.append(1)
			else:
				val.append(0)

		fName = 'p_%d'%(i+1)
		G.factors.append(
			Factor(scope=scope, card=card, val=val, name=fName)
		)

		G.factorToVar[fName] = set(scope)

		for var in scope:
			G.varToFactor[var].add(fName)

	# y_tilde =        1, 1, 1, 0, 0, 0
	# ytest1 = np.array([0, 1, 1, 0, 1, 0])
	# ytest2 = np.array([1, 0, 1, 1, 0, 1])
	# ytest3 = np.array([1, 0, 1, 1, 1, 1])
	# ytest4 = np.array([1, 1, 1, 1, 0, 0])
	# print(
	# 	G.evaluateWeight(ytest1),
	# 	G.evaluateWeight(ytest2),
	# 	G.evaluateWeight(ytest3),
	# 	G.evaluateWeight(ytest4)
	# )

	# print(G.domains)
	return G

def applyChannelNoise(y, epsilon):
	rand = np.random.uniform(low=0, high=1, size=y.shape)
	return np.abs(y-(rand<epsilon)).astype('int32')

def encodeMessage(x, G):
	return np.mod(np.dot(G, x), 2)

def qa():
	H = np.array([
		[0, 1, 1, 0, 1, 0],
		[0, 1, 0, 1, 1, 0],
		[1, 0, 1, 0, 1, 1]])
	y_tilde = np.array([1,1,1,0,0,0])
	construct_factor_graph(H, y_tilde, epsilon=0.1)

def qb():
	H = np.array([
		[0, 1, 1, 0, 1, 0],
		[0, 1, 0, 1, 1, 0],
		[1, 0, 1, 0, 1, 1]])
	y_tilde = np.array([1, 1, 1, 0, 0, 0])
	G = construct_factor_graph(H, y_tilde, epsilon=0.1)
	G.initializeBeliefs()
	G.runParallelLoopyBeliefPropagation(2)

def qc():
	mat = loadmat('ldpc36-128.mat')
	H = mat['H']
	G = mat['G']
	epsilon = 0.01

	m, n = H.shape # l, 2l
	y = np.zeros(shape=[n])
	y_tilde = applyChannelNoise(y, epsilon)

	G = construct_factor_graph(H, y_tilde, epsilon)
	G.initializeBeliefs()
	G.runParallelLoopyBeliefPropagation(50)

	values = []
	for var in G.var:
		values.append(G.estimateMarginalProbability(var)[1])

	plt.plot(range(len(G.var)), values)
	plt.show()

	print(G.getMarginalMAP())

def qd():
	mat = loadmat('ldpc36-128.mat')
	H = mat['H']
	epsilon = 0.1

	m, n = H.shape  # l, 2l
	y = np.zeros(shape=[n])


	mat = []
	for i in range(10):
		row = []
		y_tilde = applyChannelNoise(y, epsilon)

		G = construct_factor_graph(H, y_tilde, epsilon)
		G.initializeBeliefs()
		print(i)
		for __ in range(30):
			G.runParallelLoopyBeliefPropagation(1)
			map = G.getMarginalMAP()
			row.append(np.sum(map))
		mat.append(row)

	for i,row in enumerate(mat):
		plt.plot(row, label='iteration %d'%(i+1))
	plt.legend()
	plt.show()

def qfg():
	mat = loadmat('ldpc36-1600.mat')
	H = mat['H']
	G = mat['G']
	epsilon = 0.1

	image = loadmat('images.mat')['cs242']
	shape = image.shape

	x = np.reshape(image,[-1])

	y = encodeMessage(x, G)
	y_tilde = applyChannelNoise(y, epsilon)

	G = construct_factor_graph(H, y_tilde, epsilon)
	G.initializeBeliefs()

	predictions = []
	num = 8

	for i in range(num):
		G.runParallelLoopyBeliefPropagation(2)
		map = G.getMarginalMAP()
		image = np.reshape(map[:1600],newshape=shape)
		plt.imshow(image)
		plt.savefig('%d.png'%(i+1))
		# plt.show()
		predictions.append(image)

	fig, axes = plt.subplots(1,num)
	for i in range(num):
		axes[i].imshow(predictions[i])
		axes[i].set_title('%d iterations' % (i * 3 + 3))
	plt.show()


if __name__ == "__main__":

	# qa()
	# qb()
	# qc()
	# qd()
	# qg()
	pass



