00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #include <unistd.h>
00067 #include <fstream>
00068 #include <climits>
00069 #include <sys/times.h>
00070 #include "fol.h"
00071 #include "arguments.h"
00072 #include "util.h"
00073
00074 #include "infer.h"
00075
00076 extern const char* ZZ_TMP_FILE_POSTFIX;
00077
00078
00079
00080
00081 ARGS ARGS::Args[] =
00082 {
00083
00084 ARGS("i", ARGS::Req, ainMLNFiles,
00085 "Comma-separated input .mln files."),
00086
00087 ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00088 "Specified non-evidence atoms (comma-separated with no space) are "
00089 "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00090 "appearing here cannot be query atoms and cannot appear in the -o "
00091 "option."),
00092
00093 ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00094 "Specified evidence atoms (comma-separated with no space) are open "
00095 "world, while other evidence atoms are closed-world. "
00096 "Atoms appearing here cannot appear in the -c option."),
00097
00098
00099
00100 ARGS("m", ARGS::Tog, amapPos,
00101 "Run MAP inference and return only positive query atoms."),
00102
00103 ARGS("a", ARGS::Tog, amapAll,
00104 "Run MAP inference and show 0/1 results for all query atoms."),
00105
00106 ARGS("p", ARGS::Tog, agibbsInfer,
00107 "Run inference using MCMC (Gibbs sampling) and return probabilities "
00108 "for all query atoms."),
00109
00110 ARGS("ms", ARGS::Tog, amcsatInfer,
00111 "Run inference using MC-SAT and return probabilities "
00112 "for all query atoms"),
00113
00114 ARGS("simtp", ARGS::Tog, asimtpInfer,
00115 "Run inference using simulated tempering and return probabilities "
00116 "for all query atoms"),
00117
00118 ARGS("seed", ARGS::Opt, aSeed,
00119 "[random] Seed used to initialize the randomizer in the inference "
00120 "algorithm. If not set, seed is initialized from the current date and "
00121 "time."),
00122
00123 ARGS("lazy", ARGS::Opt, aLazy,
00124 "[false] Run lazy version of inference if this flag is set."),
00125
00126 ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox,
00127 "[false] Lazy version of inference will not approximate by deactivating "
00128 "atoms to save memory. This flag is ignored if -lazy is not set."),
00129
00130 ARGS("memLimit", ARGS::Opt, aMemLimit,
00131 "[-1] Maximum limit in kbytes which should be used for inference. "
00132 "-1 means main memory available on system is used."),
00133
00134
00135
00136 ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00137 "[1000000] (MaxWalkSat) The max number of steps taken."),
00138
00139 ARGS("tries", ARGS::Opt, amwsTries,
00140 "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00141
00142 ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00143 "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00144 "with weight <= specified weight."),
00145
00146 ARGS("hard", ARGS::Opt, amwsHard,
00147 "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00148 "satisfy a soft one."),
00149
00150 ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00151 "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00152 "2 = TABU, 3 = SAMPLESAT)."),
00153
00154 ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00155 "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00156 "atom when using the tabu heuristic in MaxWalkSat." ),
00157
00158 ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState,
00159 "[false] (MaxWalkSat) If false, the naive way of saving low states "
00160 "(each time a low state is found, the whole state is saved) is used; "
00161 "otherwise, a list of variables flipped since the last low state is "
00162 "kept and the low state is reconstructed. This can be much faster for "
00163 "very large data sets."),
00164
00165
00166
00167 ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00168 "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00169
00170 ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00171 "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00172
00173 ARGS("minSteps", ARGS::Opt, amcmcMinSteps,
00174 "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00175
00176 ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps,
00177 "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00178
00179 ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds,
00180 "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00181
00182
00183
00184 ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00185 "[2] (Simulated Tempering) Selection interval between swap attempts"),
00186
00187 ARGS("numRuns", ARGS::Opt, asimtpNumST,
00188 "[3] (Simulated Tempering) Number of simulated tempering runs"),
00189
00190 ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00191 "[10] (Simulated Tempering) Number of swapping chains"),
00192
00193
00194
00195 ARGS("numStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00196 "[1] (MC-SAT) Number of total steps (mcsat + gibbs) for every mcsat "
00197 "step"),
00198
00199
00200
00201 ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00202 "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00203
00204 ARGS("saRatio", ARGS::Opt, assSaRatio,
00205 "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00206 "MC-SAT"),
00207
00208 ARGS("saTemperature", ARGS::Opt, assSaTemp,
00209 "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00210 "SampleSat"),
00211
00212 ARGS("lateSa", ARGS::Tog, assLateSa,
00213 "[false] Run simulated annealing from the start in SampleSat"),
00214
00215
00216
00217 ARGS("numChains", ARGS::Opt, amcmcNumChains,
00218 "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00219 "at least 2)."),
00220
00221 ARGS("delta", ARGS::Opt, agibbsDelta,
00222 "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00223 "exceeded is less than this value."),
00224
00225 ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00226 "[0.01] (Gibbs) Fractional error from true probability."),
00227
00228 ARGS("fracConverged", ARGS::Opt, agibbsFracConverged,
00229 "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00230 "have converged."),
00231
00232 ARGS("walksatType", ARGS::Opt, agibbsWalksatType,
00233 "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00234 "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00235
00236 ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest,
00237 "[100] Perform convergence test once after this many number of samples "
00238 "per chain."),
00239
00240
00241
00242 ARGS("e", ARGS::Req, aevidenceFiles,
00243 "Comma-separated .db files containing known ground atoms (evidence), "
00244 "including function definitions."),
00245
00246 ARGS("r", ARGS::Req, aresultsFile,
00247 "The probability estimates are written to this file."),
00248
00249 ARGS("q", ARGS::Opt, aqueryPredsStr,
00250 "Query atoms (comma-separated with no space) "
00251 ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00252 "open world."),
00253
00254 ARGS("f", ARGS::Opt, aqueryFile,
00255 "A .db file containing ground query atoms, "
00256 "which are are always open world."),
00257
00258
00259 ARGS()
00260 };
00261
00262
00278 void printResults(const string& queryFile, const string& queryPredsStr,
00279 Domain *domain, ostream& out,
00280 GroundPredicateHashArray* const &queries,
00281 Inference* const &inference, VariableState* const &state)
00282 {
00283
00284
00285 if (aLazy)
00286 {
00287 const GroundPredicateHashArray* gndPredHashArray = NULL;
00288 Array<double>* gndPredProbs = NULL;
00289
00290
00291
00292 if (!(amapPos || amapAll))
00293 {
00294 gndPredHashArray = state->getGndPredHashArrayPtr();
00295 gndPredProbs = new Array<double>;
00296 gndPredProbs->growToSize(gndPredHashArray->size());
00297 for (int i = 0; i < gndPredProbs->size(); i++)
00298 (*gndPredProbs)[i] =
00299 inference->getProbability((*gndPredHashArray)[i]);
00300 }
00301
00302 if (queryFile.length() > 0)
00303 {
00304 cout << "Writing query predicates that are specified in query file..."
00305 << endl;
00306 bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(), NULL,
00307 NULL, true, out, amapPos,
00308 gndPredHashArray, gndPredProbs);
00309 if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00310 }
00311
00312 Array<int> allPredGndingsAreQueries;
00313 allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00314 if (queryPredsStr.length() > 0)
00315 {
00316 cout << "Writing query predicates that are specified on command line..."
00317 << endl;
00318 bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(),
00319 NULL, NULL, &allPredGndingsAreQueries,
00320 true, out, amapPos, gndPredHashArray,
00321 gndPredProbs);
00322 if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00323 }
00324
00325 if (!(amapPos || amapAll))
00326 delete gndPredProbs;
00327 }
00328
00329
00330 else
00331 {
00332 if (amapPos)
00333 inference->printTruePreds(out);
00334 else
00335 {
00336 for (int i = 0; i < queries->size(); i++)
00337 {
00338
00339 double prob = inference->getProbability((*queries)[i]);
00340 (*queries)[i]->print(out, domain); out << " " << prob << endl;
00341 }
00342 }
00343 }
00344 }
00345
00346
00355 int main(int argc, char* argv[])
00356 {
00358 ARGS::parse(argc, argv, &cout);
00359 Timer timer;
00360 double begSec = timer.time();
00361
00362 ofstream resultsOut(aresultsFile);
00363 if (!resultsOut.good())
00364 { cout << "ERROR: unable to open " << aresultsFile << endl; return -1; }
00365
00366 Domain* domain = NULL;
00367 Inference* inference = NULL;
00368 if (buildInference(inference, domain))
00369 {
00370 inference->init();
00371 inference->infer();
00372
00373 printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00374 inference, inference->getState());
00375 }
00376
00377 resultsOut.close();
00378 delete domain;
00379 for (int i = 0; i < knownQueries.size(); i++) delete knownQueries[i];
00380 delete inference;
00381
00382 cout << "total time taken = "; Timer::printTime(cout, timer.time()-begSec);
00383 cout << endl;
00384 }
00385