#ifndef SPDBANDF_H
#define SPDBANDF_H
#include "bandstor.h"
template <class T> class SPDBandMatrixFactor : bandStorage<T>
{
private:
public:
SPDBandMatrixFactor() {}
SPDBandMatrixFactor(const int k) : bandStorage<T>(k) {}
void factor(SPDBandMatrix<T>&);
vector<T> solve(vector<T>&);
T& operator()(const int, const int);
vector<T>& diagonal(const int);
};
template <class T> T& SPDBandMatrixFactor<T>::operator()(const int i,
const int j)
{
if (i > j)
return bandStorage<T>::operator()(i,j);
else
return bandStorage<T>::operator()(j,i);
}
template <class T> vector<T>& SPDBandMatrixFactor<T>::diagonal(const int i)
{
return bandStorage<T>::diag(int(-fabs(i)));
}
template <class T> void SPDBandMatrixFactor<T>::factor(SPDBandMatrix<T> &B)
{
int i, j, k, lambda;
int n = B.order();
int bandWidth = -B.lowerBandWidth();
T sqrtDiag;
vector<T> tmp;
for (i=0; i<=bandWidth; i++)
diagonal(i) = B.diagonal(i);
this->upperBandWidth() = B.upperBandWidth();
this->lowerBandWidth() = B.lowerBandWidth();
for (j=1; j<=n; j++){
for (k=MAX(1, j-bandWidth); k<=j-1; k++) {
lambda = MIN(k+bandWidth, n);
for (i=j; i<=lambda; i++){
(*this)(i-1,j-1) -= (*this)(j-1,
k-1)*(*this)(i-1,k-1);
}
}
lambda = MIN(j+bandWidth, n);
sqrtDiag = sqrt(T((*this)(j-1,j-1)));
for (i=j; i<=lambda; i++)
(*this)(i-1,j-1) /= sqrtDiag;
}
}
template <class T> vector<T> SPDBandMatrixFactor<T>::solve(vector<T> &b)
{
int n = order();
vector<T> x(n);
int i, j;
int width = upperBandWidth() - lowerBandWidth();
T sum;
x = b;
// forward elimination
for (i=1; i<=n; i++) {
sum = 0.0f;
for (j=MAX(1,i-width); j<=i-1; j++)
sum += (*this)(i-1,j-1)*x[j-1];
x[i-1] = (x[i-1] - sum)/(*this)(i-1,i-1);
}
// back substitution
for (i=n; i>=1; i--) {
sum = 0.0f;
for (j=i+1; j<=MIN(i+width,n); j++)
sum += (*this)(j-1,i-1)*x[j-1];
x[i-1] = (x[i-1] - sum)/(*this)(i-1,i-1);
}
return x;
}
#endif
//End of File