import random
from math import sqrt
from utilities import mat_mult, mat_dot, mat_minus
from copy import deepcopy
import sys
# sys.path.append('/Users/jim/Python_3/math_libraries')
sys.path.append('/home/pi/python3/math_libraries')
from number_theory import DataError

def positive_definite_system(verbose, a, rhs, scaling_limit, lam_iters, mat_iters):
	n = len(a)
	rng_n = range(n)
	rescale, scales = do_rescale(verbose, a, rhs, rng_n, n, scaling_limit)
	lam = largest_eigenvalue(a, rng_n, lam_iters)
	if verbose:
		print('\nAfter {0:d} iterations:\n   Largest eigenvalue <= {1:.1f}'.format(lam_iters, lam))
	normalize_max_lamda(a, rhs, lam, rng_n)
	k = build_k_matrix(a, rhs, n, rng_n)
	pm, pmm1 = initialize_solution_vectors(rng_n, n, rhs)
	mag_rhs = magnitude_of_rhs(rhs)
	sol = iteration(verbose, a, rhs, rescale, scales, mat_iters, rng_n, k, pm, pmm1, mag_rhs)

	return sol

def scale_iteration(verbose, it_num, a, rhs, rng_n, rescale, scales, pmp1, mag_rhs):
	sol = []
	for i in rng_n:
		sol.append(pmp1[i][0] / pmp1[-2][0])
	if rescale:
		for i in rng_n:
			if scales[i] != 1.0:
				sol[i] /= scales[i]

	if verbose:
		print('\nIterate {0:d}:'.format(it_num))
		for i in rng_n:
			print('   Variable {0:d} = {1:7.5f}'.format((i + 1), sol[i]))
		print('\n   Relative error within 5% if matrix\'s condition number is < {0:.1f}'.format(((it_num + 2) / 2.56)**2))
		print actual_relative_error(a, rhs, rng_n, sol, mag_rhs)

	return sol

# Rescale columns (and rows, for symmetry), if necessary
def do_rescale(verbose, a, rhs, rng_n, n, scaling_limit):
	column_sq_total, column_sums = 0.0, []
	for j in rng_n:
		sum = 0.0
		for i in rng_n:
			sum += a[i][j]**2
		column_sq_total += sum
		column_sums.append(sum)

	column_sq_average = column_sq_total / n

	rescale, scales = False, []
	for i in rng_n:
		temp = column_sums[i] / column_sq_average
		if temp > scaling_limit or temp < 1.0 / scaling_limit:
			if verbose:
				print('Column {0:d} needs to be rescaled'.format(i))
			# Use square root of temp so that squared sum of column
			# coefficient is rescaled to column_sq_average
			scales.append(sqrt(temp))
			rescale = True
		else:
			scales.append(1.0) # sentinel

	if rescale:
		if verbose:
			print('Rescaling')
		for k in rng_n:
			if scales[k] != 1.0:
				s_k = scales[k]
				# rows first
				for j in rng_n:
					a[k][j] /= s_k
				# right hand side
				rhs[k] /= s_k
				# columns next
				for i in rng_n:
					a[i][k] /= s_k
	else:
		if verbose:
			print('Rescaling unnecessary with scaling limit = {0:d}'.format(scaling_limit))

	return rescale, scales

def largest_eigenvalue(a, rng_n, lam_iters):
	# Initial vector of random values for starting power method
	b0 = []
	for _ in rng_n:
		b0.append([random.random()])

	try:
		b = []
		last_b = mat_mult(a, b0)
		for _ in range(lam_iters - 1):
			b.append(last_b)
			last_b = mat_mult(a, last_b)

		b.append(last_b)
		bmp2, bmp1, bm = b[-1], b[-2], b[-3]
		c0 = mat_dot(bm, bmp1)
		c1 = mat_dot(bmp1, bmp1)
		c2 = mat_dot(bmp1, bmp2)
		c3 = mat_dot(bmp2, bmp2)

		r1, r2, r3 = c1 / c0, c2 / c1, c3 / c2
		lam = r3
		num, den = 2 * (r3 - r2)**2, 2 * r2 - r1 - r3
		if num != 0.0 and den != 0.0:
			lam += num / den
	except DataError as e:
		print('DataError exception occurred: {0:s}'.format(e.value))

	return lam


def normalize_max_lamda(a, rhs, lam, rng_n):
	# Normalize a so maximum eigenvalue is < 1
	for row in a:
		for j in rng_n:
			row[j] /= lam

	# Normalize rhs too
	for i in rng_n:
		rhs[i] /= lam

def build_k_matrix(a, rhs, n, rng_n):
	k = []
	for i in rng_n:
		row = []
		for j in range(n + 2):
			if j < n:
				row.append(-4.0 * a[i][j])
				if i == j:
					row[-1] += 2.0
			elif j == n:
				row.append(4.0 * rhs[i])
			else:
				row.append(0.0)
		k.append(row)

	row = [0.0] * n
	row.append(2.0)
	row.append(2.0)
	k.append(row)

	next_row = deepcopy(row)
	next_row[n] = 0.0
	k.append(next_row)

	return k

def initialize_solution_vectors(rng_n, n, rhs):
	pmm1 = [[0.0]] * n
	pmm1.append([1.0])
	pmm1.append([1.0])

	pm = []
	for i in rng_n:
		pm.append([4.0 * rhs[i]])

	pm.append([4.0])
	pm.append([1.0])

	return pm, pmm1

def magnitude_of_rhs(rhs):
	sum = 0.0
	for el in rhs:
		sum += el**2

	return sqrt(sum)

def iteration(verbose, a, rhs, rescale, scales, mat_iters, rng_n, k, pm, pmm1, mag_rhs):
	for i in range(mat_iters):
		temp_pm = mat_mult(k, pm)
		pmp1 = mat_minus(temp_pm, pmm1)
		pmm1, pm = pm, pmp1
		# pmm1 = pm
		# pm = pmp1
		sol = scale_iteration(verbose, i + 1, a, rhs, rng_n, rescale, scales, pmp1, mag_rhs)

	return sol

def actual_relative_error(a, rhs, rng_n, sol, mag_rhs):
	vec_sol = [[el] for el in sol]
	approx_rhs = mat_mult(a, vec_sol)
	delt = 0.0
	for i in rng_n:
		delt += (rhs[i] - approx_rhs[i][0])**2

	return '   Actual relative error = {0:.3f}%'.format(100 * sqrt(delt) / mag_rhs)

if __name__ == "__main__":
	import sys
	a, rhs = [[3.0, 3.0, 5.0], [3.0, 5.0, 9.0], [5.0, 9.0, 17.0]], [24.0, 40.0, 74.0]
	x = [1.0, 2.0, 3.0]
	lam_iters, mat_iters, scaling_limit = 30, 15, 10
	rng_n, verbose = range(len(a)), True
	sol = positive_definite_system(verbose, a, rhs, scaling_limit, lam_iters, mat_iters)
	print('\n')
	for i in rng_n:
		print('x{0:d} = {1:7.5f} (exact {2:7.5f})'.format((i + 1), sol[i], x[i]))

# Typical results:
# ...
# Iterate 15:
#    Variable 1 = 1.03584
#    Variable 2 = 1.81987
#    Variable 3 = 3.07039
#
#    Relative error within 5% if matrix's condition number is < 44.1
#    Actual relative error = 0.347%


# x1 = 1.03584 (exact 1.00000)
# x2 = 1.81987 (exact 2.00000)
# x3 = 3.07039 (exact 3.00000)