#include "NumericalTests.h"
#include "../utilSolvers/BrendtRootSolver.h"
#include "../utilSolvers/NumerovCauchySolver.h"
#include "TestFunctions.h"
#include "../common/CommonFunctions.h"
#include "math.h"
#include "../Schrodinger/OmegaFunction.h"

void testBrendt(std::ostream & logStream) {
	BrendtRootSolver rootSolver;
	
	Sine sine;
	BrendtTestPoly1 brendtPoly;
	data_t brendtPrecision = 0.0000000000001;

	logStream << "Brendt tests:" << std::endl;
	data_t root = rootSolver.solveLocalized(brendtPoly, brendtPrecision, -3, 0,
			brendtPoly.at(-3), brendtPoly.at(0));
	logStream << "root1:" <<  -1.0 << " <-> " <<  root << ", f(root) = " << brendtPoly.at(root) << std::endl;
	root = rootSolver.solveLocalized(brendtPoly, brendtPrecision, 0, 2,
			brendtPoly.at(0), brendtPoly.at(2));
	logStream << "root2: " << 1.0 << " <-> " << root << " , f(root) = " << brendtPoly.at(root) << std::endl;
	root = rootSolver.solveLocalized(brendtPoly, brendtPrecision, 2, 10,
			brendtPoly.at(2), brendtPoly.at(10));
	logStream << "root3: " << 3.0 << " <-> " << root << " , f(root) = " << brendtPoly.at(root) << std::endl;
	root = rootSolver.solveLocalized(sine, brendtPrecision, -0.1, 0.1,
			sine.at(-0.1), sine.at(0.1));
	logStream << "root1: " << 0.0 << " <-> " << root << " , f(root) = " << sine.at(root) << std::endl;

}

void testNumerov(std::ostream & logStream) {
	NumerovCauchySolver cauchySolver;
	const ID x;
	Sine sine;
	printf("Numerov tests:\n");
	ConstFun k1Test(1);
	int numerovNodesCount = 100;
	data_t* numerovX = new data_t[numerovNodesCount];
	data_t* numerovY = new data_t[numerovNodesCount];
	data_t numerovBase = 0;
	data_t numerovStep = M_PI / (numerovNodesCount - 1);

	printf("Sine:\n");
	//init points
	numerovX[0] = numerovBase;
	for (int i = 1; i < numerovNodesCount; ++i) {
		numerovX[i] = numerovX[i - 1] + numerovStep;
	}
	cauchySolver.solve(numerovNodesCount, numerovX, numerovY, k1Test, 0, 1);
	for (int i = 0; i < numerovNodesCount; ++i) {
		printf("x = %f, %f <-> %f, diff = %f\n", numerovX[i], numerovY[i],
				sine.at(numerovX[i]), abs(numerovY[i] - sine.at(numerovX[i])));
	}

	printf("x*sin(x)\n");
	numerovBase = 1;
	//init points
	numerovX[0] = numerovBase;
	for (int i = 1; i < numerovNodesCount; ++i) {
		numerovX[i] = numerovX[i - 1] + numerovStep;
	}
	Numerov1Answer numerov1Answer;
	NumerovK1 numerovk1;
	cauchySolver.solve(numerovNodesCount, numerovX, numerovY, numerovk1,
			numerov1Answer.at(1), sin(1.0) + cos(1.0));
	for (int i = 0; i < numerovNodesCount; ++i) {
		printf("x = %f, %f <-> %f, diff = %f\n", numerovX[i], numerovY[i],
				numerov1Answer.at(numerovX[i]),
				abs(numerovY[i] - numerov1Answer.at(numerovX[i])));
	}

	delete[] numerovX;
	delete[] numerovY;
}

void testOmega(std::ostream& logStream, PotentialFunction & potential, data_t xMin, data_t xStep, int pointsCount, data_t mass, data_t startValue, data_t endValue, SchrodingerShootingBoundaryConditionGenerator & startBoundaryGenerator, SchrodingerShootingBoundaryConditionGenerator & endBoundaryGenerator, data_t energyStart, data_t energyStep, int energySteps) {
	std::vector<data_t> phiPlusX;
	std::vector<data_t> phiPlusY;
	std::vector<data_t> phiMinusX;
	std::vector<data_t> phiMinusY;
	OmegaFunction omega(potential, xMin, xStep, pointsCount, mass, startValue, endValue, startBoundaryGenerator, endBoundaryGenerator, phiPlusX, phiMinusX, phiPlusY, phiMinusY); //XXX: set some value for the log
	data_t curEnergy = energyStart;
	for (int i = 0; i < energySteps; ++i, curEnergy += energyStep) {
		logStream << curEnergy << " " << omega.at(curEnergy) << std::endl;
	}
}