#include "householder.h"

#include <rumba/manifoldmatrix.h>

using namespace RUMBA;

namespace 
{
double sumSquares(const ManifoldMatrix&);
}

void RUMBA::apply_hh_left 
( ManifoldMatrix& M, const ManifoldMatrix& v, bool skip  )
{
	int i;
	double v_dot_v = 0;
	double v_dot_x = 0;
	int start_index = M.rows() - v.rows();

	for ( int j = 0; j < v.rows(); ++j ) 
		v_dot_v += v.element(j,0) * v.element(j,0);

	if (skip)
		i = start_index;
	else
		i = 0;

	for ( ; i < M.cols(); ++i )
	{
		v_dot_x = 0;
		for ( int j = start_index; j < M.rows(); ++j )
			v_dot_x  += v.element(j-start_index,0) * M.element(j,i);
		for ( int j = start_index; j < M.rows(); ++j )
			M.element(j,i) -= 2 * v_dot_x * v.element(j-start_index,0) / v_dot_v;
	}
}

void RUMBA::apply_hh_right 
( ManifoldMatrix& M, const ManifoldMatrix& v, bool skip  )
{
	double v_dot_v = 0;
	double v_dot_x = 0;
	int start_index = M.cols() - v.rows();
	int i;
	if ( skip )
		i = start_index-1;
	else
		i = 0;

	for ( int j = 0; j < v.rows(); ++j ) 
		v_dot_v += v.element(j,0) * v.element(j,0);

	for ( ; i < M.rows(); ++i )
	{
		v_dot_x = 0;
		for ( int j = start_index; j < M.cols(); ++j )
			v_dot_x  += v.element(j-start_index,0) * M.element(i,j);

		for ( int j = start_index; j < M.cols(); ++j )
			M.element(i,j) -= 2 * v_dot_x * v.element(j-start_index,0) / v_dot_v;
	}
}

// x is a column vector
ManifoldMatrix 
RUMBA::houseHolderVector ( ManifoldMatrix x, double & beta )
{

	double x1_sq = x.element(0,0) * x.element(0,0);
	double sigma = sumSquares(x) - x1_sq;
	double mu;
	ManifoldMatrix v = makeMatrix(x.M.copy()); 
	if ( v.rows() <= 1 )
		v = v.transpose();

	double* vbegin = v.begin();

	*vbegin = 1;

	if ( sigma == 0 )
		beta = 0;
	else
	{
		mu = std::sqrt ( x1_sq + sigma );
		if (x.element(0,0) <= 0)
			*vbegin = x.element(0,0) - mu;
		else
			*vbegin = -1 * sigma / ( x.element(0,0) + mu );

		beta = 2 * (*vbegin)*(*vbegin)/ (sigma + (*vbegin) * (*vbegin));
		v = v * (1/(*vbegin));

	}


	return v;
}


void 
RUMBA::houseHolderBiDiagonalize ( ManifoldMatrix&  A)
{
	const int n = A.cols(), m = A.rows();
	double beta = 0;
	Manifold<double> mfdTmp(intPoint(1,1,1,1));
	ManifoldMatrix v = makeMatrix(mfdTmp);
	for ( int j = 0; j < n; ++j )
	{
		
		v = houseHolderVector(A.subMatrix(j,m-j,j,1),beta);
		apply_hh_left(A,v,true);

//		writeManifoldMatrix(A,cerr);
		if (m-j-1 >0) 
			A.put(j+1,j,v.subMatrix(1,m-j-1,0,1)); 
		
		if ( j < n - 2 )
		{
			v = houseHolderVector( A.subMatrix(j,1,j+1,n-j-1 ).transpose(), beta );
			apply_hh_right(A,v,true);
			A.put(j,j+2,v.subMatrix(1,n-j-2,0,1).transpose() );
		}
	}
}


namespace 
{

double sumSquares ( const ManifoldMatrix& M)
{
	double sum = 0;
	for ( double* it = M.begin(); it != M.end(); ++it )
		sum += *it * *it;

	return sum;
}

}
