#ifndef TRIDIAG_H
#define TRIDIAG_H
#include "bandstor.h"
template <class T> class tridiagonalMatrix : public bandStorage<T>
{
private:
vector<T> c;
vector<T> d;
vector<T> e;
public:
tridiagonalMatrix() {}
tridiagonalMatrix(const int k) : bandStorage<T>(k)
{lowerBandWidth() = -1; upperBandWidth() = 1;}
tridiagonalMatrix& operator=(const tridiagonalMatrix&);
vector<T>& subDiagonal(void);
vector<T>& superDiagonal(void);
vector<T>& mainDiagonal(void);
vector<T> solve(const vector<T>&);
};
template <class T> tridiagonalMatrix<T>&
tridiagonalMatrix<T>::operator=(const
tridiagonalMatrix<T>& M)
{
return operator=(M);
}
template <class T> vector<T>& tridiagonalMatrix<T>::subDiagonal(void)
{
return bandStorage<T>::diag(-1);
}
template <class T> vector<T>& tridiagonalMatrix<T>::superDiagonal(void)
{
return bandStorage<T>::diag(1);
}
template <class T> vector<T>& tridiagonalMatrix<T>::mainDiagonal(void)
{
return bandStorage<T>::diag(0);
}
template <class T> vector<T>
tridiagonalMatrix<T>::solve(const vector<T>& b)
{
vector<T> x;
int n=order();
int info;
int k, kb, kp1, nm1, nm2;
T t;
x = b;
c = vector<T>(n, 0.0);
e = vector<T>(n, 0.0);
for (k=0; k<n-1; k++) {
c[k+1] = subDiagonal()[k];
e[k] = superDiagonal()[k];
}
d = mainDiagonal();
info = 0;
c[0] = d[0];
nm1 = n-1;
if (nm1 >= 1) {
d[0] = e[0];
e[0] = 0.0;
e[n-1] = 0.0;
for (k=1; k<=nm1; k++) {
kp1 = k+1;
// find largest of two rows
if (fabs(c[kp1-1]) > fabs(c[k-1])) {
// interchange rows
swap(c[kp1-1], c[k-1]);
swap(d[kp1-1], d[k-1]);
swap(e[kp1-1], e[k-1]);
swap(x[kp1-1], x[k-1]);
}
if (c[k-1] == 0.0)
throw ("zero diagonal encoutered in factorization");
t = -c[kp1-1]/c[k-1];
c[kp1-1] = d[kp1-1] + t*d[k-1];
d[kp1-1] = e[kp1-1] + t*e[k-1];
e[kp1-1] = 0.0;
x[kp1-1] = x[kp1-1] + t*x[k-1];
}
}
if (c[n-1] == 0.0)
throw ("zero diagonal encoutered in factorization");
// Back solve
nm2 = n-2;
x[n-1] = x[n-1]/c[n-1];
if (n > 1) {
x[nm1-1] = (x[nm1-1] - d[nm1-1]*x[n-1])/c[nm1-1];
if (nm2 > 1) {
for (kb = 1; kb <= nm2; kb++) {
k = nm2 - kb + 1;
x[k-1] = (x[k-1] - d[k-1]*x[k+1-1]
- e[k-1]*x[k+2-1])/c[k-1];
}
}
}
return x;
}
#endif
//End of File