00001 #include <csignal>
00002 #include <cstdio>
00003 #include <iostream>
00004 #include <fstream>
00005 #include <sstream>
00006 #include <algorithm>
00007 #include <ctime>
00008 #include <math.h>
00009 #include <qthread.h>
00010 #include <gzstream.h>
00011 #include "coloring.hpp"
00012 #include "random.hpp"
00013 #include "gui.hpp"
00014 #include "point_obs.hpp"
00015 #include "range_obs.hpp"
00016 #include "sonar_obs.hpp"
00017 #include "coloring_mcmc.hpp"
00018 #include "cn94.hpp"
00019 #include "query_point.hpp"
00020 #include "estimation.hpp"
00021 #include "parsecl.hpp"
00022 #include "properties.hpp"
00023 #include "verification.hpp"
00024
00025 using std::cout;
00026 using std::endl;
00027 using std::istream;
00028 using std::ifstream;
00029 using std::ofstream;
00030 using namespace Arak;
00031 using namespace Arak::Util;
00032
00038 const char* default_props = "\n\
00039 arak.proposal = cn94 \n\
00040 arak.cn94.zeta.bd_not_mr = 0.5 \n\
00041 arak.cn94.zeta.move_not_recolor = 0.5 \n\
00042 arak.cn94.zeta.birth_not_death = 0.5 \n\
00043 arak.cn94.zeta.int_not_boundary = 0.68 \n\
00044 arak.cn94.zeta.bound_not_corner = 0.9 \n\
00045 arak.cn94.zeta.ib_vertex_not_triangle = 0.6 \n\
00046 arak.modified_cn94.zeta.local_not_global_recolor = 0.9 \n\
00047 arak.modified_cn94.zeta.local_not_global_bd_tri_birth = 0.5 \n\
00048 # arak.sampler.coloring_path = /tmp/coloring.out \n\
00049 # arak.sampler.map_coloring_path = /tmp/map-coloring.out \n\
00050 # arak.sampler.estimates_path = /tmp/estimates.out.gz \n\
00051 # arak.sampler.flush_interval = 600.0 \n\
00052 # arak.sampler.samples_path = /tmp/samples.bin \n\
00053 # arak.sampler.record_stride = 1000 \n\
00054 arak.scale = 1.0 \n\
00055 arak.coloring.xmin = 0.0 \n\
00056 arak.coloring.ymin = 0.0 \n\
00057 arak.coloring.width = 1.0 \n\
00058 arak.coloring.height = 1.0 \n\
00059 arak.coloring.rows = 35 \n\
00060 arak.coloring.cols = 35 \n\
00061 arak.mcmc.estimation.stride = 100 \n\
00062 arak.mcmc.estimation.burn_in = 1000 \n\
00063 arak.mcmc.estimation.grid_size = 10000 \n\
00064 ";
00065
00066
00067 volatile std::sig_atomic_t interrupted = 0;
00068 extern "C" void sigint(int sig) { interrupted = 1; }
00069
00074 class MCMCThread : public QThread {
00075
00076 protected:
00077
00079 ArakMarkovChain& chain;
00080
00082 const GridColorEstimator& estimates;
00083
00088 CriticalStatsEstimator* stats;
00089
00091 QMutex& mutex;
00092
00094 std::string coloring_path;
00095
00097 std::string estimates_path;
00098
00100 std::string map_coloring_path;
00101
00103 double map_log_likelihood;
00104
00106 std::ofstream* samples_stream;
00107
00109 int record_stride;
00110
00112 double flush_interval;
00113
00117 void flush() {
00118 if (!coloring_path.empty()) {
00119 ofstream clr_out(coloring_path.data());
00120 mutex.lock();
00121 clr_out << chain.getState();
00122 mutex.unlock();
00123 clr_out.close();
00124 }
00125 if (!estimates_path.empty()) {
00126 ogzstream est_out(estimates_path.data());
00127 mutex.lock();
00128 est_out << estimates;
00129 mutex.unlock();
00130 est_out.close();
00131 }
00132 }
00133
00137 void processSample() {
00138 mutex.lock();
00139 if (stats != NULL)
00140 stats->update(chain.getState());
00141
00142 if ((samples_stream != NULL) &&
00143 (chain.getNumSteps() % (record_stride + 1) == 0)) {
00144 chain.getState().writeBinary(*samples_stream);
00145 }
00146 double cur_log_likelihood = chain.getLogLikelihood();
00147 if ((cur_log_likelihood > map_log_likelihood)
00148 &&
00149 (!map_coloring_path.empty())
00150 &&
00151
00152
00153 (chain.getNumSamples() > 1)) {
00154 map_log_likelihood = cur_log_likelihood;
00155 ofstream clr_out(map_coloring_path.data());
00156 clr_out << chain.getState();
00157 clr_out.close();
00158 }
00159 mutex.unlock();
00160 }
00161
00166 void run() {
00167
00168
00169 unsigned long int interval = 0;
00170 std::time_t last_flush = std::time(NULL);
00171 while (true) {
00172 if (interrupted != 0) goto report;
00173 mutex.lock();
00174 chain.advance();
00175 mutex.unlock();
00176 processSample();
00177 std::time_t cur_time = std::time(NULL);
00178 if (std::difftime(cur_time, last_flush) > flush_interval)
00179 break;
00180 else
00181 interval++;
00182 }
00183
00184 while (true) {
00185
00186 last_flush = std::time(NULL);
00187 flush();
00188 std::string timestr(std::asctime(std::localtime(&last_flush)), 23);
00189 std::cerr << "[" << timestr << "] "
00190 << "Preliminary estimates written to "
00191 << coloring_path << "/"
00192 << estimates_path << "." << endl;
00193
00194 for (unsigned long int i = 0; i < interval; i++) {
00195 if (interrupted != 0) goto report;
00196 mutex.lock();
00197 chain.advance();
00198 mutex.unlock();
00199 processSample();
00200 }
00201 }
00202 report:
00203 if (stats != NULL) {
00204 std::cout << std::endl << "VERIFICATION STATISTICS:" << std::endl;
00205 stats->report(chain.getState().boundary(),
00206 chain.getDistribution().scale(),
00207 std::cout);
00208 std::cout << std::endl;
00209 }
00210 if (samples_stream != NULL) samples_stream->close();
00211 }
00212
00213 public:
00214
00218 MCMCThread(ArakMarkovChain& chain,
00219 const GridColorEstimator& estimates,
00220 QMutex& mutex,
00221 const Arak::Util::PropertyMap& props) :
00222 chain(chain),
00223 estimates(estimates),
00224 mutex(mutex),
00225 map_log_likelihood(-std::numeric_limits<double>::infinity())
00226 {
00227
00228 if (hasp(props, "arak.sampler.coloring_path"))
00229 coloring_path = getp(props, "arak.sampler.coloring_path");
00230 else
00231 coloring_path = "";
00232 if (hasp(props, "arak.sampler.map_coloring_path"))
00233 map_coloring_path = getp(props, "arak.sampler.map_coloring_path");
00234 else
00235 map_coloring_path = "";
00236 if (hasp(props, "arak.sampler.estimates_path"))
00237 estimates_path = getp(props, "arak.sampler.estimates_path");
00238 else
00239 estimates_path = "";
00240 if (hasp(props, "arak.sampler.flush_interval"))
00241 assert(parse(getp(props, "arak.sampler.flush_interval"),
00242 flush_interval));
00243 else
00244 flush_interval = 600.0;
00245 if (hasp(props, "arak.sampler.verify")) {
00246 int k;
00247 assert(parse(getp(props, "arak.sampler.verify"), k));
00248 stats = new CriticalStatsEstimator(k, chain.getState());
00249 } else
00250 stats = NULL;
00251 if (hasp(props, "arak.sampler.samples_path")) {
00252 const std::string& samples_path =
00253 getp(props, "arak.sampler.samples_path");
00254 samples_stream = new std::ofstream(samples_path.data(),
00255 std::ios::binary |
00256 std::ios::out);
00257 if (hasp(props, "arak.sampler.record_stride"))
00258 assert(parse(getp(props, "arak.sampler.record_stride"),
00259 record_stride));
00260 } else
00261 samples_stream = NULL;
00262 }
00263
00267 ~MCMCThread() {
00268 if (stats != NULL) delete stats;
00269 if (samples_stream != NULL) delete samples_stream;
00270 }
00271 };
00272
00285 int sampler(int argc,
00286 char** argv,
00287 bool visualize,
00288 bool toolbar,
00289 double refreshRateHz,
00290 const Arak::Util::PropertyMap& props) {
00291
00292 Coloring c(props);
00293 if (hasp(props, "arak.sampler.init_state_path")) {
00294 const std::string& path = getp(props, "arak.sampler.init_state_path");
00295 ifstream in(path.data());
00296 in >> c;
00297 if (!in.good()) {
00298 std::cerr << "Cannot read initial coloring from " << path << std::endl;
00299 exit(1);
00300 }
00301 }
00302
00303 ArakProcess* prior = NULL;
00304 using namespace Arak::Util;
00305 if (hasp(props, "arak.prior")) {
00306 const std::string& prior_type = getp(props, "arak.prior");
00307 if (prior_type == "standard")
00308 prior = new ArakPrior(c, props);
00309 else
00310 std::cerr << "Unrecognized prior type: "
00311 << prior_type << std::endl;
00312 } else
00313 prior = new ArakPrior(c, props);
00314
00315 ArakProcess* process = NULL;
00316 using namespace Arak::Util;
00317 if (hasp(props, "arak.obs_type")) {
00318 const std::string& obs_type = getp(props, "arak.obs_type");
00319 if (obs_type == "gaussian")
00320 process = new ArakPosteriorGaussianObs(*prior, props);
00321 else if (obs_type == "bernoulli")
00322 process = new ArakPosteriorBernoulliObs(*prior, props);
00323 else if (obs_type == "range")
00324 process = new ArakPosteriorRangeObs(*prior, props);
00325 else if (obs_type == "sonar")
00326 process = new ArakPosteriorSonarObs(*prior, props);
00327 else
00328 std::cerr << "Unrecognized observation type: "
00329 << obs_type << std::endl;
00330 } else
00331 process = prior;
00332
00333
00334 CN94Proposal* proposal = NULL;
00335 const std::string& proposal_type = getp(props, "arak.proposal");
00336 if (proposal_type == "cn94")
00337 proposal = new CN94Proposal(props);
00338 else if (proposal_type == "modified_cn94")
00339 proposal = new ModifiedCN94Proposal(props);
00340 else {
00341 std::cerr << "Unrecognized proposal distribution: "
00342 << proposal_type << std::endl;
00343 exit(1);
00344 }
00345
00346
00347 Arak::Util::Random& random = Arak::Util::default_random;
00348 std::cout << "PRNG: " << random << std::endl;
00349
00350
00351 ArakMarkovChain* chain;
00352 if (hasp(props, "arak.sampler.chain_type")) {
00353 const std::string& chain_type = getp(props, "arak.sampler.chain_type");
00354 if (chain_type == "standard")
00355 chain = new ArakMarkovChain(*process, *proposal, c,
00356 props, default_random);
00357 else if (chain_type == "annealed")
00358 chain = new AnnealedArakMarkovChain(*process, *proposal, c,
00359 props, default_random);
00360 else if (chain_type == "hill-climbing") {
00361 chain = new HillClimbingArakMarkovChain(*process, *proposal, c,
00362 props, default_random);
00363 std::cerr << "Warning: hill-climbing in an Arak process is usually a bad"
00364 << std::endl
00365 << " idea; try annealing instead."
00366 << std::endl;
00367 } else {
00368 std::cerr << "Unrecognized Markov chain type: "
00369 << chain_type << std::endl;
00370 exit(1);
00371 }
00372 } else
00373 chain = new ArakMarkovChain(*process, *proposal, c,
00374 props, default_random);
00375
00376
00377 GridColorEstimator estimator(*chain, props);
00378
00379
00380 QMutex mutex;
00381 MCMCThread thread(*chain, estimator, mutex, props);
00382
00383
00384 std::time_t start = std::time(NULL);
00385 thread.start();
00386 int result = 0;
00387
00388
00389 if (visualize) {
00390 QApplication *app = InitGui(argc, argv,
00391 480, 480,
00392 c, *process, estimator,
00393 mutex,
00394 toolbar,
00395 refreshRateHz);
00396 result = app->exec();
00397 interrupted = 1;
00398 } else {
00399 if (std::signal(SIGINT, sigint) == SIG_ERR)
00400 std::cerr << "Warning: cannot set signal handler." << endl;
00401 else
00402 cout << "Signal INT (press Ctrl-C) to stop sampling..." << endl;
00403 }
00404
00405
00406 thread.wait();
00407 std::time_t end = std::time(NULL);
00408
00409
00410 cout << "Number of samples: " << chain->getNumSteps()
00411 << " (" << double(chain->getNumSteps()) / difftime(end, start)
00412 << " samples per second)" << endl;
00413 cout << "Acceptance ratio: " << chain->acceptance()
00414 << " (" << double(chain->getNumMoves()) / difftime(end, start)
00415 << " moves per second)" << endl;
00416 proposal->writeStatistics(cout);
00417
00418
00419 if (hasp(props, "arak.sampler.coloring_path")) {
00420 std::string coloring_path = getp(props, "arak.sampler.coloring_path");
00421 ofstream clr_out(coloring_path.data());
00422 clr_out << chain->getState();
00423 clr_out.close();
00424 cout << "Final coloring written to " << coloring_path << "." << endl;
00425 }
00426 if (hasp(props, "arak.sampler.estimates_path")) {
00427 std::string estimates_path = getp(props, "arak.sampler.estimates_path");
00428 ogzstream est_out(estimates_path.data());
00429 est_out << estimator;
00430 est_out.close();
00431 cout << "Final estimates written to " << estimates_path << "." << endl;
00432 }
00433
00434
00435 delete chain;
00436 if (prior != process)
00437 delete process;
00438 delete prior;
00439 delete proposal;
00440
00441
00442 return result;
00443 }
00444
00449 int main(int argc, char** argv) {
00450 using namespace Arak::Util;
00451
00452 CommandLine cl;
00453 CommandLine::Option visualize("-v", "--visualize",
00454 "visualizes the state of the sampler");
00455 cl.add(visualize);
00456 CommandLine::Parameter<double>
00457 refresh("-r", "--refresh-rate",
00458 "the rate (in Hz) that the visualization is updated", 1.0);
00459 cl.add(refresh);
00460 CommandLine::Option toolbar("-t", "--toolbar",
00461 "uses the toolbar in the visualization");
00462 cl.add(toolbar);
00463 CommandLine::MultiParameter<std::string>
00464 propFiles("-p", "--prop-file", "specifies a property file");
00465 cl.add(propFiles);
00466 CommandLine::MultiParameter<std::string>
00467 propDefs("-D", "--define-prop", "defines a property value");
00468 cl.add(propDefs);
00469
00470 if (!cl.parse(argc, argv, std::cerr)) {
00471 cl.printUsage(std::cerr);
00472 exit(1);
00473 }
00474
00475
00476 Arak::Util::PropertyMap properties;
00477 {
00478 std::string defaultProps(default_props);
00479 std::istringstream in(defaultProps);
00480 in >> properties;
00481 }
00482
00483
00484 const std::vector<std::string>& propFilePaths = propFiles.values();
00485 for (std::vector<std::string>::const_iterator it = propFilePaths.begin();
00486 it != propFilePaths.end(); it++) {
00487 std::string path = *it;
00488 std::ifstream in(path.data());
00489 in >> properties;
00490 }
00491
00492 const std::vector<std::string>& propDefinitions = propDefs.values();
00493 for (std::vector<std::string>::const_iterator it = propDefinitions.begin();
00494 it != propDefinitions.end(); it++) {
00495 std::string def = *it;
00496 std::istringstream in(def);
00497 in >> properties;
00498 }
00499
00500 std::cout << "Properties:" << std::endl << properties << std::endl;
00501
00502
00503 return sampler(argc, argv,
00504 visualize.supplied(),
00505 toolbar.supplied(),
00506 refresh.value(),
00507 properties);
00508 }