00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00032
00033
00034
00035
00036 #ifndef SCYTHE_RNG_H
00037 #define SCYTHE_RNG_H
00038
00039 #include <iostream>
00040 #include <cmath>
00041
00042 #ifdef HAVE_IEEEFP_H
00043 #include <ieeefp.h>
00044 #endif
00045
00046 #ifdef SCYTHE_COMPILE_DIRECT
00047 #include "matrix.h"
00048 #include "error.h"
00049 #include "algorithm.h"
00050 #include "distributions.h"
00051 #include "ide.h"
00052 #include "la.h"
00053 #else
00054 #include "scythestat/matrix.h"
00055 #include "scythestat/error.h"
00056 #include "scythestat/algorithm.h"
00057 #include "scythestat/distributions.h"
00058 #include "scythestat/ide.h"
00059 #include "scythestat/la.h"
00060 #endif
00061
00062 namespace scythe {
00063
00064
00065
00066
00067
00068 #define SCYTHE_RNGMETH_MATRIX(NAME, RTYPE, ARGNAMES, ...) \
00069 template <matrix_order O, matrix_style S> \
00070 Matrix<RTYPE, O, S> \
00071 NAME (unsigned int rows, unsigned int cols, __VA_ARGS__) \
00072 { \
00073 Matrix<RTYPE, O, Concrete> ret(rows, cols, false); \
00074 typename Matrix<RTYPE,O,Concrete>::forward_iterator it; \
00075 typename Matrix<RTYPE,O,Concrete>::forward_iterator last \
00076 = ret.end_f(); \
00077 for (it = ret.begin_f(); it != last; ++it) \
00078 *it = NAME (ARGNAMES); \
00079 SCYTHE_VIEW_RETURN(RTYPE, O, S, ret) \
00080 } \
00081 \
00082 Matrix<RTYPE, Col, Concrete> \
00083 NAME (unsigned int rows, unsigned int cols, __VA_ARGS__) \
00084 { \
00085 return NAME <Col,Concrete> (rows, cols, ARGNAMES); \
00086 }
00087
00127 template <class RNGTYPE>
00128 class rng
00129 {
00130 public:
00131
00132
00133
00134
00135
00143 double operator() ()
00144 {
00145 return runif();
00146 }
00147
00148
00149
00159 double runif ()
00160 {
00161 return as_derived().runif();
00162 }
00163
00164
00165
00166
00167
00168
00169 template <matrix_order O, matrix_style S>
00170 Matrix<double,O,S> runif(unsigned int rows,
00171 unsigned int cols)
00172 {
00173 Matrix<double, O, S> ret(rows, cols, false);
00174 typename Matrix<double,O,S>::forward_iterator it;
00175 typename Matrix<double,O,S>::forward_iterator last=ret.end_f();
00176 for (it = ret.begin_f(); it != last; ++it)
00177 *it = runif();
00178
00179 return ret;
00180 }
00181
00182 Matrix<double,Col,Concrete> runif(unsigned int rows,
00183 unsigned int cols)
00184 {
00185 return runif<Col,Concrete>(rows, cols);
00186 }
00187
00204 double
00205 rbeta (double alpha, double beta)
00206 {
00207 double report;
00208 double xalpha, xbeta;
00209
00210
00211 SCYTHE_CHECK_10(alpha <= 0, scythe_invalid_arg, "alpha <= 0");
00212 SCYTHE_CHECK_10(beta <= 0, scythe_invalid_arg, "beta <= 0");
00213
00214 xalpha = rchisq (2 * alpha);
00215 xbeta = rchisq (2 * beta);
00216 report = xalpha / (xalpha + xbeta);
00217
00218 return (report);
00219 }
00220
00221 SCYTHE_RNGMETH_MATRIX(rbeta, double, SCYTHE_ARGSET(alpha, beta),
00222 double alpha, double beta);
00223
00240 double
00241 rnchypgeom(double m1, double n1, double n2, double psi,
00242 double delta)
00243 {
00244
00245 double a = psi - 1;
00246 double b = -1 * ((n1+m1+2)*psi + n2 - m1);
00247 double c = psi * (n1+1) * (m1+1);
00248 double q = -0.5 * ( b + sgn(b) *
00249 std::sqrt(std::pow(b,2) - 4*a*c));
00250 double root1 = c/q;
00251 double root2 = q/a;
00252 double el = std::max(0.0, m1-n2);
00253 double u = std::min(n1,m1);
00254 double mode = std::floor(root1);
00255 int exactcheck = 0;
00256 if (u<mode || mode<el) {
00257 mode = std::floor(root2);
00258 exactcheck = 1;
00259 }
00260
00261
00262 int size = static_cast<int>(u+1);
00263
00264 double *fvec = new double[size];
00265 fvec[static_cast<int>(mode)] = 1.0;
00266 double s;
00267
00268 if (delta <= 0 || exactcheck==1){
00269
00270 double f = 1.0;
00271 s = 1.0;
00272 for (double i=(mode+1); i<=u; ++i){
00273 double r = ((n1-i+1)*(m1-i+1))/(i*(n2-m1+i)) * psi;
00274 f = f*r;
00275 s += f;
00276 fvec[static_cast<int>(i)] = f;
00277 }
00278
00279
00280 f = 1.0;
00281 for (double i=(mode-1); i>=el; --i){
00282 double r = ((n1-i)*(m1-i))/((i+1)*(n2-m1+i+1)) * psi;
00283 f = f/r;
00284 s += f;
00285 fvec[static_cast<int>(i)] = f;
00286 }
00287 } else {
00288 double epsilon = delta/10.0;
00289
00290 double f = 1.0;
00291 s = 1.0;
00292 double i = mode+1;
00293 double r;
00294 do {
00295 if (i>u) break;
00296 r = ((n1-i+1)*(m1-i+1))/(i*(n2-m1+i)) * psi;
00297 f = f*r;
00298 s += f;
00299 fvec[static_cast<int>(i)] = f;
00300 ++i;
00301 } while(f>=epsilon || r>=5.0/6.0);
00302
00303
00304 f = 1.0;
00305 i = mode-1;
00306 do {
00307 if (i<el) break;
00308 r = ((n1-i)*(m1-i))/((i+1)*(n2-m1+i+1)) * psi;
00309 f = f/r;
00310 s += f;
00311 fvec[static_cast<int>(i)] = f;
00312 --i;
00313 } while(f>=epsilon || r <=6.0/5.0);
00314 }
00315
00316 double udraw = runif();
00317 double psum = fvec[static_cast<int>(mode)]/s;
00318 if (udraw<=psum)
00319 return mode;
00320 double lower = mode-1;
00321 double upper = mode+1;
00322
00323 do{
00324 double fl;
00325 double fu;
00326 if (lower >= el)
00327 fl = fvec[static_cast<int>(lower)];
00328 else
00329 fl = 0.0;
00330
00331 if (upper <= u)
00332 fu = fvec[static_cast<int>(upper)];
00333 else
00334 fu = 0.0;
00335
00336 if (fl > fu) {
00337 psum += fl/s;
00338 if (udraw<=psum)
00339 return lower;
00340 --lower;
00341 } else {
00342 psum += fu/s;
00343 if (udraw<=psum)
00344 return upper;
00345 ++upper;
00346 }
00347 } while(udraw>psum);
00348
00349 delete [] fvec;
00350 SCYTHE_THROW(scythe_convergence_error,
00351 "Algorithm did not converge");
00352 }
00353
00354 SCYTHE_RNGMETH_MATRIX(rnchypgeom, double,
00355 SCYTHE_ARGSET(m1, n1, n2, psi, delta), double m1, double n1,
00356 double n2, double psi, double delta);
00357
00367 unsigned int
00368 rbern (double p)
00369 {
00370 unsigned int report;
00371 double unif;
00372
00373
00374 SCYTHE_CHECK_10(p < 0 || p > 1, scythe_invalid_arg,
00375 "p parameter not in[0,1]");
00376
00377 unif = runif ();
00378 if (unif < p)
00379 report = 1;
00380 else
00381 report = 0;
00382
00383 return (report);
00384 }
00385
00386 SCYTHE_RNGMETH_MATRIX(rbern, unsigned int, p, double p);
00387
00402 unsigned int
00403 rbinom (unsigned int n, double p)
00404 {
00405 unsigned int report;
00406 unsigned int count = 0;
00407 double hold;
00408
00409
00410 SCYTHE_CHECK_10(n == 0, scythe_invalid_arg, "n == 0");
00411 SCYTHE_CHECK_10(p < 0 || p > 1, scythe_invalid_arg,
00412 "p not in [0,1]");
00413
00414
00415 for (unsigned int i = 0; i < n; i++) {
00416 hold = runif ();
00417 if (hold < p)
00418 ++count;
00419 }
00420 report = count;
00421
00422 return (report);
00423 }
00424
00425 SCYTHE_RNGMETH_MATRIX(rbinom, unsigned int, SCYTHE_ARGSET(n, p),
00426 unsigned int n, double p);
00427
00440 double
00441 rchisq (double df)
00442 {
00443 double report;
00444
00445
00446 SCYTHE_CHECK_10(df <= 0, scythe_invalid_arg,
00447 "Degrees of freedom <= 0");
00448
00449
00450 report = rgamma (df / 2, .5);
00451
00452 return (report);
00453 }
00454
00455 SCYTHE_RNGMETH_MATRIX(rchisq, double, df, double df);
00456
00470 double
00471 rexp (double invscale)
00472 {
00473 double report;
00474
00475
00476 SCYTHE_CHECK_10(invscale <= 0, scythe_invalid_arg,
00477 "Inverse scale parameter <= 0");
00478
00479 report = -std::log (runif ()) / invscale;
00480
00481 return (report);
00482 }
00483
00484 SCYTHE_RNGMETH_MATRIX(rexp, double, invscale, double invscale);
00485
00501 double
00502 rf (double df1, double df2)
00503 {
00504 SCYTHE_CHECK_10(df1 <= 0 || df2 <= 0, scythe_invalid_arg,
00505 "n1 or n2 <= 0");
00506
00507 return ((rchisq(df1) / df1) / (rchisq(df2) / df2));
00508 }
00509
00510 SCYTHE_RNGMETH_MATRIX(rf, double, SCYTHE_ARGSET(df1, df2),
00511 double df1, double df2);
00512
00528 double
00529 rgamma (double shape, double rate)
00530 {
00531 double report;
00532
00533
00534 SCYTHE_CHECK_10(shape <= 0, scythe_invalid_arg, "shape <= 0");
00535 SCYTHE_CHECK_10(rate <= 0, scythe_invalid_arg, "rate <= 0");
00536
00537 if (shape > 1)
00538 report = rgamma1 (shape) / rate;
00539 else if (shape == 1)
00540 report = -std::log (runif ()) / rate;
00541 else if (shape < 1)
00542 report = rgamma1 (shape + 1)
00543 * std::pow (runif (), 1 / shape) / rate;
00544
00545 return (report);
00546 }
00547
00548 SCYTHE_RNGMETH_MATRIX(rgamma, double, SCYTHE_ARGSET(shape, rate),
00549 double shape, double rate);
00550
00565 double
00566 rlogis (double location, double scale)
00567 {
00568 double report;
00569 double unif;
00570
00571
00572 SCYTHE_CHECK_10(scale <= 0, scythe_invalid_arg, "scale <= 0");
00573
00574 unif = runif ();
00575 report = location + scale * std::log (unif / (1 - unif));
00576
00577 return (report);
00578 }
00579
00580 SCYTHE_RNGMETH_MATRIX(rlogis, double,
00581 SCYTHE_ARGSET(location, scale),
00582 double location, double scale);
00583
00599 double
00600 rlnorm (double logmean, double logsd)
00601 {
00602 SCYTHE_CHECK_10(logsd < 0.0, scythe_invalid_arg,
00603 "standard deviation < 0");
00604
00605 return std::exp(rnorm(logmean, logsd));
00606 }
00607
00608 SCYTHE_RNGMETH_MATRIX(rlnorm, double,
00609 SCYTHE_ARGSET(logmean, logsd),
00610 double logmean, double logsd);
00611
00628 unsigned int
00629 rnbinom (double n, double p)
00630 {
00631 SCYTHE_CHECK_10(n == 0 || p <= 0 || p > 1, scythe_invalid_arg,
00632 "n == 0, p <= 0, or p > 1");
00633
00634 return rpois(rgamma(n, (1 - p) / p));
00635 }
00636
00637 SCYTHE_RNGMETH_MATRIX(rnbinom, unsigned int,
00638 SCYTHE_ARGSET(n, p), double n, double p);
00639
00654 double
00655 rnorm (double mean = 0, double sd = 1)
00656 {
00657 SCYTHE_CHECK_10(sd <= 0, scythe_invalid_arg,
00658 "Negative standard deviation");
00659
00660 return (mean + rnorm1 () * sd);
00661 }
00662
00663 SCYTHE_RNGMETH_MATRIX(rnorm, double, SCYTHE_ARGSET(mean, sd),
00664 double mean, double sd);
00665
00680 unsigned int
00681 rpois(double lambda)
00682 {
00683 SCYTHE_CHECK_10(lambda <= 0, scythe_invalid_arg, "lambda <= 0");
00684 unsigned int n;
00685
00686 if (lambda < 33) {
00687 double cutoff = std::exp(-lambda);
00688 n = -1;
00689 double t = 1.0;
00690 do {
00691 ++n;
00692 t *= runif();
00693 } while (t > cutoff);
00694 } else {
00695 bool accept = false;
00696 double c = 0.767 - 3.36/lambda;
00697 double beta = M_PI/std::sqrt(3*lambda);
00698 double alpha = lambda*beta;
00699 double k = std::log(c) - lambda - std::log(beta);
00700
00701 while (! accept){
00702 double u1 = runif();
00703 double x = (alpha - std::log((1-u1)/u1))/beta;
00704 while (x <= -0.5){
00705 u1 = runif();
00706 x = (alpha - std::log((1-u1)/u1))/beta;
00707 }
00708 n = static_cast<int>(x + 0.5);
00709 double u2 = runif();
00710 double lhs = alpha - beta*x +
00711 std::log(u2/std::pow(1+std::exp(alpha-beta*x),2));
00712 double rhs = k + n*std::log(lambda) - lnfactorial(n);
00713 if (lhs <= rhs)
00714 accept = true;
00715 }
00716 }
00717
00718 return n;
00719 }
00720
00721 SCYTHE_RNGMETH_MATRIX(rpois, unsigned int, lambda, double lambda);
00722
00723
00724
00725
00726
00727
00728
00743 double
00744 rt (double mu, double sigma2, double nu)
00745 {
00746 double report;
00747 double x, z;
00748
00749
00750 SCYTHE_CHECK_10(sigma2 <= 0, scythe_invalid_arg,
00751 "Variance parameter sigma2 <= 0");
00752 SCYTHE_CHECK_10(nu <= 0, scythe_invalid_arg,
00753 "D.O.F parameter nu <= 0");
00754
00755 z = rnorm1 ();
00756 x = rchisq (nu);
00757 report = mu + std::sqrt (sigma2) * z
00758 * std::sqrt (nu) / std::sqrt (x);
00759
00760 return (report);
00761 }
00762
00763 SCYTHE_RNGMETH_MATRIX(rt1, double, SCYTHE_ARGSET(mu, sigma2, nu),
00764 double mu, double sigma2, double nu);
00765
00779 double
00780 rweibull (double shape, double scale)
00781 {
00782 SCYTHE_CHECK_10(shape <= 0 || scale <= 0, scythe_invalid_arg,
00783 "shape or scale <= 0");
00784
00785 return scale * std::pow(-std::log(runif()), 1.0 / shape);
00786 }
00787
00788 SCYTHE_RNGMETH_MATRIX(rweibull, double,
00789 SCYTHE_ARGSET(shape, scale), double shape, double scale);
00790
00804 double
00805 richisq (double nu)
00806 {
00807 double report;
00808
00809
00810 SCYTHE_CHECK_10(nu <= 0, scythe_invalid_arg,
00811 "Degrees of freedom <= 0");
00812
00813
00814 report = rigamma (nu / 2, .5);
00815 return (report);
00816 }
00817
00818 SCYTHE_RNGMETH_MATRIX(richisq, double, nu, double nu);
00819
00832 double
00833 rigamma (double alpha, double beta)
00834 {
00835 double report;
00836
00837
00838 SCYTHE_CHECK_10(alpha <= 0, scythe_invalid_arg, "alpha <= 0");
00839 SCYTHE_CHECK_10(beta <= 0, scythe_invalid_arg, "beta <= 0");
00840
00841
00842 report = std::pow (rgamma (alpha, beta), -1);
00843
00844 return (report);
00845 }
00846
00847 SCYTHE_RNGMETH_MATRIX(rigamma, double, SCYTHE_ARGSET(alpha, beta),
00848 double alpha, double beta);
00849
00850
00851
00874 double
00875 rtnorm(double mean, double variance, double below, double above)
00876 {
00877 SCYTHE_CHECK_10(below >= above, scythe_invalid_arg,
00878 "Truncation bound not logically consistent");
00879 SCYTHE_CHECK_10(variance <= 0, scythe_invalid_arg,
00880 "Variance <= 0");
00881
00882 double sd = std::sqrt(variance);
00883 double FA = 0.0;
00884 double FB = 0.0;
00885 if ((std::fabs((above-mean)/sd) < 8.2)
00886 && (std::fabs((below-mean)/sd) < 8.2)){
00887 FA = pnorm1((above-mean)/sd, true, false);
00888 FB = pnorm1((below-mean)/sd, true, false);
00889 }
00890 if ((((above-mean)/sd) < 8.2) && (((below-mean)/sd) <= -8.2) ){
00891 FA = pnorm1((above-mean)/sd, true, false);
00892 FB = 0.0;
00893 }
00894 if ( (((above-mean)/sd) >= 8.2) && (((below-mean)/sd) > -8.2) ){
00895 FA = 1.0;
00896 FB = pnorm1((below-mean)/sd, true, false);
00897 }
00898 if ( (((above-mean)/sd) >= 8.2) && (((below-mean)/sd) <= -8.2)){
00899 FA = 1.0;
00900 FB = 0.0;
00901 }
00902 double term = runif()*(FA-FB)+FB;
00903 if (term < 5.6e-17)
00904 term = 5.6e-17;
00905 if (term > (1 - 5.6e-17))
00906 term = 1 - 5.6e-17;
00907 double draw = mean + sd * qnorm1(term);
00908 if (draw > above)
00909 draw = above;
00910 if (draw < below)
00911 draw = below;
00912
00913 return draw;
00914 }
00915
00916 SCYTHE_RNGMETH_MATRIX(rtnorm, double,
00917 SCYTHE_ARGSET(mean, variance, above, below), double mean,
00918 double variance, double above, double below);
00919
00944 double
00945 rtnorm_combo(double mean, double variance, double below,
00946 double above)
00947 {
00948 SCYTHE_CHECK_10(below >= above, scythe_invalid_arg,
00949 "Truncation bound not logically consistent");
00950 SCYTHE_CHECK_10(variance <= 0, scythe_invalid_arg,
00951 "Variance <= 0");
00952
00953 double sd = std::sqrt(variance);
00954 if ((((above-mean)/sd > 0.5) && ((mean-below)/sd > 0.5))
00955 ||
00956 (((above-mean)/sd > 2.0) && ((below-mean)/sd < 0.25))
00957 ||
00958 (((mean-below)/sd > 2.0) && ((above-mean)/sd > -0.25))) {
00959 double x = rnorm(mean, sd);
00960 while ((x > above) || (x < below))
00961 x = rnorm(mean,sd);
00962 return x;
00963 } else {
00964
00965 double FA = 0.0;
00966 double FB = 0.0;
00967 if ((std::fabs((above-mean)/sd) < 8.2)
00968 && (std::fabs((below-mean)/sd) < 8.2)){
00969 FA = pnorm1((above-mean)/sd, true, false);
00970 FB = pnorm1((below-mean)/sd, true, false);
00971 }
00972 if ((((above-mean)/sd) < 8.2) && (((below-mean)/sd) <= -8.2) ){
00973 FA = pnorm1((above-mean)/sd, true, false);
00974 FB = 0.0;
00975 }
00976 if ( (((above-mean)/sd) >= 8.2) && (((below-mean)/sd) > -8.2) ){
00977 FA = 1.0;
00978 FB = pnorm1((below-mean)/sd, true, false);
00979 }
00980 if ( (((above-mean)/sd) >= 8.2) && (((below-mean)/sd) <= -8.2)){
00981 FA = 1.0;
00982 FB = 0.0;
00983 }
00984 double term = runif()*(FA-FB)+FB;
00985 if (term < 5.6e-17)
00986 term = 5.6e-17;
00987 if (term > (1 - 5.6e-17))
00988 term = 1 - 5.6e-17;
00989 double x = mean + sd * qnorm1(term);
00990 if (x > above)
00991 x = above;
00992 if (x < below)
00993 x = below;
00994 return x;
00995 }
00996 }
00997
00998 SCYTHE_RNGMETH_MATRIX(rtnorm_combo, double,
00999 SCYTHE_ARGSET(mean, variance, above, below), double mean,
01000 double variance, double above, double below);
01001
01024 double
01025 rtbnorm_slice (double mean, double variance, double below,
01026 unsigned int iter = 10)
01027 {
01028 SCYTHE_CHECK_10(below < mean, scythe_invalid_arg,
01029 "Truncation point < mean");
01030 SCYTHE_CHECK_10(variance <= 0, scythe_invalid_arg,
01031 "Variance <= 0");
01032
01033 double z = 0;
01034 double x = below + .00001;
01035
01036 for (unsigned int i=0; i<iter; ++i){
01037 z = runif()*std::exp(-1*std::pow((x-mean),2)/(2*variance));
01038 x = runif()*
01039 ((mean + std::sqrt(-2*variance*std::log(z))) - below) + below;
01040 }
01041
01042 if (! finite(x)) {
01043 SCYTHE_WARN("Mean extremely far from truncation point. "
01044 << "Returning truncation point");
01045 return below;
01046 }
01047
01048 return x;
01049 }
01050
01051 SCYTHE_RNGMETH_MATRIX(rtbnorm_slice, double,
01052 SCYTHE_ARGSET(mean, variance, below, iter), double mean,
01053 double variance, double below, unsigned int iter = 10);
01054
01077 double
01078 rtanorm_slice (double mean, double variance, double above,
01079 unsigned int iter = 10)
01080 {
01081 SCYTHE_CHECK_10(above > mean, scythe_invalid_arg,
01082 "Truncation point > mean");
01083 SCYTHE_CHECK_10(variance <= 0, scythe_invalid_arg,
01084 "Variance <= 0");
01085
01086 double below = -1*above;
01087 double newmu = -1*mean;
01088 double z = 0;
01089 double x = below + .00001;
01090
01091 for (unsigned int i=0; i<iter; ++i){
01092 z = runif()*std::exp(-1*std::pow((x-newmu),2)
01093 /(2*variance));
01094 x = runif()
01095 *( (newmu + std::sqrt(-2*variance*std::log(z))) - below)
01096 + below;
01097 }
01098 if (! finite(x)) {
01099 SCYTHE_WARN("Mean extremely far from truncation point. "
01100 << "Returning truncation point");
01101 return above;
01102 }
01103
01104 return -1*x;
01105 }
01106
01107 SCYTHE_RNGMETH_MATRIX(rtanorm_slice, double,
01108 SCYTHE_ARGSET(mean, variance, above, iter), double mean,
01109 double variance, double above, unsigned int iter = 10);
01110
01136 double
01137 rtbnorm_combo (double mean, double variance, double below,
01138 unsigned int iter = 10)
01139 {
01140 SCYTHE_CHECK_10(variance <= 0, scythe_invalid_arg,
01141 "Variance <= 0");
01142
01143 double s = std::sqrt(variance);
01144
01145
01146 if ((mean/s - below/s ) > -0.5){
01147 double x = rnorm(mean, s);
01148 while (x < below)
01149 x = rnorm(mean,s);
01150 return x;
01151 } else if ((mean/s - below/s ) > -5.0 ){
01152
01153 double above = std::numeric_limits<double>::infinity();
01154 double x = rtnorm(mean, variance, below, above);
01155 return x;
01156 } else {
01157
01158 double z = 0;
01159 double x = below + .00001;
01160 for (unsigned int i=0; i<iter; ++i){
01161 z = runif() * std::exp(-1 * std::pow((x - mean), 2)
01162 / (2 * variance));
01163 x = runif()
01164 * ((mean + std::sqrt(-2 * variance * std::log(z)))
01165 - below) + below;
01166 }
01167 if (! finite(x)) {
01168 SCYTHE_WARN("Mean extremely far from truncation point. "
01169 << "Returning truncation point");
01170 return below;
01171 }
01172 return x;
01173 }
01174 }
01175
01176 SCYTHE_RNGMETH_MATRIX(rtbnorm_combo, double,
01177 SCYTHE_ARGSET(mean, variance, below, iter), double mean,
01178 double variance, double below, unsigned int iter = 10);
01179
01204 double
01205 rtanorm_combo (double mean, double variance, double above,
01206 const unsigned int iter = 10)
01207 {
01208 SCYTHE_CHECK_10(variance <= 0, scythe_invalid_arg,
01209 "Variance <= 0");
01210
01211 double s = std::sqrt(variance);
01212
01213 if ((mean/s - above/s ) < 0.5){
01214 double x = rnorm(mean, s);
01215 while (x > above)
01216 x = rnorm(mean,s);
01217 return x;
01218 } else if ((mean/s - above/s ) < 5.0 ){
01219
01220 double below = -std::numeric_limits<double>::infinity();
01221 double x = rtnorm(mean, variance, below, above);
01222 return x;
01223 } else {
01224
01225 double below = -1*above;
01226 double newmu = -1*mean;
01227 double z = 0;
01228 double x = below + .00001;
01229
01230 for (unsigned int i=0; i<iter; ++i){
01231 z = runif() * std::exp(-1 * std::pow((x-newmu), 2)
01232 /(2 * variance));
01233 x = runif()
01234 * ((newmu + std::sqrt(-2 * variance * std::log(z)))
01235 - below) + below;
01236 }
01237 if (! finite(x)) {
01238 SCYTHE_WARN("Mean extremely far from truncation point. "
01239 << "Returning truncation point");
01240 return above;
01241 }
01242 return -1*x;
01243 }
01244 }
01245
01246 SCYTHE_RNGMETH_MATRIX(rtanorm_combo, double,
01247 SCYTHE_ARGSET(mean, variance, above, iter), double mean,
01248 double variance, double above, unsigned int iter = 10);
01249
01250
01251
01264 template <matrix_order O, matrix_style S>
01265 Matrix<double, O, Concrete>
01266 rwish(unsigned int v, const Matrix<double, O, S> &Sigma)
01267 {
01268 SCYTHE_CHECK_10(! Sigma.isSquare(), scythe_dimension_error,
01269 "Sigma not square");
01270 SCYTHE_CHECK_10(v < Sigma.rows(), scythe_invalid_arg,
01271 "v < Sigma.rows()");
01272
01273 Matrix<double,O,Concrete>
01274 A(Sigma.rows(), Sigma.rows());
01275 Matrix<double,O,Concrete> C = cholesky<O,Concrete>(Sigma);
01276 Matrix<double,O,Concrete> alpha;
01277
01278 for (unsigned int i = 0; i < v; ++i) {
01279 alpha = C * rnorm(Sigma.rows(), 1, 0, 1);
01280 A = A + (alpha * (t(alpha)));
01281 }
01282
01283 return A;
01284 }
01285
01297 template <matrix_order O, matrix_style S>
01298 Matrix<double, O, Concrete>
01299 rdirich(const Matrix<double, O, S>& alpha)
01300 {
01301
01302 SCYTHE_CHECK_10(std::min(alpha) <= 0, scythe_invalid_arg,
01303 "alpha has elements < 0");
01304 SCYTHE_CHECK_10(! alpha.isColVector(), scythe_dimension_error,
01305 "alpha not column vector");
01306
01307 Matrix<double, O, Concrete> y(alpha.rows(), 1);
01308 double ysum = 0;
01309
01310
01311
01312 const_matrix_forward_iterator<double,O,O,S> ait;
01313 const_matrix_forward_iterator<double,O,O,S> alast
01314 = alpha.template end_f();
01315 typename Matrix<double,O,Concrete>::forward_iterator yit
01316 = y.begin_f();
01317 for (ait = alpha.begin_f(); ait != alast; ++ait) {
01318 *yit = rgamma(*ait, 1);
01319 ysum += *yit;
01320 ++ait;
01321 }
01322
01323 y /= ysum;
01324
01325 return y;
01326 }
01327
01341 template <matrix_order PO1, matrix_style PS1,
01342 matrix_order PO2, matrix_style PS2>
01343 Matrix<double, PO1, Concrete>
01344 rmvnorm(const Matrix<double, PO1, PS1>& mu,
01345 const Matrix<double, PO2, PS2>& sigma)
01346 {
01347 unsigned int dim = mu.rows();
01348 SCYTHE_CHECK_10(! mu.isColVector(), scythe_dimension_error,
01349 "mu not column vector");
01350 SCYTHE_CHECK_10(! sigma.isSquare(), scythe_dimension_error,
01351 "sigma not square");
01352 SCYTHE_CHECK_10(sigma.rows() != dim, scythe_conformation_error,
01353 "mu and sigma not conformable");
01354
01355 return(mu + cholesky(sigma) * rnorm(dim, 1, 0, 1));
01356 }
01357
01372 template <matrix_order O, matrix_style S>
01373 Matrix<double, O, Concrete>
01374 rmvt (const Matrix<double, O, S>& sigma, double nu)
01375 {
01376 Matrix<double, O, Concrete> result;
01377 SCYTHE_CHECK_10(nu <= 0, scythe_invalid_arg,
01378 "D.O.F parameter nu <= 0");
01379
01380 result =
01381 rmvnorm(Matrix<double, O>(sigma.rows(), 1, true, 0), sigma);
01382 result /= std::sqrt(rchisq(nu) / nu);
01383 return result;
01384 }
01385
01386 protected:
01387
01392 rng()
01393 : rnorm_count_ (1)
01394 {}
01395
01396
01397 RNGTYPE& as_derived()
01398 {
01399 return static_cast<RNGTYPE&>(*this);
01400 }
01401
01402
01403
01404
01405
01406
01407
01408
01409
01410
01411
01412 int rnorm_count_;
01413 double x2_;
01414
01415 double
01416 rnorm1 ()
01417 {
01418 double nu1, nu2, rsquared, sqrt_term;
01419 if (rnorm_count_ == 1){
01420 do {
01421 nu1 = -1 +2*runif();
01422 nu2 = -1 +2*runif();
01423 rsquared = ::pow(nu1,2) + ::pow(nu2,2);
01424 } while (rsquared >= 1 || rsquared == 0.0);
01425 sqrt_term = std::sqrt(-2*std::log(rsquared)/rsquared);
01426 x2_ = nu2*sqrt_term;
01427 rnorm_count_ = 2;
01428 return nu1*sqrt_term;
01429 } else {
01430 rnorm_count_ = 1;
01431 return x2_;
01432 }
01433 }
01434
01435
01436 double accept_;
01437
01438 double
01439 rgamma1 (double alpha)
01440 {
01441 int test;
01442 double u, v, w, x, y, z, b, c;
01443
01444
01445 SCYTHE_CHECK_10(alpha <= 1, scythe_invalid_arg, "alpha <= 1");
01446
01447
01448 b = alpha - 1;
01449 c = 3 * alpha - 0.75;
01450 test = 0;
01451 while (test == 0) {
01452 u = runif ();
01453 v = runif ();
01454
01455 w = u * (1 - u);
01456 y = std::sqrt (c / w) * (u - .5);
01457 x = b + y;
01458
01459 if (x > 0) {
01460 z = 64 * std::pow (v, 2) * std::pow (w, 3);
01461 if (z <= (1 - (2 * std::pow (y, 2) / x))) {
01462 test = 1;
01463 accept_ = x;
01464 } else if ((2 * (b * std::log (x / b) - y)) >= ::log (z)) {
01465 test = 1;
01466 accept_ = x;
01467 } else {
01468 test = 0;
01469 }
01470 }
01471 }
01472
01473 return (accept_);
01474 }
01475
01476 };
01477
01478
01479 }
01480 #endif