Main Page | Modules | Namespace List | Class Hierarchy | Class List | File List | Namespace Members | Class Members | File Members | Related Pages

sampler.cpp

Go to the documentation of this file.
00001 #include <csignal>
00002 #include <cstdio>
00003 #include <iostream>
00004 #include <fstream>
00005 #include <sstream>
00006 #include <algorithm>
00007 #include <ctime>
00008 #include <math.h>
00009 #include <qthread.h>
00010 #include <gzstream.h>
00011 #include "coloring.hpp"
00012 #include "random.hpp"
00013 #include "gui.hpp"
00014 #include "point_obs.hpp"
00015 #include "range_obs.hpp"
00016 #include "sonar_obs.hpp"
00017 #include "coloring_mcmc.hpp"
00018 #include "cn94.hpp"
00019 #include "query_point.hpp"
00020 #include "estimation.hpp"
00021 #include "parsecl.hpp"
00022 #include "properties.hpp"
00023 #include "verification.hpp"
00024 
00025 using std::cout; 
00026 using std::endl;
00027 using std::istream;
00028 using std::ifstream;
00029 using std::ofstream;
00030 using namespace Arak;
00031 using namespace Arak::Util;
00032 
00038 const char* default_props = "\n\
00039 arak.proposal = cn94 \n\
00040 arak.cn94.zeta.bd_not_mr = 0.5 \n\
00041 arak.cn94.zeta.move_not_recolor = 0.5 \n\
00042 arak.cn94.zeta.birth_not_death = 0.5 \n\
00043 arak.cn94.zeta.int_not_boundary = 0.68 \n\
00044 arak.cn94.zeta.bound_not_corner = 0.9 \n\
00045 arak.cn94.zeta.ib_vertex_not_triangle = 0.6 \n\
00046 arak.modified_cn94.zeta.local_not_global_recolor = 0.9 \n\
00047 arak.modified_cn94.zeta.local_not_global_bd_tri_birth = 0.5 \n\
00048 # arak.sampler.coloring_path = /tmp/coloring.out \n\
00049 # arak.sampler.map_coloring_path = /tmp/map-coloring.out \n\
00050 # arak.sampler.estimates_path = /tmp/estimates.out.gz \n\
00051 # arak.sampler.flush_interval = 600.0 \n\
00052 # arak.sampler.samples_path = /tmp/samples.bin \n\
00053 # arak.sampler.record_stride = 1000 \n\
00054 arak.scale = 1.0 \n\
00055 arak.coloring.xmin = 0.0 \n\
00056 arak.coloring.ymin = 0.0 \n\
00057 arak.coloring.width = 1.0 \n\
00058 arak.coloring.height = 1.0 \n\
00059 arak.coloring.rows = 35 \n\
00060 arak.coloring.cols = 35 \n\
00061 arak.mcmc.estimation.stride = 100 \n\
00062 arak.mcmc.estimation.burn_in = 1000 \n\
00063 arak.mcmc.estimation.grid_size = 10000 \n\
00064 ";
00065 
00066 // Signal handling for interrupts
00067 volatile std::sig_atomic_t interrupted = 0;
00068 extern "C" void sigint(int sig) { interrupted = 1; }
00069 
00074 class MCMCThread : public QThread {
00075 
00076  protected:
00077 
00079   ArakMarkovChain& chain;
00080 
00082   const GridColorEstimator& estimates;
00083 
00088   CriticalStatsEstimator* stats;
00089 
00091   QMutex& mutex;
00092 
00094   std::string coloring_path;
00095 
00097   std::string estimates_path;
00098 
00100   std::string map_coloring_path;
00101 
00103   double map_log_likelihood;
00104 
00106   std::ofstream* samples_stream;
00107 
00109   int record_stride;
00110 
00112   double flush_interval;
00113 
00117   void flush() {
00118     if (!coloring_path.empty()) {
00119       ofstream clr_out(coloring_path.data());
00120       mutex.lock();
00121       clr_out << chain.getState();
00122       mutex.unlock();
00123       clr_out.close();
00124     }
00125     if (!estimates_path.empty()) {
00126       ogzstream est_out(estimates_path.data());
00127       mutex.lock();
00128       est_out << estimates;
00129       mutex.unlock();
00130       est_out.close();
00131     }
00132   }
00133 
00137   void processSample() {
00138     mutex.lock();
00139     if (stats != NULL)
00140       stats->update(chain.getState());
00141     // Record the current sample if needed.
00142     if ((samples_stream != NULL) &&
00143   (chain.getNumSteps() % (record_stride + 1) == 0)) {
00144       chain.getState().writeBinary(*samples_stream);
00145     }
00146     double cur_log_likelihood = chain.getLogLikelihood();
00147     if ((cur_log_likelihood > map_log_likelihood)
00148   &&
00149   (!map_coloring_path.empty())
00150   &&
00151   // Make sure we're past the burn-in phase to avoid too many
00152   // initial writes.
00153   (chain.getNumSamples() > 1)) {
00154       map_log_likelihood = cur_log_likelihood;
00155       ofstream clr_out(map_coloring_path.data());
00156       clr_out << chain.getState();
00157       clr_out.close();
00158     }
00159     mutex.unlock();
00160   }
00161 
00166   void run() {
00167     // The number of samples between flushes.  This is determined
00168     // automatically.
00169     unsigned long int interval = 0;
00170     std::time_t last_flush = std::time(NULL);
00171     while (true) {
00172       if (interrupted != 0) goto report;
00173       mutex.lock();
00174       chain.advance();
00175       mutex.unlock();
00176       processSample();
00177       std::time_t cur_time = std::time(NULL);
00178       if (std::difftime(cur_time, last_flush) > flush_interval) 
00179   break;
00180       else
00181   interval++;
00182     }
00183     // We have determined the interval and are ready for the first flush.
00184     while (true) {
00185       // Write the current sample and point color estimates.
00186       last_flush = std::time(NULL);
00187       flush();
00188       std::string timestr(std::asctime(std::localtime(&last_flush)), 23);
00189       std::cerr << "[" << timestr << "] " 
00190     << "Preliminary estimates written to " 
00191     << coloring_path << "/"
00192     << estimates_path << "." << endl;
00193       // Continue sampling.
00194       for (unsigned long int i = 0; i < interval; i++) {
00195   if (interrupted != 0) goto report;
00196   mutex.lock();
00197   chain.advance();
00198   mutex.unlock();
00199   processSample();
00200       }
00201     }
00202   report:
00203     if (stats != NULL) {
00204       std::cout << std::endl << "VERIFICATION STATISTICS:" << std::endl;
00205       stats->report(chain.getState().boundary(),
00206         chain.getDistribution().scale(),
00207         std::cout);
00208       std::cout << std::endl;
00209     }
00210     if (samples_stream != NULL) samples_stream->close();
00211   }
00212 
00213  public:
00214 
00218   MCMCThread(ArakMarkovChain& chain,
00219        const GridColorEstimator& estimates,
00220        QMutex& mutex,
00221        const Arak::Util::PropertyMap& props) : 
00222     chain(chain), 
00223     estimates(estimates),
00224     mutex(mutex),
00225     map_log_likelihood(-std::numeric_limits<double>::infinity()) 
00226   { 
00227     // Parse the properties.
00228     if (hasp(props, "arak.sampler.coloring_path"))
00229       coloring_path = getp(props, "arak.sampler.coloring_path");
00230     else
00231       coloring_path = "";
00232     if (hasp(props, "arak.sampler.map_coloring_path"))
00233       map_coloring_path = getp(props, "arak.sampler.map_coloring_path");
00234     else 
00235       map_coloring_path = "";
00236     if (hasp(props, "arak.sampler.estimates_path"))
00237       estimates_path = getp(props, "arak.sampler.estimates_path");
00238     else
00239       estimates_path = "";
00240     if (hasp(props, "arak.sampler.flush_interval"))
00241       assert(parse(getp(props, "arak.sampler.flush_interval"), 
00242        flush_interval));
00243     else
00244       flush_interval = 600.0; // 10 minutes
00245     if (hasp(props, "arak.sampler.verify")) {
00246       int k;
00247       assert(parse(getp(props, "arak.sampler.verify"), k));
00248       stats = new CriticalStatsEstimator(k, chain.getState());
00249     } else
00250       stats = NULL;
00251     if (hasp(props, "arak.sampler.samples_path")) {
00252       const std::string& samples_path = 
00253   getp(props, "arak.sampler.samples_path");
00254       samples_stream = new std::ofstream(samples_path.data(), 
00255            std::ios::binary |
00256            std::ios::out);
00257       if (hasp(props, "arak.sampler.record_stride"))
00258   assert(parse(getp(props, "arak.sampler.record_stride"), 
00259          record_stride));
00260     } else
00261       samples_stream = NULL;
00262   }
00263 
00267   ~MCMCThread() {
00268     if (stats != NULL) delete stats;
00269     if (samples_stream != NULL) delete samples_stream;
00270   }
00271 };
00272 
00285 int sampler(int argc, 
00286       char** argv,
00287       bool visualize,
00288       bool toolbar,
00289       double refreshRateHz,
00290       const Arak::Util::PropertyMap& props) {
00291   // Form the initial coloring.
00292   Coloring c(props);
00293   if (hasp(props, "arak.sampler.init_state_path")) {
00294     const std::string& path = getp(props, "arak.sampler.init_state_path");
00295     ifstream in(path.data());
00296     in >> c;
00297     if (!in.good()) {
00298       std::cerr << "Cannot read initial coloring from " << path << std::endl;
00299       exit(1);
00300     }
00301   }
00302   // Form the prior process.
00303   ArakProcess* prior = NULL;
00304   using namespace Arak::Util;
00305   if (hasp(props, "arak.prior")) {
00306     const std::string& prior_type = getp(props, "arak.prior");
00307     if (prior_type == "standard") 
00308       prior = new ArakPrior(c, props);
00309     else
00310       std::cerr << "Unrecognized prior type: " 
00311     << prior_type << std::endl;
00312   } else 
00313     prior = new ArakPrior(c, props);
00314   // Form the posterior process.
00315   ArakProcess* process = NULL;
00316   using namespace Arak::Util;
00317   if (hasp(props, "arak.obs_type")) {
00318     const std::string& obs_type = getp(props, "arak.obs_type");
00319     if (obs_type == "gaussian") 
00320       process = new ArakPosteriorGaussianObs(*prior, props);
00321     else if (obs_type == "bernoulli") 
00322       process = new ArakPosteriorBernoulliObs(*prior, props);
00323     else if (obs_type == "range") 
00324       process = new ArakPosteriorRangeObs(*prior, props);
00325     else if (obs_type == "sonar") 
00326       process = new ArakPosteriorSonarObs(*prior, props);
00327     else
00328       std::cerr << "Unrecognized observation type: " 
00329     << obs_type << std::endl;
00330   } else 
00331     process = prior;
00332   
00333   // Form the proposal distribution.
00334   CN94Proposal* proposal = NULL;
00335   const std::string& proposal_type = getp(props, "arak.proposal");
00336   if (proposal_type == "cn94") 
00337     proposal = new CN94Proposal(props);
00338   else if (proposal_type == "modified_cn94") 
00339     proposal = new ModifiedCN94Proposal(props);
00340   else {
00341     std::cerr << "Unrecognized proposal distribution: "
00342         << proposal_type << std::endl;
00343     exit(1);
00344   }
00345 
00346   // Choose the random number generator.
00347   Arak::Util::Random& random = Arak::Util::default_random;
00348   std::cout << "PRNG: " << random << std::endl;
00349 
00350   // Form the Markov chain.
00351   ArakMarkovChain* chain;
00352   if (hasp(props, "arak.sampler.chain_type")) {
00353     const std::string& chain_type = getp(props, "arak.sampler.chain_type");
00354     if (chain_type == "standard") 
00355       chain = new ArakMarkovChain(*process, *proposal, c, 
00356           props, default_random);
00357     else if (chain_type == "annealed") 
00358       chain = new AnnealedArakMarkovChain(*process, *proposal, c, 
00359             props, default_random);
00360     else if (chain_type == "hill-climbing") {
00361       chain = new HillClimbingArakMarkovChain(*process, *proposal, c, 
00362                 props, default_random);
00363       std::cerr << "Warning: hill-climbing in an Arak process is usually a bad"
00364     << std::endl
00365     << "         idea; try annealing instead."
00366     << std::endl;
00367     } else {
00368       std::cerr << "Unrecognized Markov chain type: "
00369     << chain_type << std::endl;
00370       exit(1);
00371     }
00372   } else
00373     chain = new ArakMarkovChain(*process, *proposal, c, 
00374         props, default_random);
00375 
00376   // Form the point set color estimator.
00377   GridColorEstimator estimator(*chain, props);
00378   
00379   // Create the sampling thread and the GUI application.
00380   QMutex mutex;
00381   MCMCThread thread(*chain, estimator, mutex, props);
00382 
00383   // Start the inference thread and transfer control to the GUI.
00384   std::time_t start = std::time(NULL);
00385   thread.start();
00386   int result = 0;
00387 
00388   // If visualization was requested, start the GUI.
00389   if (visualize) {
00390     QApplication *app = InitGui(argc, argv, 
00391         480, 480, 
00392         c, *process, estimator,
00393         mutex,
00394         toolbar, 
00395         refreshRateHz);
00396     result = app->exec();
00397     interrupted = 1;
00398   } else {
00399     if (std::signal(SIGINT, sigint) == SIG_ERR)
00400       std::cerr << "Warning: cannot set signal handler." << endl;
00401     else
00402       cout << "Signal INT (press Ctrl-C) to stop sampling..." << endl;
00403   }
00404 
00405   // Wait for the sampler to stop.
00406   thread.wait();
00407   std::time_t end = std::time(NULL);
00408 
00409   // Report the results.
00410   cout << "Number of samples: " << chain->getNumSteps() 
00411        << " (" << double(chain->getNumSteps()) / difftime(end, start)
00412        << " samples per second)" << endl;
00413   cout << "Acceptance ratio: " << chain->acceptance() 
00414        << " (" << double(chain->getNumMoves()) / difftime(end, start)
00415        << " moves per second)" << endl;
00416   proposal->writeStatistics(cout);
00417 
00418   // Write the final coloring and point color estimates.
00419   if (hasp(props, "arak.sampler.coloring_path")) {
00420     std::string coloring_path = getp(props, "arak.sampler.coloring_path");
00421     ofstream clr_out(coloring_path.data());
00422     clr_out << chain->getState();
00423     clr_out.close();
00424     cout << "Final coloring written to " << coloring_path << "." << endl;
00425   }
00426   if (hasp(props, "arak.sampler.estimates_path")) {
00427     std::string estimates_path = getp(props, "arak.sampler.estimates_path");
00428     ogzstream est_out(estimates_path.data());
00429     est_out << estimator;
00430     est_out.close();
00431     cout << "Final estimates written to " << estimates_path << "." << endl;
00432   }
00433 
00434   // Deallocate the chain, prior, process, and proposal.
00435   delete chain;
00436   if (prior != process)
00437     delete process;
00438   delete prior;
00439   delete proposal;
00440   
00441   // Return the result of the GUI.
00442   return result;
00443 }
00444 
00449 int main(int argc, char** argv) {
00450   using namespace Arak::Util;
00451   // Parse the command line.
00452   CommandLine cl;
00453   CommandLine::Option visualize("-v", "--visualize", 
00454         "visualizes the state of the sampler");
00455   cl.add(visualize);
00456   CommandLine::Parameter<double> 
00457     refresh("-r", "--refresh-rate", 
00458       "the rate (in Hz) that the visualization is updated", 1.0);
00459   cl.add(refresh);
00460   CommandLine::Option toolbar("-t", "--toolbar", 
00461             "uses the toolbar in the visualization");
00462   cl.add(toolbar);
00463   CommandLine::MultiParameter<std::string>
00464     propFiles("-p", "--prop-file", "specifies a property file");
00465   cl.add(propFiles);
00466   CommandLine::MultiParameter<std::string>
00467     propDefs("-D", "--define-prop", "defines a property value");
00468   cl.add(propDefs);
00469 
00470   if (!cl.parse(argc, argv, std::cerr)) {
00471     cl.printUsage(std::cerr);
00472     exit(1);
00473   }
00474 
00475   // Parse the default properties.
00476   Arak::Util::PropertyMap properties;
00477   {
00478     std::string defaultProps(default_props);
00479     std::istringstream in(defaultProps);
00480     in >> properties;
00481   }
00482 
00483   // Parse the properties files.
00484   const std::vector<std::string>& propFilePaths = propFiles.values();
00485   for (std::vector<std::string>::const_iterator it = propFilePaths.begin();
00486        it != propFilePaths.end(); it++) {
00487     std::string path = *it;
00488     std::ifstream in(path.data());
00489     in >> properties;
00490   }
00491   // Parse the overriding property definitions.
00492   const std::vector<std::string>& propDefinitions = propDefs.values();
00493   for (std::vector<std::string>::const_iterator it = propDefinitions.begin();
00494        it != propDefinitions.end(); it++) {
00495     std::string def = *it;
00496     std::istringstream in(def);
00497     in >> properties;
00498   }
00499 
00500   std::cout << "Properties:" << std::endl << properties << std::endl;
00501 
00502   // Invoke the program.
00503   return sampler(argc, argv, 
00504      visualize.supplied(),
00505      toolbar.supplied(),
00506      refresh.value(),
00507      properties);
00508 }

Generated on Wed May 25 14:39:18 2005 for Arak by doxygen 1.3.6