00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070 #ifndef SCYTHE_IDE_H
00071 #define SCYTHE_IDE_H
00072
00073 #ifdef SCYTHE_COMPILE_DIRECT
00074 #include "matrix.h"
00075 #include "error.h"
00076 #include "defs.h"
00077 #ifdef SCYTHE_LAPACK
00078 #include "lapack.h"
00079 #include "stat.h"
00080 #endif
00081 #else
00082 #include "scythestat/matrix.h"
00083 #include "scythestat/error.h"
00084 #include "scythestat/defs.h"
00085 #ifdef SCYTHE_LAPACK
00086 #include "scythestat/lapack.h"
00087 #include "scythestat/stat.h"
00088 #endif
00089 #endif
00090
00091 #include <cmath>
00092 #include <algorithm>
00093
00094 namespace scythe {
00095
00096 namespace {
00097 typedef unsigned int uint;
00098 }
00099
00123 template <matrix_order RO, matrix_style RS, typename T,
00124 matrix_order PO, matrix_style PS>
00125 Matrix<T, RO, RS>
00126 cholesky (const Matrix<T, PO, PS>& A)
00127 {
00128 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error,
00129 "Matrix not square");
00130 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00131 "Matrix is NULL");
00132
00133
00134
00135
00136 Matrix<T,RO,Concrete> temp (A.rows(), A.cols(), false);
00137 T h;
00138
00139 if (PO == Row) {
00140 for (uint i = 0; i < A.rows(); ++i) {
00141 for (uint j = i; j < A.cols(); ++j) {
00142 h = A(i,j);
00143 for (uint k = 0; k < i; ++k)
00144 h -= temp(i, k) * temp(j, k);
00145 if (i == j) {
00146 SCYTHE_CHECK_20(h <= (T) 0, scythe_type_error,
00147 "Matrix not positive definite");
00148
00149 temp(i,i) = std::sqrt(h);
00150 } else {
00151 temp(j,i) = (((T) 1) / temp(i,i)) * h;
00152 temp(i,j) = (T) 0;
00153 }
00154 }
00155 }
00156 } else {
00157 for (uint j = 0; j < A.cols(); ++j) {
00158 for (uint i = j; i < A.rows(); ++i) {
00159 h = A(i, j);
00160 for (uint k = 0; k < j; ++k)
00161 h -= temp(j, k) * temp(i, k);
00162 if (i == j) {
00163 SCYTHE_CHECK_20(h <= (T) 0, scythe_type_error,
00164 "Matrix not positive definite");
00165 temp(j,j) = std::sqrt(h);
00166 } else {
00167 temp(i,j) = (((T) 1) / temp(j,j)) * h;
00168 temp(j,i) = (T) 0;
00169 }
00170 }
00171 }
00172 }
00173
00174 SCYTHE_VIEW_RETURN(T, RO, RS, temp)
00175 }
00176
00177 template <typename T, matrix_order O, matrix_style S>
00178 Matrix<T, O, Concrete>
00179 cholesky (const Matrix<T,O,S>& A)
00180 {
00181 return cholesky<O,Concrete>(A);
00182 }
00183
00184 namespace {
00185
00186
00187
00188 template <typename T,
00189 matrix_order PO1, matrix_style PS1,
00190 matrix_order PO2, matrix_style PS2,
00191 matrix_order PO3, matrix_style PS3>
00192 inline void
00193 solve(const Matrix<T,PO1,PS1>& L, const Matrix<T,PO2,PS2>& U,
00194 Matrix<T,PO3,PS3> b, T* x, T* y)
00195 {
00196 T sum;
00197
00198
00199
00200
00201
00202
00203 for (uint i = 0; i < b.size(); ++i) {
00204 sum = T (0);
00205 for (uint j = 0; j < i; ++j) {
00206 sum += L(i,j) * y[j];
00207 }
00208 y[i] = (b[i] - sum) / L(i, i);
00209 }
00210
00211
00212 if (U.isNull()) {
00213 for (int i = b.size() - 1; i >= 0; --i) {
00214 sum = T(0);
00215 for (uint j = i + 1; j < b.size(); ++j) {
00216 sum += L(j,i) * x[j];
00217 }
00218 x[i] = (y[i] - sum) / L(i, i);
00219 }
00220 } else {
00221 for (int i = b.size() - 1; i >= 0; --i) {
00222 sum = T(0);
00223 for (uint j = i + 1; j < b.size(); ++j) {
00224 sum += U(i,j) * x[j];
00225 }
00226 x[i] = (y[i] - sum) / U(i, i);
00227 }
00228 }
00229 }
00230 }
00231
00258 template <matrix_order RO, matrix_style RS, typename T,
00259 matrix_order PO1, matrix_style PS1,
00260 matrix_order PO2, matrix_style PS2,
00261 matrix_order PO3, matrix_style PS3>
00262 Matrix<T,RO,RS>
00263 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b,
00264 const Matrix<T,PO3,PS3>& M)
00265 {
00266 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00267 "A is NULL")
00268 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error,
00269 "b must be a column vector");
00270 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error,
00271 "A and b do not conform");
00272 SCYTHE_CHECK_10(A.rows() != M.rows(), scythe_conformation_error,
00273 "A and M do not conform");
00274 SCYTHE_CHECK_10(! M.isSquare(), scythe_dimension_error,
00275 "M must be square");
00276
00277 T *y = new T[A.rows()];
00278 T *x = new T[A.rows()];
00279
00280 solve(M, Matrix<>(), b, x, y);
00281
00282 Matrix<T,RO,RS> result(A.rows(), 1, x);
00283
00284 delete[]x;
00285 delete[]y;
00286
00287 return result;
00288 }
00289
00290 template <typename T, matrix_order PO1, matrix_style PS1,
00291 matrix_order PO2, matrix_style PS2,
00292 matrix_order PO3, matrix_style PS3>
00293 Matrix<T,PO1,Concrete>
00294 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b,
00295 const Matrix<T,PO3,PS3>& M)
00296 {
00297 return chol_solve<PO1,Concrete>(A,b,M);
00298 }
00299
00324 template <matrix_order RO, matrix_style RS, typename T,
00325 matrix_order PO1, matrix_style PS1,
00326 matrix_order PO2, matrix_style PS2>
00327 Matrix<T,RO,RS>
00328 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b)
00329 {
00330
00331
00332
00333
00334 return chol_solve<RO,RS>(A, b, cholesky<RO,Concrete>(A));
00335 }
00336
00337 template <typename T, matrix_order PO1, matrix_style PS1,
00338 matrix_order PO2, matrix_style PS2>
00339 Matrix<T,PO1,Concrete>
00340 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b)
00341 {
00342 return chol_solve<PO1,Concrete>(A, b);
00343 }
00344
00345
00368 template <matrix_order RO, matrix_style RS, typename T,
00369 matrix_order PO1, matrix_style PS1,
00370 matrix_order PO2, matrix_style PS2>
00371 Matrix<T,RO,RS>
00372 invpd (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& M)
00373 {
00374 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00375 "A is NULL")
00376 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error,
00377 "A is not square")
00378 SCYTHE_CHECK_10(A.rows() != M.cols() || A.cols() != M.rows(),
00379 scythe_conformation_error, "A and M do not conform");
00380
00381
00382 T *y = new T[A.rows()];
00383 T *x = new T[A.rows()];
00384 Matrix<T, RO, Concrete> b(A.rows(), 1);
00385 Matrix<T, RO, Concrete> null;
00386
00387
00388 Matrix<T, RO, Concrete> Ainv(A.rows(), A.cols(), false);
00389
00390 for (uint k = 0; k < A.rows(); ++k) {
00391 b[k] = (T) 1;
00392
00393 solve(M, null, b, x, y);
00394
00395 b[k] = (T) 0;
00396 for (uint l = 0; l < A.rows(); ++l)
00397 Ainv(l,k) = x[l];
00398 }
00399
00400 delete[] y;
00401 delete[] x;
00402
00403 SCYTHE_VIEW_RETURN(T, RO, RS, Ainv)
00404 }
00405
00406 template <typename T, matrix_order PO1, matrix_style PS1,
00407 matrix_order PO2, matrix_style PS2>
00408 Matrix<T,PO1,Concrete>
00409 invpd (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& M)
00410 {
00411 return invpd<PO1,Concrete>(A, M);
00412 }
00413
00434 template <matrix_order RO, matrix_style RS, typename T,
00435 matrix_order PO, matrix_style PS>
00436 Matrix<T, RO, RS>
00437 invpd (const Matrix<T, PO, PS>& A)
00438 {
00439
00440
00441 return invpd<RO,RS>(A, cholesky<RO,Concrete>(A));
00442 }
00443
00444 template <typename T, matrix_order O, matrix_style S>
00445 Matrix<T, O, Concrete>
00446 invpd (const Matrix<T,O,S>& A)
00447 {
00448 return invpd<O,Concrete>(A);
00449 }
00450
00451
00452
00453
00454
00455
00456 namespace {
00457 template <matrix_order PO1, matrix_style PS1, typename T,
00458 matrix_order PO2, matrix_order PO3, matrix_order PO4>
00459 inline T
00460 lu_decomp_alg(Matrix<T,PO1,PS1>& A, Matrix<T,PO2,Concrete>& L,
00461 Matrix<T,PO3,Concrete>& U,
00462 Matrix<unsigned int, PO4, Concrete>& perm_vec)
00463 {
00464 if (A.isRowVector()) {
00465 L = Matrix<T,PO2,Concrete> (1, 1, true, 1);
00466 U = A;
00467 perm_vec = Matrix<uint, PO4, Concrete>(1, 1);
00468 return (T) 0;
00469 }
00470
00471 L = U = Matrix<T, PO2, Concrete>(A.rows(), A.cols(), false);
00472 perm_vec = Matrix<uint, PO3, Concrete> (A.rows() - 1, 1, false);
00473
00474 uint pivot;
00475 T temp;
00476 T sign = (T) 1;
00477
00478 for (uint k = 0; k < A.rows() - 1; ++k) {
00479 pivot = k;
00480
00481 for (uint i = k; i < A.rows(); ++i) {
00482 if (std::fabs(A(pivot,k)) < std::fabs(A(i,k)))
00483 pivot = i;
00484 }
00485
00486 SCYTHE_CHECK_20(A(pivot,k) == (T) 0, scythe_type_error,
00487 "Matrix is singular");
00488
00489
00490 if (k != pivot) {
00491 sign *= -1;
00492 for (uint i = 0; i < A.rows(); ++i) {
00493 temp = A(pivot,i);
00494 A(pivot,i) = A(k,i);
00495 A(k,i) = temp;
00496 }
00497 }
00498 perm_vec[k] = pivot;
00499
00500 for (uint i = k + 1; i < A.rows(); ++i) {
00501 A(i,k) = A(i,k) / A(k,k);
00502 for (uint j = k + 1; j < A.rows(); ++j)
00503 A(i,j) = A(i,j) - A(i,k) * A(k,j);
00504 }
00505 }
00506
00507 L = A;
00508
00509 for (uint i = 0; i < A.rows(); ++i) {
00510 for (uint j = i; j < A.rows(); ++j) {
00511 U(i,j) = A(i,j);
00512 L(i,j) = (T) 0;
00513 L(i,i) = (T) 1;
00514 }
00515 }
00516 return sign;
00517 }
00518 }
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00557 template <matrix_order PO1, matrix_style PS1, typename T,
00558 matrix_order PO2, matrix_order PO3, matrix_order PO4>
00559 void
00560 lu_decomp(Matrix<T,PO1,PS1> A, Matrix<T,PO2,Concrete>& L,
00561 Matrix<T,PO3,Concrete>& U,
00562 Matrix<unsigned int, PO4, Concrete>& perm_vec)
00563 {
00564 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00565 "A is NULL")
00566 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error,
00567 "Matrix A not square");
00568
00569 lu_decomp_alg(A, L, U, perm_vec);
00570 }
00571
00572
00573
00574
00575
00604 template <matrix_order RO, matrix_style RS, typename T,
00605 matrix_order PO1, matrix_style PS1,
00606 matrix_order PO2, matrix_style PS2,
00607 matrix_order PO3, matrix_style PS3,
00608 matrix_order PO4, matrix_style PS4,
00609 matrix_order PO5, matrix_style PS5>
00610 Matrix<T, RO, RS>
00611 lu_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b,
00612 const Matrix<T,PO3,PS3>& L, const Matrix<T,PO4,PS4>& U,
00613 const Matrix<unsigned int, PO5, PS5> &perm_vec)
00614 {
00615 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00616 "A is NULL")
00617 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error,
00618 "b is not a column vector");
00619 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error,
00620 "A is not square");
00621 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error,
00622 "A and b have different row sizes");
00623 SCYTHE_CHECK_10(A.rows() != L.rows() || A.rows() != U.rows() ||
00624 A.cols() != L.cols() || A.cols() != U.cols(),
00625 scythe_conformation_error,
00626 "A, L, and U do not conform");
00627 SCYTHE_CHECK_10(perm_vec.rows() + 1 != A.rows(),
00628 scythe_conformation_error,
00629 "perm_vec does not have exactly one less row than A");
00630
00631 T *y = new T[A.rows()];
00632 T *x = new T[A.rows()];
00633
00634 Matrix<T,RO,Concrete> bb = row_interchange(b, perm_vec);
00635 solve(L, U, bb, x, y);
00636
00637 Matrix<T,RO,RS> result(A.rows(), 1, x);
00638
00639 delete[]x;
00640 delete[]y;
00641
00642 return result;
00643 }
00644
00645 template <typename T, matrix_order PO1, matrix_style PS1,
00646 matrix_order PO2, matrix_style PS2,
00647 matrix_order PO3, matrix_style PS3,
00648 matrix_order PO4, matrix_style PS4,
00649 matrix_order PO5, matrix_style PS5>
00650 Matrix<T, PO1, Concrete>
00651 lu_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b,
00652 const Matrix<T,PO3,PS3>& L, const Matrix<T,PO4,PS4>& U,
00653 const Matrix<unsigned int, PO5, PS5> &perm_vec)
00654 {
00655 return lu_solve<PO1,Concrete>(A, b, L, U, perm_vec);
00656 }
00657
00678 template <matrix_order RO, matrix_style RS, typename T,
00679 matrix_order PO1, matrix_style PS1,
00680 matrix_order PO2, matrix_style PS2>
00681 Matrix<T,RO,RS>
00682 lu_solve (Matrix<T,PO1,PS1> A, const Matrix<T,PO2,PS2>& b)
00683 {
00684
00685 Matrix<T, RO, Concrete> L, U;
00686 Matrix<uint, RO, Concrete> perm_vec;
00687 lu_decomp_alg(A, L, U, perm_vec);
00688
00689 return lu_solve<RO,RS>(A, b, L, U, perm_vec);
00690 }
00691
00692 template <typename T, matrix_order PO1, matrix_style PS1,
00693 matrix_order PO2, matrix_style PS2>
00694 Matrix<T,PO1,Concrete>
00695 lu_solve (Matrix<T,PO1,PS1> A, const Matrix<T,PO2,PS2>& b)
00696 {
00697
00698
00699
00700 Matrix<T, PO1, Concrete> L, U;
00701 Matrix<uint, PO1, Concrete> perm_vec;
00702 lu_decomp_alg(A, L, U, perm_vec);
00703
00704 return lu_solve<PO1,Concrete>(A, b, L, U, perm_vec);
00705 }
00706
00730 template<matrix_order RO, matrix_style RS, typename T,
00731 matrix_order PO1, matrix_style PS1,
00732 matrix_order PO2, matrix_style PS2,
00733 matrix_order PO3, matrix_style PS3,
00734 matrix_order PO4, matrix_style PS4>
00735 Matrix<T,RO,RS>
00736 inv (const Matrix<T,PO1,PS1>& A,
00737 const Matrix<T,PO2,PS2>& L, const Matrix<T,PO3,PS3>& U,
00738 const Matrix<unsigned int,PO4,PS4>& perm_vec)
00739 {
00740 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00741 "A is NULL")
00742 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error,
00743 "A is not square");
00744 SCYTHE_CHECK_10(A.rows() != L.rows() || A.rows() != U.rows() ||
00745 A.cols() != L.cols() || A.cols() != U.cols(),
00746 scythe_conformation_error,
00747 "A, L, and U do not conform");
00748 SCYTHE_CHECK_10(perm_vec.rows() + 1 != A.rows()
00749 && !(A.isScalar() && perm_vec.isScalar()),
00750 scythe_conformation_error,
00751 "perm_vec does not have exactly one less row than A");
00752
00753
00754 Matrix<T,RO,Concrete> Ainv(A.rows(), A.rows(), false);
00755
00756
00757 T *y = new T[A.rows()];
00758 T *x = new T[A.rows()];
00759 Matrix<T, RO, Concrete> b(A.rows(), 1);
00760 Matrix<T,RO,Concrete> bb;
00761
00762 for (uint k = 0; k < A.rows(); ++k) {
00763 b[k] = (T) 1;
00764 bb = row_interchange(b, perm_vec);
00765
00766 solve(L, U, bb, x, y);
00767
00768 b[k] = (T) 0;
00769 for (uint l = 0; l < A.rows(); ++l)
00770 Ainv(l,k) = x[l];
00771 }
00772
00773 delete[] y;
00774 delete[] x;
00775
00776 SCYTHE_VIEW_RETURN(T, RO, RS, Ainv)
00777 }
00778
00779 template<typename T,
00780 matrix_order PO1, matrix_style PS1,
00781 matrix_order PO2, matrix_style PS2,
00782 matrix_order PO3, matrix_style PS3,
00783 matrix_order PO4, matrix_style PS4>
00784 Matrix<T,PO1,Concrete>
00785 inv (const Matrix<T,PO1,PS1>& A,
00786 const Matrix<T,PO2,PS2>& L, const Matrix<T,PO3,PS3>& U,
00787 const Matrix<unsigned int,PO4,PS4>& perm_vec)
00788 {
00789 return inv<PO1,Concrete>(A, L, U, perm_vec);
00790 }
00791
00810 template <matrix_order RO, matrix_style RS, typename T,
00811 matrix_order PO, matrix_style PS>
00812 Matrix<T, RO, RS>
00813 inv (const Matrix<T, PO, PS>& A)
00814 {
00815
00816
00817 Matrix<T,RO,Concrete> AA = A;
00818
00819
00820 Matrix<T, RO, Concrete> L, U;
00821 Matrix<uint, RO, Concrete> perm_vec;
00822 lu_decomp_alg(AA, L, U, perm_vec);
00823
00824 return inv<RO,RS>(A, L, U, perm_vec);
00825 }
00826
00827 template <typename T, matrix_order O, matrix_style S>
00828 Matrix<T, O, Concrete>
00829 inv (const Matrix<T, O, S>& A)
00830 {
00831 return inv<O,Concrete>(A);
00832 }
00833
00834
00851 template <matrix_order RO, matrix_style RS, typename T,
00852 matrix_order PO1, matrix_style PS1,
00853 matrix_order PO2, matrix_style PS2>
00854 Matrix<T,RO,RS>
00855 row_interchange (Matrix<T,PO1,PS1> A,
00856 const Matrix<unsigned int,PO2,PS2>& p)
00857 {
00858 SCYTHE_CHECK_10(! p.isColVector(), scythe_dimension_error,
00859 "p not a column vector");
00860 SCYTHE_CHECK_10(p.rows() + 1 != A.rows() && ! p.isScalar(),
00861 scythe_conformation_error, "p must have one less row than A");
00862
00863 for (uint i = 0; i < A.rows() - 1; ++i) {
00864 Matrix<T,PO1,View> vec1 = A(i, _);
00865 Matrix<T,PO1,View> vec2 = A(p[i], _);
00866 std::swap_ranges(vec1.begin_f(), vec1.end_f(), vec2.begin_f());
00867 }
00868
00869 return A;
00870 }
00871
00872 template <typename T, matrix_order PO1, matrix_style PS1,
00873 matrix_order PO2, matrix_style PS2>
00874 Matrix<T,PO1,Concrete>
00875 row_interchange (const Matrix<T,PO1,PS1>& A,
00876 const Matrix<unsigned int,PO2,PS2>& p)
00877 {
00878 return row_interchange<PO1,Concrete>(A, p);
00879 }
00880
00893 template <typename T, matrix_order PO, matrix_style PS>
00894 T
00895 det (const Matrix<T, PO, PS>& A)
00896 {
00897 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error,
00898 "Matrix is not square")
00899 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00900 "Matrix is NULL")
00901
00902
00903
00904 Matrix<T,PO,Concrete> AA = A;
00905
00906
00907 Matrix<T, PO, Concrete> L, U;
00908 Matrix<uint, PO, Concrete> perm_vec;
00909 T sign = lu_decomp_alg(AA, L, U, perm_vec);
00910
00911
00912 T det = (T) 1;
00913 for (uint i = 0; i < AA.rows(); ++i)
00914 det *= AA(i, i);
00915
00916 return sign * det;
00917 }
00918
00919 #ifdef SCYTHE_LAPACK
00920
00921 template<>
00922 Matrix<>
00923 cholesky (const Matrix<>& A)
00924 {
00925 SCYTHE_DEBUG_MSG("Using lapack/blas for cholesky");
00926 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error,
00927 "Matrix not square");
00928 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00929 "Matrix is NULL");
00930
00931
00932
00933 Matrix<> AA = A;
00934
00935
00936 double* Aarray = AA.getArray();
00937 int rows = (int) AA.rows();
00938 int err = 0;
00939
00940
00941 lapack::dpotrf_("L", &rows, Aarray, &rows, &err);
00942 SCYTHE_CHECK_10(err > 0, scythe_type_error,
00943 "Matrix is not positive definite")
00944 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
00945 "The " << err << "th value of the matrix had an illegal value")
00946
00947
00948 for (uint j = 1; j < AA.cols(); ++j)
00949 for (uint i = 0; i < j; ++i)
00950 AA(i, j) = 0;
00951
00952 return AA;
00953 }
00954
00955 template<>
00956 Matrix<>
00957 chol_solve (const Matrix<>& A, const Matrix<>& b, const Matrix<>& M)
00958 {
00959 SCYTHE_DEBUG_MSG("Using lapack/blas for chol_solve");
00960 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00961 "A is NULL")
00962 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error,
00963 "b must be a column vector");
00964 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error,
00965 "A and b do not conform");
00966 SCYTHE_CHECK_10(A.rows() != M.rows(), scythe_conformation_error,
00967 "A and M do not conform");
00968 SCYTHE_CHECK_10(! M.isSquare(), scythe_dimension_error,
00969 "M must be square");
00970
00971
00972 Matrix<> bb = b;
00973
00974
00975 const double* Marray = M.getArray();
00976 double* barray = bb.getArray();
00977 int rows = (int) bb.rows();
00978 int cols = (int) bb.cols();
00979 int err = 0;
00980
00981
00982 lapack::dpotrs_("L", &rows, &cols, Marray, &rows, barray, &rows, &err);
00983 SCYTHE_CHECK_10(err > 0, scythe_type_error,
00984 "Matrix is not positive definite")
00985 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
00986 "The " << err << "th value of the matrix had an illegal value")
00987
00988 return bb;
00989 }
00990
00991 template<>
00992 Matrix<>
00993 chol_solve (const Matrix<>& A, const Matrix<>& b)
00994 {
00995 SCYTHE_DEBUG_MSG("Using lapack/blas for chol_solve");
00996 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
00997 "A is NULL")
00998 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error,
00999 "b must be a column vector");
01000 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error,
01001 "A and b do not conform");
01002
01003
01004 Matrix<> AA =A;
01005 Matrix<> bb = b;
01006
01007
01008 double* Aarray = AA.getArray();
01009 double* barray = bb.getArray();
01010 int rows = (int) bb.rows();
01011 int cols = (int) bb.cols();
01012 int err = 0;
01013
01014
01015 lapack::dposv_("L", &rows, &cols, Aarray, &rows, barray, &rows, &err);
01016 SCYTHE_CHECK_10(err > 0, scythe_type_error,
01017 "Matrix is not positive definite")
01018 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
01019 "The " << err << "th value of the matrix had an illegal value")
01020
01021 return bb;
01022 }
01023
01024 template <matrix_order PO2, matrix_order PO3, matrix_order PO4>
01025 inline double
01026 lu_decomp_alg(Matrix<>& A, Matrix<double,PO2,Concrete>& L,
01027 Matrix<double,PO3,Concrete>& U,
01028 Matrix<unsigned int, PO4, Concrete>& perm_vec)
01029 {
01030 SCYTHE_DEBUG_MSG("Using lapack/blas for lu_decomp_alg");
01031 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL")
01032 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error,
01033 "A is not square");
01034
01035 if (A.isRowVector()) {
01036 L = Matrix<double,PO2,Concrete> (1, 1, true, 1);
01037 U = A;
01038 perm_vec = Matrix<uint, PO4, Concrete>(1, 1);
01039 return 0.;
01040 }
01041
01042 L = U = Matrix<double, PO2, Concrete>(A.rows(), A.cols(), false);
01043 perm_vec = Matrix<uint, PO3, Concrete> (A.rows(), 1, false);
01044
01045
01046 double* Aarray = A.getArray();
01047 int rows = (int) A.rows();
01048 int* ipiv = (int*) perm_vec.getArray();
01049 int err = 0;
01050
01051
01052 lapack::dgetrf_(&rows, &rows, Aarray, &rows, ipiv, &err);
01053
01054 SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is singular");
01055 SCYTHE_CHECK_10(err < 0, scythe_lapack_internal_error,
01056 "The " << err << "th value of the matrix had an illegal value");
01057
01058
01059 L = A;
01060 for (uint i = 0; i < A.rows(); ++i) {
01061 for (uint j = i; j < A.rows(); ++j) {
01062 U(i,j) = A(i,j);
01063 L(i,j) = 0.;
01064 L(i,i) = 1.;
01065 }
01066 }
01067
01068
01069
01070
01071 if (perm_vec(perm_vec.size() - 1) != perm_vec.size())
01072 SCYTHE_THROW(scythe_unexpected_default_error,
01073 "This is an unexpected error. Please notify the developers.")
01074 perm_vec = perm_vec(0, 0, perm_vec.rows() - 2, 0) - 1;
01075
01076
01077 if (sum(perm_vec > 0) % 2 == 0)
01078 return 1;
01079
01080 return -1;
01081 }
01082
01102 struct QRdecomp {
01103 Matrix<> QR;
01104 Matrix<> tau;
01105 Matrix<> pivot;
01106 };
01107
01141 QRdecomp
01142 qr_decomp (const Matrix<>& A)
01143 {
01144 SCYTHE_DEBUG_MSG("Using lapack/blas for qr_decomp");
01145 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL");
01146
01147
01148 Matrix<> QR = A;
01149 double* QRarray = QR.getArray();
01150 int rows = (int) QR.rows();
01151 int cols = (int) QR.cols();
01152 Matrix<unsigned int> pivot(cols, 1);
01153 int* parray = (int*) pivot.getArray();
01154 Matrix<> tau = Matrix<>(rows < cols ? rows : cols, 1);
01155 double* tarray = tau.getArray();
01156 double tmp, *work;
01157 int lwork, info;
01158
01159
01160 lwork = -1;
01161 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, &tmp,
01162 &lwork, &info);
01163
01164 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01165 "Internal error in LAPACK routine dgeqp3");
01166
01167 lwork = (int) tmp;
01168 work = new double[lwork];
01169
01170
01171 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, work,
01172 &lwork, &info);
01173
01174 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01175 "Internal error in LAPACK routine dgeqp3");
01176
01177 delete[] work;
01178
01179 pivot -= 1;
01180
01181 QRdecomp result;
01182 result.QR = QR;
01183 result.tau = tau;
01184 result.pivot = pivot;
01185
01186 return result;
01187 }
01188
01225 inline Matrix<>
01226 qr_solve(const Matrix<>& A, const Matrix<>& b, const QRdecomp& QR)
01227 {
01228 SCYTHE_DEBUG_MSG("Using lapack/blas for qr_solve");
01229 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL")
01230 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error,
01231 "A and b do not conform");
01232 SCYTHE_CHECK_10(A.rows() != QR.QR.rows() || A.cols() != QR.QR.cols(),
01233 scythe_conformation_error, "A and QR do not conform");
01234 int taudim = (int) (A.rows() < A.cols() ? A.rows() : A.cols());
01235 SCYTHE_CHECK_10(QR.tau.size() != taudim, scythe_conformation_error,
01236 "A and tau do not conform");
01237 SCYTHE_CHECK_10(QR.pivot.size() != A.cols(), scythe_conformation_error,
01238 "pivot vector is not the right length");
01239
01240 int rows = (int) QR.QR.rows();
01241 int cols = (int) QR.QR.cols();
01242 int nrhs = (int) b.cols();
01243 int lwork, info;
01244 double *work, tmp;
01245 double* QRarray = QR.QR.getArray();
01246 double* tarray = QR.tau.getArray();
01247 Matrix<> bb = b;
01248 double* barray = bb.getArray();
01249
01250
01251 lwork = -1;
01252 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows,
01253 tarray, barray, &rows, &tmp, &lwork, &info);
01254
01255 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01256 "Internal error in LAPACK routine dormqr");
01257
01258
01259 lwork = (int) tmp;
01260 work = new double[lwork];
01261 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows,
01262 tarray, barray, &rows, work, &lwork, &info);
01263
01264 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01265 "Internal error in LAPACK routine dormqr");
01266
01267 lapack::dtrtrs_("U", "N", "N", &taudim, &nrhs, QRarray, &rows, barray,
01268 &rows, &info);
01269
01270 SCYTHE_CHECK_10(info > 0, scythe_type_error, "Matrix is singular");
01271 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error,
01272 "Internal error in LAPACK routine dtrtrs");
01273
01274 delete[] work;
01275
01276 Matrix<> result(A.cols(), b.cols(), false);
01277 for (uint i = 0; i < QR.pivot.size(); ++i)
01278 result(i, _) = bb(QR.pivot(i), _);
01279 return result;
01280 }
01281
01313 inline Matrix<>
01314 qr_solve (const Matrix<>& A, const Matrix<>& b)
01315 {
01316 SCYTHE_DEBUG_MSG("Using lapack/blas for qr_solve");
01317 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL")
01318 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error,
01319 "A and b do not conform");
01320
01321
01322
01323
01324 Matrix<> QR = A;
01325 double* QRarray = QR.getArray();
01326 int rows = (int) QR.rows();
01327 int cols = (int) QR.cols();
01328 Matrix<unsigned int> pivot(cols, 1);
01329 int* parray = (int*) pivot.getArray();
01330 Matrix<> tau = Matrix<>(rows < cols ? rows : cols, 1);
01331 double* tarray = tau.getArray();
01332 double tmp, *work;
01333 int lwork, info;
01334
01335
01336 lwork = -1;
01337 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, &tmp,
01338 &lwork, &info);
01339
01340 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01341 "Internal error in LAPACK routine dgeqp3");
01342
01343 lwork = (int) tmp;
01344 work = new double[lwork];
01345
01346
01347 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, work,
01348 &lwork, &info);
01349
01350 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01351 "Internal error in LAPACK routine dgeqp3");
01352
01353 delete[] work;
01354
01355 pivot -= 1;
01356
01357
01358
01359
01360 int nrhs = (int) b.cols();
01361 Matrix<> bb = b;
01362 double* barray = bb.getArray();
01363 int taudim = (int) tau.size();
01364
01365
01366 lwork = -1;
01367 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows,
01368 tarray, barray, &rows, &tmp, &lwork, &info);
01369
01370 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01371 "Internal error in LAPACK routine dormqr");
01372
01373
01374 lwork = (int) tmp;
01375 work = new double[lwork];
01376 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows,
01377 tarray, barray, &rows, work, &lwork, &info);
01378
01379 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01380 "Internal error in LAPACK routine dormqr");
01381
01382 lapack::dtrtrs_("U", "N", "N", &taudim, &nrhs, QRarray, &rows, barray,
01383 &rows, &info);
01384
01385 SCYTHE_CHECK_10(info > 0, scythe_type_error, "Matrix is singular");
01386 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error,
01387 "Internal error in LAPACK routine dtrtrs");
01388
01389 delete[] work;
01390
01391 Matrix<> result(A.cols(), b.cols(), false);
01392 for (uint i = 0; i < pivot.size(); ++i)
01393 result(i, _) = bb(pivot(i), _);
01394
01395 return result;
01396 }
01397
01398 template<>
01399 Matrix<>
01400 invpd (const Matrix<>& A)
01401 {
01402 SCYTHE_DEBUG_MSG("Using lapack/blas for invpd");
01403 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
01404 "A is NULL")
01405 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error,
01406 "A is not square");
01407
01408
01409
01410 Matrix<> AA = A;
01411
01412
01413 double* Aarray = AA.getArray();
01414 int rows = (int) AA.rows();
01415 int err = 0;
01416
01417
01418 lapack::dpotrf_("L", &rows, Aarray, &rows, &err);
01419 SCYTHE_CHECK_10(err > 0, scythe_type_error,
01420 "Matrix is not positive definite")
01421 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
01422 "The " << err << "th value of the matrix had an illegal value")
01423
01424
01425 lapack::dpotri_("L", &rows, Aarray, &rows, &err);
01426 SCYTHE_CHECK_10(err > 0, scythe_type_error,
01427 "The (" << err << ", " << err << ") element of the matrix is zero"
01428 << " and the inverse could not be computed")
01429 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
01430 "The " << err << "th value of the matrix had an illegal value")
01431 lapack::make_symmetric(Aarray, rows);
01432
01433 return AA;
01434 }
01435
01436 template<>
01437 Matrix<>
01438 invpd (const Matrix<>& A, const Matrix<>& M)
01439 {
01440 SCYTHE_DEBUG_MSG("Using lapack/blas for invpd");
01441 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
01442 "A is NULL")
01443 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error,
01444 "A is not square");
01445 SCYTHE_CHECK_10(A.rows() != M.cols() || A.cols() != M.rows(),
01446 scythe_conformation_error, "A and M do not conform");
01447
01448
01449
01450 Matrix<> MM = M;
01451
01452
01453 double* Marray = MM.getArray();
01454 int rows = (int) MM.rows();
01455 int err = 0;
01456
01457
01458 lapack::dpotri_("L", &rows, Marray, &rows, &err);
01459 SCYTHE_CHECK_10(err > 0, scythe_type_error,
01460 "The (" << err << ", " << err << ") element of the matrix is zero"
01461 << " and the inverse could not be computed")
01462 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
01463 "The " << err << "th value of the matrix had an illegal value")
01464 lapack::make_symmetric(Marray, rows);
01465
01466 return MM;
01467 }
01468
01469 template <>
01470 Matrix<>
01471 inv(const Matrix<>& A)
01472 {
01473 SCYTHE_DEBUG_MSG("Using lapack/blas for inv");
01474 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
01475 "A is NULL")
01476 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error,
01477 "A is not square");
01478
01479
01480
01481 Matrix<> AA = A;
01482
01483
01484 double* Aarray = AA.getArray();
01485 int rows = (int) AA.rows();
01486 int* ipiv = new int[rows];
01487 int err = 0;
01488
01489
01490 lapack::dgetrf_(&rows, &rows, Aarray, &rows, ipiv, &err);
01491
01492 SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is singular");
01493 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
01494 "The " << err << "th value of the matrix had an illegal value");
01495
01496
01497
01498 double work_query = 0;
01499 int work_size = -1;
01500 lapack::dgetri_(&rows, Aarray, &rows, ipiv, &work_query,
01501 &work_size, &err);
01502 double* workspace = new double[(work_size = (int) work_query)];
01503 lapack::dgetri_(&rows, Aarray, &rows, ipiv, workspace, &work_size,
01504 &err);
01505 delete[] ipiv;
01506 delete[] workspace;
01507
01508 SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is singular");
01509 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg,
01510 "Internal error in LAPACK routine dgetri");
01511
01512 return AA;
01513 }
01514
01530 struct SVD {
01531 Matrix<> d;
01532 Matrix<> U;
01533 Matrix<> Vt;
01534 };
01535
01564 inline SVD
01565 svd (const Matrix<>& A, int nu = -1, int nv = -1)
01566 {
01567 SCYTHE_DEBUG_MSG("Using lapack/blas for eigen");
01568 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
01569 "Matrix is NULL");
01570
01571 char* jobz;
01572 int m = (int) A.rows();
01573 int n = (int) A.cols();
01574 int mn = (int) std::min(A.rows(), A.cols());
01575 Matrix<> U;
01576 Matrix<> V;
01577 if (nu < 0) nu = mn;
01578 if (nv < 0) nv = mn;
01579 if (nu <= mn && nv<= mn) {
01580 jobz = "S";
01581 U = Matrix<>(m, mn, false);
01582 V = Matrix<>(mn, n, false);
01583 } else if (nu == 0 && nv == 0) {
01584 jobz = "N";
01585 } else {
01586 jobz = "A";
01587 U = Matrix<>(m, m, false);
01588 V = Matrix<>(n, n, false);
01589 }
01590 double* Uarray = U.getArray();
01591 double* Varray = V.getArray();
01592
01593 int ldu = (int) U.rows();
01594 int ldvt = (int) V.rows();
01595 Matrix<> X = A;
01596 double* Xarray = X.getArray();
01597 Matrix<> d(mn, 1, false);
01598 double* darray = d.getArray();
01599
01600 double tmp, *work;
01601 int lwork, info;
01602 int *iwork = new int[8 * mn];
01603
01604
01605 lwork = -1;
01606 lapack::dgesdd_(jobz, &m, &n, Xarray, &m, darray, Uarray, &ldu,
01607 Varray, &ldvt, &tmp, &lwork, iwork, &info);
01608 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error,
01609 "Internal error in LAPACK routine dgessd");
01610 SCYTHE_CHECK_10(info > 0, scythe_convergence_error, "Did not converge");
01611
01612 lwork = (int) tmp;
01613 work = new double[lwork];
01614
01615
01616 lapack::dgesdd_(jobz, &m, &n, Xarray, &m, darray, Uarray, &ldu,
01617 Varray, &ldvt, work, &lwork, iwork, &info);
01618 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error,
01619 "Internal error in LAPACK routine dgessd");
01620 SCYTHE_CHECK_10(info > 0, scythe_convergence_error, "Did not converge");
01621 delete[] work;
01622
01623 if (nu < mn && nu > 0)
01624 U = U(0, 0, U.rows() - 1, (unsigned int) std::min(m, nu) - 1);
01625 if (nv < mn && nv > 0)
01626 V = V(0, 0, (unsigned int) std::min(n, nv) - 1, V.cols() - 1);
01627 SVD result;
01628 result.d = d;
01629 result.U = U;
01630 result.Vt = V;
01631
01632 return result;
01633 }
01634
01646 struct Eigen {
01647 Matrix<> values;
01648 Matrix<> vectors;
01649 };
01650
01679 inline Eigen
01680 eigen (const Matrix<>& A, bool vectors=true)
01681 {
01682 SCYTHE_DEBUG_MSG("Using lapack/blas for eigen");
01683 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error,
01684 "Matrix not square");
01685 SCYTHE_CHECK_10(A.isNull(), scythe_null_error,
01686 "Matrix is NULL");
01687
01688
01689
01690
01691 Matrix<> AA = A;
01692
01693
01694 double* Aarray = AA.getArray();
01695 int order = (int) AA.rows();
01696 double dignored = 0;
01697 int iignored = 0;
01698 double abstol = 0.0;
01699 int m;
01700 Matrix<> result;
01701 char getvecs[1];
01702 if (vectors) {
01703 getvecs[0] = 'V';
01704 result = Matrix<>(order, order + 1, false);
01705 } else {
01706 result = Matrix<>(order, 1, false);
01707 getvecs[0] = 'N';
01708 }
01709 double* eigenvalues = result.getArray();
01710 int* isuppz = new int[2 * order];
01711 double tmp;
01712 int lwork, liwork, *iwork, itmp;
01713 double *work;
01714 int info = 0;
01715
01716
01717 lwork = -1;
01718 liwork = -1;
01719 lapack::dsyevr_(getvecs, "A", "L", &order, Aarray, &order, &dignored,
01720 &dignored, &iignored, &iignored, &abstol, &m, eigenvalues,
01721 eigenvalues + order, &order, isuppz, &tmp, &lwork, &itmp,
01722 &liwork, &info);
01723 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01724 "Internal error in LAPACK routine dsyevr");
01725 lwork = (int) tmp;
01726 liwork = itmp;
01727 work = new double[lwork];
01728 iwork = new int[liwork];
01729
01730
01731 lapack::dsyevr_(getvecs, "A", "L", &order, Aarray, &order, &dignored,
01732 &dignored, &iignored, &iignored, &abstol, &m, eigenvalues,
01733 eigenvalues + order, &order, isuppz, work, &lwork, iwork,
01734 &liwork, &info);
01735 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error,
01736 "Internal error in LAPACK routine dsyevr");
01737
01738 delete[] isuppz;
01739 delete[] work;
01740 delete[] iwork;
01741
01742 Eigen resobj;
01743 if (vectors) {
01744 resobj.values = result(_, 0);
01745 resobj.vectors = result(0, 1, result.rows() -1, result.cols() - 1);
01746 } else {
01747 resobj.values = result;
01748 }
01749
01750 return resobj;
01751 }
01752
01753 #endif
01754
01755 }
01756
01757 #endif