/*
 * ScatterMatrixProducer.cpp
 *
 *  Created on: 25.03.2013
 *      Author: 
 */

#include "ScatterMatrixProducer.h"
#include <cmath>
#include <iostream>

ScatterMatrixProducer::ScatterMatrixProducer() {}

void ScatterMatrixProducer::addLayer(LayerData& layer) {
	if (layers.size() > 0) {
		//it's not the first layer
		LayerData& previousLayer = layers.at(layers.size() - 1);
		interfaceMatrices.push_back(layer.getBoundaryConditionsReversed() * previousLayer.getBoundaryConditions());
	}
	layers.push_back(layer);
}

Matrix<complex_t> ScatterMatrixProducer::calculateScatterMatrix() {
	if (layers.size() < 2) {
		std::cerr << "Not enough interfaces for S-matrix calculation" << std::endl;
	}
	//XXX: for now, assume that l+ = l- = l, and l(p) = l(k) for all p, k
	size_t l = layers.at(0).getPlusLambdaCount();
	size_t l_plus_p;
	size_t l_minus_p;
	size_t l_plus_p_1 = l;
	size_t l_minus_p_1 = l;
	data_t h_p;
	Matrix<complex_t> T_u_u_p_initial = Matrix<complex_t>::getIdentity(l_plus_p_1);
	Matrix<complex_t> R_u_d_p_initial = Matrix<complex_t>::getZero(l_plus_p_1, l_minus_p_1);
	Matrix<complex_t> R_d_u_p_initial = Matrix<complex_t>::getZero(l_minus_p_1, l_plus_p_1);
	Matrix<complex_t> T_d_d_p_initial = Matrix<complex_t>::getIdentity(l_minus_p_1);
	Matrix<complex_t>* T_u_u_p;
	Matrix<complex_t>* R_u_d_p;
	Matrix<complex_t>* R_d_u_p;
	Matrix<complex_t>* T_d_d_p;
	Matrix<complex_t>* T_u_u_p_1 = &T_u_u_p_initial;
	Matrix<complex_t>* R_u_d_p_1 = &R_u_d_p_initial;
	Matrix<complex_t>* R_d_u_p_1 = &R_d_u_p_initial;
	Matrix<complex_t>* T_d_d_p_1 = &T_d_d_p_initial;
	for (size_t interfaceNum = 0; interfaceNum < interfaceMatrices.size(); ++interfaceNum) {
		//calculate next matrix
		LayerData& layerData = layers.at(interfaceNum);
		l_plus_p = layerData.getPlusLambdaCount();
		l_minus_p = layerData.getMinusLambdaCount();
		h_p = layerData.getThickness();
		std::vector<complex_t> phiPlusDiag;
		std::vector<complex_t> phiMinusDiag;
		std::vector<complex_t>& plusLambdas = layerData.getPlusLambda();
		std::vector<complex_t>& minusLambdas = layerData.getMinusLambda();
		for (size_t i = 0; i < l_plus_p; ++i) {
			complex_t currentLambda = plusLambdas.at(i);
			phiPlusDiag.push_back(complex_t(cos(h_p * currentLambda.getRe()), sin(h_p * currentLambda.getIm())));
		}
		for (size_t i = 0; i < l_minus_p; ++i) {
			complex_t currentLambda = minusLambdas.at(i);
			phiMinusDiag.push_back(complex_t(cos(h_p * currentLambda.getRe()), sin(h_p * currentLambda.getIm())));
		}
		Matrix<complex_t> phi_plus = Matrix<complex_t>::getDiagonalMatrix(phiPlusDiag);
		Matrix<complex_t> phi_minus_ = Matrix<complex_t>::getDiagonalMatrix(phiMinusDiag).inverse();
		Matrix<complex_t> omega = phi_plus * (*R_u_d_p_1) * phi_minus_;
		Matrix<complex_t>& interfaceMatrix = interfaceMatrices.at(interfaceNum);
		std::cout << "interface matrix: " << std::endl << interfaceMatrix << std::endl;
		Matrix<complex_t>const& t_1_1 = interfaceMatrix.getSubMatrix(0, l_plus_p, 0, l_plus_p);
		Matrix<complex_t>const& t_1_2 = interfaceMatrix.getSubMatrix(0, l_plus_p, l_plus_p, l_plus_p + l_minus_p);
		Matrix<complex_t>const& t_2_1 = interfaceMatrix.getSubMatrix(l_plus_p, l_plus_p + l_minus_p, 0, l_plus_p);
		Matrix<complex_t>const& t_2_2 = interfaceMatrix.getSubMatrix(l_plus_p, l_plus_p + l_minus_p, l_plus_p, l_plus_p + l_minus_p);
		Matrix<complex_t> reversedOmegaPiece = (t_2_2 + t_2_1 * omega).inverse();

		//allocate space for matrices of this layer
		T_u_u_p = new Matrix<complex_t>(l_plus_p, l_plus_p);
		R_u_d_p = new Matrix<complex_t>(l_plus_p, l_minus_p);
		R_d_u_p = new Matrix<complex_t>(l_minus_p, l_plus_p);
		T_d_d_p = new Matrix<complex_t>(l_minus_p, l_minus_p);

		(*R_u_d_p) = (t_1_2 + t_1_1 * omega) * reversedOmegaPiece;
		(*T_d_d_p) = (*T_d_d_p_1) * phi_minus_ * reversedOmegaPiece;
		(*T_u_u_p) = ((t_1_1 - (*R_u_d_p) * t_2_1) * phi_plus) * (*T_u_u_p_1);
		(*R_d_u_p) = (*R_d_u_p_1) - (*T_d_d_p) * t_2_1 * phi_plus * (*T_u_u_p_1);

		if (interfaceNum != 0) {
			//free previous matrix data: it's not necessary now
			delete T_u_u_p_1;
			delete R_u_d_p_1;
			delete R_d_u_p_1;
			delete T_d_d_p_1;
		}
		T_u_u_p_1 = T_u_u_p;
		R_u_d_p_1 = R_u_d_p;
		R_d_u_p_1 = R_d_u_p;
		T_d_d_p_1 = T_d_d_p;
		l_plus_p_1 = l_plus_p;
		l_minus_p_1 = l_minus_p;
	}
	Matrix<complex_t> result(l_plus_p + l_minus_p, l_plus_p + l_minus_p);
	result.setSubMatrix(0, l_plus_p, 0, l_plus_p, (*T_u_u_p));
	result.setSubMatrix(0, l_plus_p, l_plus_p, l_plus_p + l_minus_p, (*R_u_d_p));
	result.setSubMatrix(l_plus_p, l_plus_p + l_minus_p, 0, l_plus_p, (*R_d_u_p));
	result.setSubMatrix(l_plus_p, l_plus_p + l_minus_p, l_plus_p, l_plus_p + l_minus_p, (*T_d_d_p));

	delete T_u_u_p_1;
	delete R_u_d_p_1;
	delete R_d_u_p_1;
	delete T_d_d_p_1;

	return result;
}