Listing 5: A class for tridiagonal matrices


#ifndef TRIDIAG_H
#define TRIDIAG_H

#include "bandstor.h"

template <class T> class tridiagonalMatrix : public bandStorage<T>
{
private:
    vector<T> c;
    vector<T> d;
    vector<T> e;

public:
    tridiagonalMatrix() {}
    tridiagonalMatrix(const int k) : bandStorage<T>(k)
        {lowerBandWidth() = -1; upperBandWidth() = 1;}
    tridiagonalMatrix& operator=(const tridiagonalMatrix&);
    vector<T>& subDiagonal(void);
    vector<T>& superDiagonal(void);
    vector<T>& mainDiagonal(void);
    vector<T> solve(const vector<T>&);
};
template <class T> tridiagonalMatrix<T>&
tridiagonalMatrix<T>::operator=(const
tridiagonalMatrix<T>& M)
{
    return operator=(M);
}

template <class T> vector<T>& tridiagonalMatrix<T>::subDiagonal(void)
{
    return bandStorage<T>::diag(-1);
}

template <class T> vector<T>& tridiagonalMatrix<T>::superDiagonal(void)
{
    return bandStorage<T>::diag(1);
}

template <class T> vector<T>& tridiagonalMatrix<T>::mainDiagonal(void)
{
    return bandStorage<T>::diag(0);
}

template <class T> vector<T>
tridiagonalMatrix<T>::solve(const vector<T>& b)
{
    vector<T> x;
    int n=order();

    int info;
    int k, kb, kp1, nm1, nm2;
    T t;

    x = b;
    c = vector<T>(n, 0.0);
    e = vector<T>(n, 0.0);

    for (k=0; k<n-1; k++) {
        c[k+1] = subDiagonal()[k];
        e[k] = superDiagonal()[k];
        }
    d = mainDiagonal();

    info = 0;
    c[0] = d[0];
    nm1 = n-1;

    if (nm1 >= 1) {

        d[0] = e[0];
        e[0] = 0.0;
        e[n-1] = 0.0;

        for (k=1; k<=nm1; k++) {
            kp1 = k+1;

// find largest of two rows

            if (fabs(c[kp1-1]) > fabs(c[k-1])) {

// interchange rows

                swap(c[kp1-1], c[k-1]);
                swap(d[kp1-1], d[k-1]);
                swap(e[kp1-1], e[k-1]);
                swap(x[kp1-1], x[k-1]);
                }

            if (c[k-1] == 0.0)
                throw ("zero diagonal encoutered in factorization");

            t = -c[kp1-1]/c[k-1];
            c[kp1-1] = d[kp1-1] + t*d[k-1];
            d[kp1-1] = e[kp1-1] + t*e[k-1];
            e[kp1-1] = 0.0;
            x[kp1-1] = x[kp1-1] + t*x[k-1];
            }
        }

    if (c[n-1] == 0.0)
        throw ("zero diagonal encoutered in factorization");

// Back solve

    nm2 = n-2;
    x[n-1] = x[n-1]/c[n-1];
    if (n > 1) {
        x[nm1-1] = (x[nm1-1] - d[nm1-1]*x[n-1])/c[nm1-1];
        if (nm2 > 1) {
            for (kb = 1; kb <= nm2; kb++) {
                k = nm2 - kb + 1;
                x[k-1] = (x[k-1] - d[k-1]*x[k+1-1]
                    - e[k-1]*x[k+2-1])/c[k-1];
                }
            }
        }

    return x;
}
#endif
//End of File