Listing 3: A class that supports Cholesky decomposition


#ifndef SPDBANDF_H
#define SPDBANDF_H

#include "bandstor.h"

template <class T> class SPDBandMatrixFactor : bandStorage<T>
{
private:

public:
    SPDBandMatrixFactor() {}
    SPDBandMatrixFactor(const int k) : bandStorage<T>(k) {}
    void factor(SPDBandMatrix<T>&);
    vector<T> solve(vector<T>&);
    T& operator()(const int, const int);
    vector<T>& diagonal(const int);
};

template <class T> T& SPDBandMatrixFactor<T>::operator()(const int i,
                                                         const int j)
{
    if (i > j)
        return bandStorage<T>::operator()(i,j);
    else
        return bandStorage<T>::operator()(j,i);
}

template <class T> vector<T>& SPDBandMatrixFactor<T>::diagonal(const int i)
{
    return bandStorage<T>::diag(int(-fabs(i)));
}
template <class T> void SPDBandMatrixFactor<T>::factor(SPDBandMatrix<T> &B)
{
    int i, j, k, lambda;
    int n = B.order();
    int bandWidth = -B.lowerBandWidth();
    T sqrtDiag;
    vector<T> tmp;

    for (i=0; i<=bandWidth; i++)
        diagonal(i) = B.diagonal(i);

    this->upperBandWidth() = B.upperBandWidth();
    this->lowerBandWidth() = B.lowerBandWidth();

    for (j=1; j<=n; j++){
        for (k=MAX(1, j-bandWidth); k<=j-1; k++) {
            lambda = MIN(k+bandWidth, n);
            for (i=j; i<=lambda; i++){
                (*this)(i-1,j-1) -= (*this)(j-1,
                    k-1)*(*this)(i-1,k-1);
                }
            }

        lambda = MIN(j+bandWidth, n);
        sqrtDiag = sqrt(T((*this)(j-1,j-1)));
        for (i=j; i<=lambda; i++)
            (*this)(i-1,j-1) /= sqrtDiag;
    }
}

template <class T> vector<T> SPDBandMatrixFactor<T>::solve(vector<T> &b)
{
    int n = order();
    vector<T> x(n);
    int i, j;
    int width = upperBandWidth() - lowerBandWidth();
    T sum;
    x = b;

// forward elimination

    for (i=1; i<=n; i++) {
        sum = 0.0f;
        for (j=MAX(1,i-width); j<=i-1; j++)
            sum += (*this)(i-1,j-1)*x[j-1];
        x[i-1] = (x[i-1] - sum)/(*this)(i-1,i-1);
        }

// back substitution

    for (i=n; i>=1; i--) {
        sum = 0.0f;
        for (j=i+1; j<=MIN(i+width,n); j++)
            sum += (*this)(j-1,i-1)*x[j-1];
        x[i-1] = (x[i-1] - sum)/(*this)(i-1,i-1);
        }
    return x;
}

#endif
//End of File