00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #include <fstream>
00067 #include <iostream>
00068 #include <sstream>
00069 #include "arguments.h"
00070 #include "inferenceargs.h"
00071 #include "lbfgsb.h"
00072 #include "votedperceptron.h"
00073 #include "learnwts.h"
00074 #include "maxwalksat.h"
00075 #include "mcsat.h"
00076 #include "gibbssampler.h"
00077 #include "simulatedtempering.h"
00078
00079
00080
00081 bool PRINT_CLAUSE_DURING_COUNT = true;
00082
00083 const double DISC_DEFAULT_STD_DEV = 1;
00084 const double GEN_DEFAULT_STD_DEV = 100;
00085
00086
00087 bool discLearn = false;
00088 bool genLearn = false;
00089 char* outMLNFile = NULL;
00090 char* dbFiles = NULL;
00091 char* nonEvidPredsStr = NULL;
00092 bool noAddUnitClauses = false;
00093 bool multipleDatabases = false;
00094 bool initToZero = false;
00095 bool isQueryEvidence = false;
00096
00097 bool noPrior = false;
00098 double priorMean = 0;
00099 double priorStdDev = -1;
00100
00101
00102 int maxIter = 10000;
00103 double convThresh = 1e-5;
00104 bool noEqualPredWt = false;
00105
00106
00107 int numIter = 200;
00108 double learningRate = 0.001;
00109 double momentum = 0.0;
00110 bool rescaleGradient = false;
00111 bool withEM = false;
00112 char* aInferStr = NULL;
00113 int amwsMaxSubsequentSteps = -1;
00114
00115
00116
00117
00118
00119 ARGS ARGS::Args[] =
00120 {
00121
00122 ARGS("i", ARGS::Req, ainMLNFiles,
00123 "Comma-separated input .mln files. (With the -multipleDatabases "
00124 "option, the second file to the last one are used to contain constants "
00125 "from different databases, and they correspond to the .db files "
00126 "specified with the -t option.)"),
00127
00128 ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00129 "Specified non-evidence atoms (comma-separated with no space) are "
00130 "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00131 "appearing here cannot be query atoms and cannot appear in the -o "
00132 "option."),
00133
00134 ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00135 "Specified evidence atoms (comma-separated with no space) are open "
00136 "world, while other evidence atoms are closed-world. "
00137 "Atoms appearing here cannot appear in the -c option."),
00138
00139
00140
00141 ARGS("m", ARGS::Tog, amapPos,
00142 "Run MAP inference and return only positive query atoms."),
00143
00144 ARGS("a", ARGS::Tog, amapAll,
00145 "Run MAP inference and show 0/1 results for all query atoms."),
00146
00147 ARGS("p", ARGS::Tog, agibbsInfer,
00148 "Run inference using MCMC (Gibbs sampling) and return probabilities "
00149 "for all query atoms."),
00150
00151 ARGS("ms", ARGS::Tog, amcsatInfer,
00152 "Run inference using MC-SAT and return probabilities "
00153 "for all query atoms"),
00154
00155 ARGS("simtp", ARGS::Tog, asimtpInfer,
00156 "Run inference using simulated tempering and return probabilities "
00157 "for all query atoms"),
00158
00159 ARGS("seed", ARGS::Opt, aSeed,
00160 "[random] Seed used to initialize the randomizer in the inference "
00161 "algorithm. If not set, seed is initialized from the current date and "
00162 "time."),
00163
00164 ARGS("lazy", ARGS::Opt, aLazy,
00165 "[false] Run lazy version of inference if this flag is set."),
00166
00167 ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox,
00168 "[false] Lazy version of inference will not approximate by deactivating "
00169 "atoms to save memory. This flag is ignored if -lazy is not set."),
00170
00171 ARGS("memLimit", ARGS::Opt, aMemLimit,
00172 "[-1] Maximum limit in kbytes which should be used for inference. "
00173 "-1 means main memory available on system is used."),
00174
00175
00176
00177 ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00178 "[1000000] (MaxWalkSat) The max number of steps taken."),
00179
00180 ARGS("tries", ARGS::Opt, amwsTries,
00181 "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00182
00183 ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00184 "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00185 "with weight <= specified weight."),
00186
00187 ARGS("hard", ARGS::Opt, amwsHard,
00188 "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00189 "satisfy a soft one."),
00190
00191 ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00192 "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00193 "2 = TABU, 3 = SAMPLESAT)."),
00194
00195 ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00196 "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00197 "atom when using the tabu heuristic in MaxWalkSat." ),
00198
00199 ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState,
00200 "[false] (MaxWalkSat) If false, the naive way of saving low states "
00201 "(each time a low state is found, the whole state is saved) is used; "
00202 "otherwise, a list of variables flipped since the last low state is "
00203 "kept and the low state is reconstructed. This can be much faster for "
00204 "very large data sets."),
00205
00206
00207
00208 ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00209 "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00210
00211 ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00212 "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00213
00214 ARGS("minSteps", ARGS::Opt, amcmcMinSteps,
00215 "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00216
00217 ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps,
00218 "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00219
00220 ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds,
00221 "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00222
00223
00224
00225 ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00226 "[2] (Simulated Tempering) Selection interval between swap attempts"),
00227
00228 ARGS("numRuns", ARGS::Opt, asimtpNumST,
00229 "[3] (Simulated Tempering) Number of simulated tempering runs"),
00230
00231 ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00232 "[10] (Simulated Tempering) Number of swapping chains"),
00233
00234
00235
00236 ARGS("numStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00237 "[1] (MC-SAT) Number of total steps (mcsat + gibbs) for every mcsat "
00238 "step"),
00239
00240
00241
00242 ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00243 "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00244
00245 ARGS("saRatio", ARGS::Opt, assSaRatio,
00246 "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00247 "MC-SAT"),
00248
00249 ARGS("saTemperature", ARGS::Opt, assSaTemp,
00250 "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00251 "SampleSat"),
00252
00253 ARGS("lateSa", ARGS::Tog, assLateSa,
00254 "[false] Run simulated annealing from the start in SampleSat"),
00255
00256
00257
00258 ARGS("numChains", ARGS::Opt, amcmcNumChains,
00259 "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00260 "at least 2)."),
00261
00262 ARGS("delta", ARGS::Opt, agibbsDelta,
00263 "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00264 "exceeded is less than this value."),
00265
00266 ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00267 "[0.01] (Gibbs) Fractional error from true probability."),
00268
00269 ARGS("fracConverged", ARGS::Opt, agibbsFracConverged,
00270 "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00271 "have converged."),
00272
00273 ARGS("walksatType", ARGS::Opt, agibbsWalksatType,
00274 "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00275 "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00276
00277 ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest,
00278 "[100] Perform convergence test once after this many number of samples "
00279 "per chain."),
00280
00281
00282
00283 ARGS("infer", ARGS::Opt, aInferStr,
00284 "Specified inference parameters when using discriminative learning. "
00285 "The arguments are to be encapsulated in \"\" and the syntax is "
00286 "identical to the infer command (run infer with no commands to see "
00287 "this). If not specified, MaxWalkSat with default parameters is used."),
00288
00289 ARGS("d", ARGS::Tog, discLearn, "Discriminative weight learning."),
00290
00291 ARGS("g", ARGS::Tog, genLearn, "Generative weight learning."),
00292
00293 ARGS("o", ARGS::Req, outMLNFile,
00294 "Output .mln file containing formulas with learned weights."),
00295
00296 ARGS("t", ARGS::Req, dbFiles,
00297 "Comma-separated .db files containing the training database "
00298 "(of true/false ground atoms), including function definitions, "
00299 "e.g. ai.db,graphics.db,languages.db."),
00300
00301 ARGS("ne", ARGS::Opt, nonEvidPredsStr,
00302 "First-order non-evidence predicates (comma-separated with no space), "
00303 "e.g., cancer,smokes,friends. For discriminative learning, at least "
00304 "one non-evidence predicate must be specified. For generative learning, "
00305 "the specified predicates are included in the (weighted) pseudo-log-"
00306 "likelihood computation; if none are specified, all are included."),
00307
00308 ARGS("noAddUnitClauses", ARGS::Tog, noAddUnitClauses,
00309 "If specified, unit clauses are not included in the .mln file; "
00310 "otherwise they are included."),
00311
00312 ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00313 "If specified, each .db file belongs to a separate database; "
00314 "otherwise all .db files belong to the same database."),
00315
00316 ARGS("withEM", ARGS::Tog, withEM,
00317 "If set, EM is used to fill in missing truth values; "
00318 "otherwise missing truth values are set to false."),
00319
00320 ARGS("dNumIter", ARGS::Opt, numIter,
00321 "[200] (For discriminative learning only.) "
00322 "Number of iterations to run voted perceptron."),
00323
00324 ARGS("dLearningRate", ARGS::Opt, learningRate,
00325 "[0.001] (For discriminative learning only) "
00326 "Learning rate for the gradient descent in voted perceptron algorithm."),
00327
00328 ARGS("dMomentum", ARGS::Opt, momentum,
00329 "[0.0] (For discriminative learning only) "
00330 "Momentum term for the gradient descent in voted perceptron algorithm."),
00331
00332 ARGS("queryEvidence", ARGS::Tog, isQueryEvidence,
00333 "If this flag is set, then all the groundings of query preds not in db "
00334 "are assumed false evidence."),
00335
00336 ARGS("dRescale", ARGS::Tog, rescaleGradient,
00337 "(For discriminative learning only.) "
00338 "Rescale the gradient by the number of true groundings per weight."),
00339
00340 ARGS("dZeroInit", ARGS::Tog, initToZero,
00341 "(For discriminative learning only.) "
00342 "Initialize clause weights to zero instead of their log odds."),
00343
00344 ARGS("dMwsMaxSubsequentSteps", ARGS::Opt, amwsMaxSubsequentSteps,
00345 "[Same as mwsMaxSteps] (For discriminative learning only.) The max "
00346 "number of MaxWalkSat steps taken in subsequent iterations (>= 2) of "
00347 "disc. learning. If not specified, mwsMaxSteps is used in each "
00348 "iteration"),
00349
00350 ARGS("gMaxIter", ARGS::Opt, maxIter,
00351 "[10000] (For generative learning only.) "
00352 "Max number of iterations to run L-BFGS-B, "
00353 "the optimization algorithm for generative learning."),
00354
00355 ARGS("gConvThresh", ARGS::Opt, convThresh,
00356 "[1e-5] (For generative learning only.) "
00357 "Fractional change in pseudo-log-likelihood at which "
00358 "L-BFGS-B terminates."),
00359
00360 ARGS("gNoEqualPredWt", ARGS::Opt, noEqualPredWt,
00361 "(For generative learning only.) "
00362 "If specified, the predicates are not weighted equally in the "
00363 "pseudo-log-likelihood computation; otherwise they are."),
00364
00365 ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00366
00367 ARGS("priorMean", ARGS::Opt, priorMean,
00368 "[0] Means of Gaussian priors on formula weights. By default, "
00369 "for each formula, it is the weight given in the .mln input file, "
00370 "or fraction thereof if the formula turns into multiple clauses. "
00371 "This mean applies if no weight is given in the .mln file."),
00372
00373 ARGS("priorStdDev", ARGS::Opt, priorStdDev,
00374 "[1 for discriminative learning. 100 for generative learning] "
00375 "Standard deviations of Gaussian priors on clause weights."),
00376
00377 ARGS()
00378 };
00379
00380
00381
00382 int main(int argc, char* argv[])
00383 {
00384 ARGS::parse(argc,argv,&cout);
00385
00386 if (!discLearn && !genLearn)
00387 {
00388 cout << "must specify either -d or -g "
00389 <<"(discriminative or generative learning) " << endl;
00390 return -1;
00391 }
00392
00393 Timer timer;
00394 double startSec = timer.time();
00395 double begSec;
00396
00397 if (priorStdDev < 0)
00398 {
00399 if (discLearn)
00400 {
00401 cout << "priorStdDev set to (discriminative learning's) default of "
00402 << DISC_DEFAULT_STD_DEV << endl;
00403 priorStdDev = DISC_DEFAULT_STD_DEV;
00404 }
00405 else
00406 {
00407 cout << "priorStdDev set to (generative learning's) default of "
00408 << GEN_DEFAULT_STD_DEV << endl;
00409 priorStdDev = GEN_DEFAULT_STD_DEV;
00410 }
00411 }
00412
00413
00415 if (discLearn && nonEvidPredsStr == NULL)
00416 {
00417 cout << "ERROR: you must specify non-evidence predicates for "
00418 << "discriminative learning" << endl;
00419 return -1;
00420 }
00421
00422 if (maxIter <= 0) { cout << "maxIter must be > 0" << endl; return -1; }
00423 if (convThresh <= 0 || convThresh > 1)
00424 { cout << "convThresh must be > 0 and <= 1" << endl; return -1; }
00425 if (priorStdDev <= 0) { cout << "priorStdDev must be > 0" << endl; return -1;}
00426
00427 if (amwsMaxSteps <= 0)
00428 { cout << "ERROR: maxSteps must be positive" << endl; return -1; }
00429
00430
00431 if (amwsMaxSubsequentSteps <= 0) amwsMaxSubsequentSteps = amwsMaxSteps;
00432
00433 if (amwsTries <= 0)
00434 { cout << "ERROR: tries must be positive" << endl; return -1; }
00435
00436 if (aMemLimit <= 0 && aMemLimit != -1)
00437 { cout << "ERROR: limit must be positive (or -1)" << endl; return -1; }
00438
00439 if (!discLearn && aLazy)
00440 {
00441 cout << "ERROR: lazy can only be used with discriminative learning"
00442 << endl;
00443 return -1;
00444 }
00445
00446 ofstream out(outMLNFile);
00447 if (!out.good())
00448 {
00449 cout << "ERROR: unable to open " << outMLNFile << endl;
00450 return -1;
00451 }
00452
00453
00454 if (discLearn)
00455 {
00456
00457 if (!aInferStr)
00458 {
00459 amapPos = true;
00460 }
00461
00462 else
00463 {
00464 int inferArgc = 0;
00465 char **inferArgv = new char*[200];
00466 for (int i = 0; i < 200; i++)
00467 {
00468 inferArgv[i] = new char[30];
00469 }
00470
00471 extractArgs(aInferStr, inferArgc, inferArgv);
00472 cout << "extractArgs " << inferArgc << endl;
00473 for (int i = 0; i < inferArgc; i++)
00474 {
00475 cout << i << ": " << inferArgv[i] << endl;
00476 }
00477
00478 ARGS::parseFromCommandLine(inferArgc, inferArgv);
00479
00480
00481
00482 for (int i = 0; i < inferArgc; i++)
00483 {
00484 if (string(inferArgv[i]) == "-m") amapPos = true;
00485 else if (string(inferArgv[i]) == "-a") amapAll = true;
00486 else if (string(inferArgv[i]) == "-p") agibbsInfer = true;
00487 else if (string(inferArgv[i]) == "-ms") amcsatInfer = true;
00488 else if (string(inferArgv[i]) == "-simtp") asimtpInfer = true;
00489 }
00490
00491
00492 for (int i = 0; i < 200; i++)
00493 {
00494 delete[] inferArgv[i];
00495 }
00496 delete[] inferArgv;
00497 }
00498 }
00499
00500
00501
00502
00503
00504
00505
00506 Array<string> constFilesArr;
00507 Array<string> dbFilesArr;
00508 extractFileNames(ainMLNFiles, constFilesArr);
00509 assert(constFilesArr.size() >= 1);
00510 string inMLNFile = constFilesArr[0];
00511 constFilesArr.removeItem(0);
00512 extractFileNames(dbFiles, dbFilesArr);
00513
00514 if (dbFilesArr.size() <= 0)
00515 {cout<<"ERROR: must specify training data with -t option."<<endl; return -1;}
00516
00517
00518 if (multipleDatabases)
00519 {
00520
00521 if ((constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00522 {
00523 cout << "ERROR: when there are multiple databases, if .mln files "
00524 << "containing constants are specified, there must "
00525 << "be the same number of them as .db files" << endl;
00526 return -1;
00527 }
00528 }
00529
00530 StringHashArray nonEvidPredNames;
00531 if (nonEvidPredsStr)
00532 {
00533 if(!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames))
00534 {
00535 cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00536 return -1;
00537 }
00538 }
00539
00540 StringHashArray owPredNames;
00541 StringHashArray cwPredNames;
00542
00544
00545 cout << "Parsing MLN and creating domains..." << endl;
00546 StringHashArray* nePredNames = (discLearn) ? &nonEvidPredNames : NULL;
00547 Array<Domain*> domains;
00548 Array<MLN*> mlns;
00549 begSec = timer.time();
00550 bool allPredsExceptQueriesAreCW = true;
00551 if (discLearn)
00552 {
00553
00554 if (aOpenWorldPredsStr)
00555 {
00556 if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames))
00557 return -1;
00558 assert(owPredNames.size() > 0);
00559 }
00560
00561
00562 if (aClosedWorldPredsStr)
00563 {
00564 if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames))
00565 return -1;
00566 assert(cwPredNames.size() > 0);
00567 if (!checkQueryPredsNotInClosedWorldPreds(nonEvidPredNames, cwPredNames))
00568 return -1;
00569 }
00570
00571 allPredsExceptQueriesAreCW = owPredNames.empty();
00572 }
00573
00574
00575 createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile,
00576 constFilesArr, dbFilesArr, nePredNames,
00577 !noAddUnitClauses, priorMean, true,
00578 allPredsExceptQueriesAreCW, &owPredNames);
00579 cout << "Parsing MLN and creating domains took ";
00580 Timer::printTime(cout, timer.time() - begSec); cout << endl;
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595
00597
00598
00599 IndexTranslator* indexTrans
00600 = (IndexTranslator::needIndexTranslator(mlns, domains)) ?
00601 new IndexTranslator(&mlns, &domains) : NULL;
00602
00603 if (indexTrans)
00604 cout << endl << "the weights of clauses in the CNFs of existential"
00605 << " formulas will be tied" << endl;
00606
00607
00608 Array<double> priorMeans, priorStdDevs;
00609 if (!noPrior)
00610 {
00611 if (indexTrans)
00612 {
00613 indexTrans->setPriorMeans(priorMeans);
00614 priorStdDevs.growToSize(priorMeans.size());
00615 for (int i = 0; i < priorMeans.size(); i++)
00616 priorStdDevs[i] = priorStdDev;
00617 }
00618 else
00619 {
00620 const ClauseHashArray* clauses = mlns[0]->getClauses();
00621 int numClauses = clauses->size();
00622 for (int i = 0; i < numClauses; i++)
00623 {
00624 priorMeans.append((*clauses)[i]->getWt());
00625 priorStdDevs.append(priorStdDev);
00626 }
00627 }
00628 }
00629
00630
00631
00632 int numClausesFormulas = mlns[0]->getClauses()->size();
00633
00634
00636 Array<double> wts;
00637
00638
00639 if (discLearn)
00640 {
00641 wts.growToSize(numClausesFormulas + 1);
00642 double* wwts = (double*) wts.getItems();
00643 wwts++;
00644
00645 string nePredsStr = nonEvidPredsStr;
00646
00647
00648 SampleSatParams* ssparams = new SampleSatParams;
00649 ssparams->lateSa = assLateSa;
00650 ssparams->saRatio = assSaRatio;
00651 ssparams->saTemp = assSaTemp;
00652
00653
00654 MaxWalksatParams* mwsparams = NULL;
00655 mwsparams = new MaxWalksatParams;
00656 mwsparams->ssParams = ssparams;
00657 mwsparams->maxSteps = amwsMaxSteps;
00658 mwsparams->maxTries = amwsTries;
00659 mwsparams->targetCost = amwsTargetWt;
00660 mwsparams->hard = amwsHard;
00661
00662
00663 mwsparams->numSolutions = amwsNumSolutions;
00664 mwsparams->heuristic = amwsHeuristic;
00665 mwsparams->tabuLength = amwsTabuLength;
00666 mwsparams->lazyLowState = amwsLazyLowState;
00667
00668
00669 MCSatParams* msparams = new MCSatParams;
00670 msparams->mwsParams = mwsparams;
00671
00672 msparams->numChains = 1;
00673 msparams->burnMinSteps = amcmcBurnMinSteps;
00674 msparams->burnMaxSteps = amcmcBurnMaxSteps;
00675 msparams->minSteps = amcmcMinSteps;
00676 msparams->maxSteps = amcmcMaxSteps;
00677 msparams->maxSeconds = amcmcMaxSeconds;
00678 msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00679
00680
00681 GibbsParams* gibbsparams = new GibbsParams;
00682 gibbsparams->mwsParams = mwsparams;
00683 gibbsparams->numChains = amcmcNumChains;
00684 gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00685 gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00686 gibbsparams->minSteps = amcmcMinSteps;
00687 gibbsparams->maxSteps = amcmcMaxSteps;
00688 gibbsparams->maxSeconds = amcmcMaxSeconds;
00689
00690 gibbsparams->gamma = 1 - agibbsDelta;
00691 gibbsparams->epsilonError = agibbsEpsilonError;
00692 gibbsparams->fracConverged = agibbsFracConverged;
00693 gibbsparams->walksatType = agibbsWalksatType;
00694 gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00695
00696
00697 SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00698 stparams->mwsParams = mwsparams;
00699 stparams->numChains = amcmcNumChains;
00700 stparams->burnMinSteps = amcmcBurnMinSteps;
00701 stparams->burnMaxSteps = amcmcBurnMaxSteps;
00702 stparams->minSteps = amcmcMinSteps;
00703 stparams->maxSteps = amcmcMaxSteps;
00704 stparams->maxSeconds = amcmcMaxSeconds;
00705
00706 stparams->subInterval = asimtpSubInterval;
00707 stparams->numST = asimtpNumST;
00708 stparams->numSwap = asimtpNumSwap;
00709
00710 Array<VariableState*> states;
00711 Array<Inference*> inferences;
00712
00713 states.growToSize(domains.size());
00714 inferences.growToSize(domains.size());
00715
00716
00717 Array<int> allPredGndingsAreNonEvid;
00718 Array<Predicate*> ppreds;
00719
00720 for (int i = 0; i < domains.size(); i++)
00721 {
00722 Domain* domain = domains[i];
00723 MLN* mln = mlns[i];
00724
00725
00726 if (!aLazy)
00727 domains[i]->getDB()->setLazyFlag(false);
00728
00729
00730 GroundPredicateHashArray* unePreds = NULL;
00731
00732
00733 GroundPredicateHashArray* knePreds = NULL;
00734 Array<TruthValue>* knePredValues = NULL;
00735
00736
00737 for (int j = 0; j < mln->getNumClauses(); j++)
00738 ((Clause*) mln->getClause(j))->setWt(1);
00739
00740
00741 if (!allPredsExceptQueriesAreCW)
00742 {
00743 for (int i = 0; i < owPredNames.size(); i++)
00744 {
00745 nePredsStr.append(",");
00746 nePredsStr.append(owPredNames[i]);
00747 nonEvidPredNames.append(owPredNames[i]);
00748 }
00749 }
00750
00751
00752
00753 if (!aLazy)
00754 {
00755 unePreds = new GroundPredicateHashArray;
00756 knePreds = new GroundPredicateHashArray;
00757 knePredValues = new Array<TruthValue>;
00758
00759 allPredGndingsAreNonEvid.growToSize(domain->getNumPredicates(), false);
00760
00761 createComLineQueryPreds(nePredsStr, domain, domain->getDB(),
00762 unePreds, knePreds,
00763 &allPredGndingsAreNonEvid);
00764
00765
00766
00767
00768
00769 knePredValues->growToSize(knePreds->size(), FALSE);
00770 for (int predno = 0; predno < knePreds->size(); predno++)
00771 (*knePredValues)[predno] =
00772 domain->getDB()->setValue((*knePreds)[predno], UNKNOWN);
00773
00774
00775
00776
00777 if (isQueryEvidence)
00778
00779 for (int predno = 0; predno < unePreds->size(); predno++)
00780 domain->getDB()->setValue((*unePreds)[predno], FALSE);
00781 }
00782
00783
00784
00785 cout << endl << "constructing state for domain " << i << "..." << endl;
00786 bool markHardGndClauses = false;
00787 bool trackParentClauseWts = true;
00788 VariableState*& state = states[i];
00789 state = new VariableState(unePreds, knePreds, knePredValues,
00790 &allPredGndingsAreNonEvid, markHardGndClauses,
00791 trackParentClauseWts, mln, domain, aLazy);
00792
00793 Inference*& inference = inferences[i];
00794 bool trackClauseTrueCnts = true;
00795
00796 if (amapPos || amapAll)
00797 {
00798
00799
00800 mwsparams->numSolutions = 1;
00801 inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts,
00802 mwsparams);
00803 }
00804 else if (amcsatInfer)
00805 {
00806 inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00807 }
00808 else if (asimtpInfer)
00809 {
00810
00811
00812 mwsparams->numSolutions = 1;
00813 inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00814 stparams);
00815 }
00816 else if (agibbsInfer)
00817 {
00818
00819
00820 mwsparams->numSolutions = 1;
00821 inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
00822 gibbsparams);
00823 }
00824
00825 if (!aLazy)
00826 {
00827
00828 domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00829
00830
00831
00832 for (int predno = 0; predno < unePreds->size(); predno++)
00833 {
00834 domain->getDB()->setValue((*unePreds)[predno], FALSE);
00835 }
00836 }
00837 }
00838 cout << endl << "done constructing variable states" << endl << endl;
00839
00840 VotedPerceptron vp(inferences, nonEvidPredNames, indexTrans, aLazy,
00841 rescaleGradient, withEM);
00842 if (!noPrior)
00843 vp.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00844 priorStdDevs.getItems());
00845 else
00846 vp.setMeansStdDevs(-1, NULL, NULL);
00847
00848 begSec = timer.time();
00849 cout << "learning (discriminative) weights .. " << endl;
00850 vp.learnWeights(wwts, wts.size()-1, numIter, learningRate, momentum,
00851 !initToZero, amwsMaxSubsequentSteps);
00852 cout << endl << endl << "Done learning discriminative weights. "<< endl;
00853 cout << "Time Taken for learning = ";
00854 Timer::printTime(cout, (timer.time() - begSec)); cout << endl;
00855
00856 if (mwsparams) delete mwsparams;
00857 if (ssparams) delete ssparams;
00858 if (msparams) delete msparams;
00859 if (gibbsparams) delete gibbsparams;
00860 if (stparams) delete stparams;
00861 for (int i = 0; i < inferences.size(); i++) delete inferences[i];
00862 for (int i = 0; i < states.size(); i++) delete states[i];
00863 }
00864 else
00865 {
00867
00868 Array<bool> areNonEvidPreds;
00869 if (nonEvidPredNames.empty())
00870 {
00871 areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), true);
00872 for (int i = 0; i < domains[0]->getNumPredicates(); i++)
00873 {
00874
00875 if (domains[0]->getPredicateTemplate(i)->isEqualPred())
00876 {
00877 const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00878 int predId = domains[0]->getPredicateId(pname);
00879 areNonEvidPreds[predId] = false;
00880 }
00881
00882 if (domains[0]->getPredicateTemplate(i)->isInternalPredicateTemplate())
00883 {
00884 const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00885 int predId = domains[0]->getPredicateId(pname);
00886 areNonEvidPreds[predId] = false;
00887 }
00888 }
00889 }
00890 else
00891 {
00892 areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), false);
00893 for (int i = 0; i < nonEvidPredNames.size(); i++)
00894 {
00895 int predId = domains[0]->getPredicateId(nonEvidPredNames[i].c_str());
00896 if (predId < 0)
00897 {
00898 cout << "ERROR: Predicate " << nonEvidPredNames[i] << " undefined."
00899 << endl;
00900 exit(-1);
00901 }
00902 areNonEvidPreds[predId] = true;
00903 }
00904 }
00905
00906 Array<bool>* nePreds = &areNonEvidPreds;
00907 PseudoLogLikelihood pll(nePreds, &domains, !noEqualPredWt, false,-1,-1,-1);
00908 pll.setIndexTranslator(indexTrans);
00909
00910 if (!noPrior)
00911 pll.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00912 priorStdDevs.getItems());
00913 else
00914 pll.setMeansStdDevs(-1, NULL, NULL);
00915
00917
00918 begSec = timer.time();
00919 for (int m = 0; m < mlns.size(); m++)
00920 {
00921 cout << "Computing counts for clauses in domain " << m << "..." << endl;
00922 const ClauseHashArray* clauses = mlns[m]->getClauses();
00923 for (int i = 0; i < clauses->size(); i++)
00924 {
00925 if (PRINT_CLAUSE_DURING_COUNT)
00926 {
00927 cout << "clause " << i << ": ";
00928 (*clauses)[i]->printWithoutWt(cout, domains[m]);
00929 cout << endl; cout.flush();
00930 }
00931 MLNClauseInfo* ci = (MLNClauseInfo*) mlns[m]->getMLNClauseInfo(i);
00932 pll.computeCountsForNewAppendedClause((*clauses)[i], &(ci->index), m,
00933 NULL, false, NULL);
00934 }
00935 }
00936 pll.compress();
00937 cout <<"Computing counts took ";
00938 Timer::printTime(cout, timer.time() - begSec); cout << endl;
00939
00941
00942
00943 wts.growToSize(numClausesFormulas + 1);
00944 for (int i = 0; i < numClausesFormulas; i++) wts[i+1] = 0;
00945
00946
00947
00948 cout << "L-BFGS-B is finding optimal weights......" << endl;
00949 begSec = timer.time();
00950 LBFGSB lbfgsb(maxIter, convThresh, &pll, numClausesFormulas);
00951 int iter;
00952 bool error;
00953 double pllValue = lbfgsb.minimize((double*)wts.getItems(), iter, error);
00954
00955 if (error) cout << "LBFGSB returned with an error!" << endl;
00956 cout << "num iterations = " << iter << endl;
00957 cout << "time taken = ";
00958 Timer::printTime(cout, timer.time() - begSec);
00959 cout << endl;
00960 cout << "pseudo-log-likelihood = " << -pllValue << endl;
00961
00962 }
00963
00965 if (indexTrans) assignWtsAndOutputMLN(out, mlns, domains, wts, indexTrans);
00966 else assignWtsAndOutputMLN(out, mlns, domains, wts);
00967
00968 out.close();
00969
00971 deleteDomains(domains);
00972 for (int i = 0; i < mlns.size(); i++) delete mlns[i];
00973 PowerSet::deletePowerSet();
00974 if (indexTrans) delete indexTrans;
00975
00976 cout << "Total time = ";
00977 Timer::printTime(cout, timer.time() - startSec); cout << endl;
00978 }