00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00028 #ifndef SCYTHE_RTMVNORM_H
00029 #define SCYTHE_RTMVNORM_H
00030
00031 #include <iostream>
00032 #include <cmath>
00033
00034 #ifdef SCYTHE_COMPILE_DIRECT
00035 #include "matrix.h"
00036 #include "rng.h"
00037 #include "error.h"
00038 #include "algorithm.h"
00039 #include "ide.h"
00040 #else
00041 #include "scythestat/matrix.h"
00042 #include "scythestat/rng.h"
00043 #include "scythestat/error.h"
00044 #include "scythestat/algorithm.h"
00045 #include "scythestat/ide.h"
00046 #endif
00047 namespace scythe
00048 {
00049
00050
00051
00052
00053
00054
00055
00056
00067 template <class RNGTYPE>
00068 class rtmvnorm {
00069 public:
00070
00120 template <matrix_order PO1, matrix_style PS1, matrix_order PO2,
00121 matrix_style PS2, matrix_order PO3, matrix_style PS3,
00122 matrix_order PO4, matrix_style PS5, matrix_order PO5,
00123 matrix_style PS4>
00124 rtmvnorm (const Matrix<double, PO1,PS1>& mu,
00125 const Matrix<double, PO2, PS2>& sigma,
00126 const Matrix<double, PO3, PS3>& D,
00127 const Matrix<double, PO4, PS4>& a,
00128 const Matrix<double, PO5, PS5>& b, rng<RNGTYPE>& generator,
00129 unsigned int burnin = 0, unsigned int thin = 1,
00130 bool preinvertedD = false)
00131 : mu_ (mu), C_ (mu.rows(), mu.rows(), false),
00132 h_ (mu.rows(), 1, false), z_ (mu.rows(), 1, true, 0),
00133 generator_ (generator), n_ (mu.rows()), thin_ (thin), iter_ (0)
00134 {
00135 SCYTHE_CHECK_10(thin == 0, scythe_invalid_arg,
00136 "thin must be >= 1");
00137 SCYTHE_CHECK_10(! mu.isColVector(), scythe_dimension_error,
00138 "mu not column vector");
00139 SCYTHE_CHECK_10(! sigma.isSquare(), scythe_dimension_error,
00140 "sigma not square");
00141 SCYTHE_CHECK_10(! D.isSquare(), scythe_dimension_error,
00142 "D not square");
00143 SCYTHE_CHECK_10(! a.isColVector(), scythe_dimension_error,
00144 "a not column vector");
00145 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error,
00146 "b not column vector");
00147 SCYTHE_CHECK_10(sigma.rows() != n_ || D.rows() != n_ ||
00148 a.rows() != n_ || b.rows() != n_, scythe_conformation_error,
00149 "mu, sigma, D, a, and b not conformable");
00150
00151
00152
00153 if (preinvertedD)
00154 Dinv_ = D;
00155 else
00156 Dinv_ = inv(D);
00157 Matrix<> Tinv = inv(D * sigma * t(D));
00158 alpha_ = a - D * mu;
00159 beta_ = b - D * mu;
00160
00161
00162 if (SCYTHE_DEBUG > 0) {
00163 for (unsigned int i = 0; i < n_; ++i) {
00164 SCYTHE_CHECK(alpha_(i) >= beta_(i), scythe_invalid_arg,
00165 "Truncation bound " << i
00166 << " not logically consistent");
00167 }
00168 }
00169
00170
00171 for (unsigned int i = 0; i < n_; ++i) {
00172 C_(i, _) = -(1 / Tinv(i, i)) % Tinv(i, _);
00173 C_(i, i) = 0;
00174 h_(i) = std::sqrt(1 / Tinv(i, i));
00175 SCYTHE_CHECK_30(std::isnan(h_(i)), scythe_invalid_arg,
00176 "sigma is not positive definite");
00177 }
00178
00179
00180 for (unsigned int i = 0; i < burnin; ++i)
00181 sample ();
00182 }
00183
00193 template <matrix_order O, matrix_style S>
00194 Matrix<double, O, S> operator() ()
00195 {
00196 do { sample (); } while (iter_ % thin_ != 0);
00197
00198 return (mu_ + Dinv_ * z_);
00199 }
00200
00207 Matrix<double,Col,Concrete> operator() ()
00208 {
00209 return operator()<Col, Concrete>();
00210 }
00211
00212 protected:
00213
00214 void sample ()
00215 {
00216 double czsum;
00217 double above;
00218 double below;
00219 for (unsigned int i = 0; i < n_; ++i) {
00220
00221
00222 czsum = 0;
00223 for (unsigned int j = 0; j < n_; ++j) {
00224 if (i == j) continue;
00225 czsum += C_(i, j) * z_(j);
00226 }
00227
00228
00229 below = (alpha_(i) - czsum) / h_(i);
00230 above = (beta_(i) - czsum) / h_(i);
00231
00232
00233 z_(i) = h_(i);
00234 if (above == std::numeric_limits<double>::infinity()){
00235 if (below == -std::numeric_limits<double>::infinity())
00236 z_(i) *= generator_.rnorm(0, 1);
00237 else
00238 z_(i) *= generator_.rtbnorm_combo(0, 1, below);
00239 } else if (below ==
00240 -std::numeric_limits<double>::infinity())
00241 z_(i) *= generator_.rtanorm_combo(0, 1, above);
00242 else
00243 z_(i) *= generator_.rtnorm_combo(0, 1, below, above);
00244
00245 z_(i) += czsum;
00246 }
00247
00248 ++iter_;
00249 }
00250
00251
00252
00253
00254 Matrix<> mu_; Matrix<> Dinv_;
00255 Matrix<> C_; Matrix<> alpha_; Matrix<> beta_; Matrix<> h_;
00256
00257 Matrix<> z_;
00258
00259 rng<RNGTYPE>& generator_;
00260
00261 unsigned int n_;
00262 unsigned int thin_;
00263 unsigned int iter_;
00264 };
00265 }
00266 #endif