infer.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #include <unistd.h>
00067 #include <fstream>
00068 #include <climits>
00069 #include <sys/times.h>
00070 #include "fol.h"
00071 #include "arguments.h"
00072 #include "util.h"
00073 //#include "learnwts.h"
00074 #include "infer.h"
00075 
00076 extern const char* ZZ_TMP_FILE_POSTFIX; //defined in fol.y
00077 
00078 
00079   // TODO: List the arguments common to learnwts and inference in
00080   // inferenceargs.h. This can't be done with a static array.
00081 ARGS ARGS::Args[] = 
00082 {
00083     // BEGIN: Common arguments
00084   ARGS("i", ARGS::Req, ainMLNFiles, 
00085        "Comma-separated input .mln files."),
00086 
00087   ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00088        "Specified non-evidence atoms (comma-separated with no space) are "
00089        "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00090        "appearing here cannot be query atoms and cannot appear in the -o "
00091        "option."),
00092 
00093   ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00094        "Specified evidence atoms (comma-separated with no space) are open "
00095        "world, while other evidence atoms are closed-world. "
00096        "Atoms appearing here cannot appear in the -c option."),
00097     // END: Common arguments
00098 
00099     // BEGIN: Common inference arguments
00100   ARGS("m", ARGS::Tog, amapPos, 
00101        "Run MAP inference and return only positive query atoms."),
00102 
00103   ARGS("a", ARGS::Tog, amapAll, 
00104        "Run MAP inference and show 0/1 results for all query atoms."),
00105 
00106   ARGS("p", ARGS::Tog, agibbsInfer, 
00107        "Run inference using MCMC (Gibbs sampling) and return probabilities "
00108        "for all query atoms."),
00109   
00110   ARGS("ms", ARGS::Tog, amcsatInfer,
00111        "Run inference using MC-SAT and return probabilities "
00112        "for all query atoms"),
00113 
00114   ARGS("simtp", ARGS::Tog, asimtpInfer,
00115        "Run inference using simulated tempering and return probabilities "
00116        "for all query atoms"),
00117 
00118   ARGS("seed", ARGS::Opt, aSeed,
00119        "[random] Seed used to initialize the randomizer in the inference "
00120        "algorithm. If not set, seed is initialized from the current date and "
00121        "time."),
00122 
00123   ARGS("lazy", ARGS::Opt, aLazy, 
00124        "[false] Run lazy version of inference if this flag is set."),
00125   
00126   ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox, 
00127        "[false] Lazy version of inference will not approximate by deactivating "
00128        "atoms to save memory. This flag is ignored if -lazy is not set."),
00129   
00130   ARGS("memLimit", ARGS::Opt, aMemLimit, 
00131        "[-1] Maximum limit in kbytes which should be used for inference. "
00132        "-1 means main memory available on system is used."),
00133     // END: Common inference arguments
00134 
00135     // BEGIN: MaxWalkSat args
00136   ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00137        "[1000000] (MaxWalkSat) The max number of steps taken."),
00138 
00139   ARGS("tries", ARGS::Opt, amwsTries, 
00140        "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00141 
00142   ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00143        "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00144        "with weight <= specified weight."),
00145 
00146   ARGS("hard", ARGS::Opt, amwsHard, 
00147        "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00148        "satisfy a soft one."),
00149   
00150   ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00151        "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00152        "2 = TABU, 3 = SAMPLESAT)."),
00153   
00154   ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00155        "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00156        "atom when using the tabu heuristic in MaxWalkSat." ),
00157 
00158   ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState, 
00159        "[false] (MaxWalkSat) If false, the naive way of saving low states "
00160        "(each time a low state is found, the whole state is saved) is used; "
00161        "otherwise, a list of variables flipped since the last low state is "
00162        "kept and the low state is reconstructed. This can be much faster for "
00163        "very large data sets."),  
00164     // END: MaxWalkSat args
00165 
00166     // BEGIN: MCMC args
00167   ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00168        "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00169 
00170   ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00171        "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00172 
00173   ARGS("minSteps", ARGS::Opt, amcmcMinSteps, 
00174        "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00175 
00176   ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps, 
00177        "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00178 
00179   ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds, 
00180        "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00181     // END: MCMC args
00182   
00183     // BEGIN: Simulated tempering args
00184   ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00185         "[2] (Simulated Tempering) Selection interval between swap attempts"),
00186 
00187   ARGS("numRuns", ARGS::Opt, asimtpNumST,
00188         "[3] (Simulated Tempering) Number of simulated tempering runs"),
00189 
00190   ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00191         "[10] (Simulated Tempering) Number of swapping chains"),
00192     // END: Simulated tempering args
00193 
00194     // BEGIN: MC-SAT args
00195   ARGS("numStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00196        "[1] (MC-SAT) Number of total steps (mcsat + gibbs) for every mcsat "
00197        "step"),
00198     // END: MC-SAT args
00199 
00200     // BEGIN: SampleSat args
00201   ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00202        "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00203 
00204   ARGS("saRatio", ARGS::Opt, assSaRatio,
00205        "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00206        "MC-SAT"),
00207 
00208   ARGS("saTemperature", ARGS::Opt, assSaTemp,
00209         "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00210         "SampleSat"),
00211 
00212   ARGS("lateSa", ARGS::Tog, assLateSa,
00213        "[false] Run simulated annealing from the start in SampleSat"),
00214     // END: SampleSat args
00215 
00216     // BEGIN: Gibbs sampling args
00217   ARGS("numChains", ARGS::Opt, amcmcNumChains, 
00218        "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00219        "at least 2)."),
00220 
00221   ARGS("delta", ARGS::Opt, agibbsDelta,
00222        "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00223        "exceeded is less than this value."),
00224 
00225   ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00226        "[0.01] (Gibbs) Fractional error from true probability."),
00227 
00228   ARGS("fracConverged", ARGS::Opt, agibbsFracConverged, 
00229        "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00230        "have converged."),
00231 
00232   ARGS("walksatType", ARGS::Opt, agibbsWalksatType, 
00233        "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00234        "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00235 
00236   ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest, 
00237        "[100] Perform convergence test once after this many number of samples "
00238        "per chain."),
00239     // END: Gibbs sampling args
00240 
00241     // BEGIN: Args specific to stand-alone inference
00242   ARGS("e", ARGS::Req, aevidenceFiles, 
00243        "Comma-separated .db files containing known ground atoms (evidence), "
00244        "including function definitions."),
00245 
00246   ARGS("r", ARGS::Req, aresultsFile,
00247        "The probability estimates are written to this file."),
00248     
00249   ARGS("q", ARGS::Opt, aqueryPredsStr, 
00250        "Query atoms (comma-separated with no space)  "
00251        ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00252        "open world."),
00253 
00254   ARGS("f", ARGS::Opt, aqueryFile,
00255        "A .db file containing ground query atoms, "
00256        "which are are always open world."),
00257     // END: Args specific to stand-alone inference
00258 
00259   ARGS()
00260 };
00261 
00262 
00278 void printResults(const string& queryFile, const string& queryPredsStr,
00279                   Domain *domain, ostream& out, 
00280                   GroundPredicateHashArray* const &queries,
00281                   Inference* const &inference, VariableState* const &state)
00282 {
00283     // Lazy version: Have to generate the queries from the file or query string.
00284     // This involves calling createQueryFilePreds / createComLineQueryPreds
00285   if (aLazy)
00286   {
00287     const GroundPredicateHashArray* gndPredHashArray = NULL;
00288     Array<double>* gndPredProbs = NULL;
00289       // Inference algorithms with probs: have to retrieve this info from state.
00290       // These are the ground preds which have been brought into memory. All
00291       // others have always been false throughout sampling.
00292     if (!(amapPos || amapAll))
00293     {
00294       gndPredHashArray = state->getGndPredHashArrayPtr();
00295       gndPredProbs = new Array<double>;
00296       gndPredProbs->growToSize(gndPredHashArray->size());
00297       for (int i = 0; i < gndPredProbs->size(); i++)
00298         (*gndPredProbs)[i] =
00299           inference->getProbability((*gndPredHashArray)[i]);
00300     }
00301     
00302     if (queryFile.length() > 0)
00303     {
00304       cout << "Writing query predicates that are specified in query file..."
00305            << endl;
00306       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(), NULL,
00307                                      NULL, true, out, amapPos,
00308                                      gndPredHashArray, gndPredProbs);
00309       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00310     }
00311 
00312     Array<int> allPredGndingsAreQueries;
00313     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00314     if (queryPredsStr.length() > 0)
00315     {
00316       cout << "Writing query predicates that are specified on command line..." 
00317            << endl;
00318       bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(), 
00319                                         NULL, NULL, &allPredGndingsAreQueries,
00320                                         true, out, amapPos, gndPredHashArray,
00321                                         gndPredProbs);
00322       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00323     }
00324     
00325     if (!(amapPos || amapAll))
00326       delete gndPredProbs;
00327   }
00328     // Eager version: Queries have already been generated and we can get the
00329     // information directly from the state
00330   else
00331   {
00332     if (amapPos)
00333       inference->printTruePreds(out);
00334     else
00335     {
00336       for (int i = 0; i < queries->size(); i++)
00337       {
00338           // Prob is smoothed in inference->getProbability
00339         double prob = inference->getProbability((*queries)[i]);
00340         (*queries)[i]->print(out, domain); out << " " << prob << endl;
00341       }
00342     }
00343   }
00344 }
00345 
00346 
00355 int main(int argc, char* argv[])
00356 {
00358   ARGS::parse(argc, argv, &cout);
00359   Timer timer;
00360   double begSec = timer.time(); 
00361 
00362   ofstream resultsOut(aresultsFile);
00363   if (!resultsOut.good())
00364   { cout << "ERROR: unable to open " << aresultsFile << endl; return -1; }
00365 
00366   Domain* domain = NULL;
00367   Inference* inference = NULL;
00368   if (buildInference(inference, domain))
00369   {
00370     inference->init();
00371     inference->infer();
00372     
00373     printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00374                  inference, inference->getState());
00375   }
00376 
00377   resultsOut.close();
00378   delete domain;
00379   for (int i = 0; i < knownQueries.size(); i++)  delete knownQueries[i];
00380   delete inference;
00381   
00382   cout << "total time taken = "; Timer::printTime(cout, timer.time()-begSec);
00383   cout << endl;
00384 }
00385 

Generated on Wed Feb 14 15:15:17 2007 for Alchemy by  doxygen 1.5.1