00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef _INFER_H_OCT_30_2005
00067 #define _INFER_H_OCT_30_2005
00068
00073 #include "util.h"
00074 #include "fol.h"
00075 #include "mrf.h"
00076 #include "learnwts.h"
00077 #include "inferenceargs.h"
00078 #include "maxwalksat.h"
00079 #include "mcsat.h"
00080 #include "gibbssampler.h"
00081 #include "simulatedtempering.h"
00082
00083
00084
00085 char* aevidenceFiles = NULL;
00086 char* aresultsFile = NULL;
00087 char* aqueryPredsStr = NULL;
00088 char* aqueryFile = NULL;
00089
00090 string queryPredsStr, queryFile;
00091 GroundPredicateHashArray queries;
00092 GroundPredicateHashArray knownQueries;
00093
00125 bool createComLineQueryPreds(const string& queryPredsStr,
00126 const Domain* const & domain,
00127 Database* const & db,
00128 GroundPredicateHashArray* const & queries,
00129 GroundPredicateHashArray* const & knownQueries,
00130 Array<int>* const & allPredGndingsAreQueries,
00131 bool printToFile, ostream& out, bool amapPos,
00132 const GroundPredicateHashArray* const & trueQueries,
00133 const Array<double>* const & trueProbs)
00134 {
00135 if (queryPredsStr.length() == 0) return true;
00136 string preds = Util::trim(queryPredsStr);
00137
00138
00139 int balparen = 0;
00140 for (unsigned int i = 0; i < preds.length(); i++)
00141 {
00142 if (preds.at(i)=='(') balparen++;
00143 else if (preds.at(i)==')') balparen--;
00144 else if (preds.at(i)==',' && balparen==0) preds.at(i)='\n';
00145 }
00146
00147 bool onlyPredName;
00148 bool ret = true;
00149 unsigned int cur;
00150 int termId, varIdCnt = 0;
00151 hash_map<string, int, HashString, EqualString> varToId;
00152 hash_map<string, int, HashString, EqualString>::iterator it;
00153 Array<VarsTypeId*>* vtiArr;
00154 string pred, predName, term;
00155 const PredicateTemplate* ptemplate;
00156 istringstream iss(preds);
00157 char delimit[2]; delimit[1] = '\0';
00158
00159
00160 while (getline(iss, pred))
00161 {
00162 onlyPredName = false;
00163 varToId.clear();
00164 varIdCnt = 0;
00165 cur = 0;
00166
00167
00168 if (!Util::substr(pred,cur,predName,"("))
00169 {
00170 predName = pred;
00171 onlyPredName = true;
00172 }
00173
00174
00175 ptemplate = domain->getPredicateTemplate(predName.c_str());
00176 if (ptemplate == NULL)
00177 {
00178 cout << "ERROR: Cannot find command line query predicate" << predName
00179 << " in domain." << endl;
00180 ret = false;
00181 continue;
00182 }
00183 Predicate ppred(ptemplate);
00184
00185
00186 if (!onlyPredName)
00187 {
00188
00189 for (int i = 0; i < 2; i++)
00190 {
00191 if (i == 0) delimit[0] = ',';
00192 else delimit[0] = ')';
00193 while(Util::substr(pred, cur, term, delimit))
00194 {
00195
00196 if (isupper(term.at(0)) || term.at(0) == '"')
00197 {
00198 termId = domain->getConstantId(term.c_str());
00199 if (termId < 0)
00200 {
00201 cout <<"ERROR: Cannot find constant "<<term<<" in database"<<endl;
00202 ret = false;
00203 }
00204 }
00205 else
00206 {
00207 if ((it=varToId.find(term)) == varToId.end())
00208 {
00209 termId = --varIdCnt;
00210 varToId[term] = varIdCnt;
00211 }
00212 else
00213 termId = (*it).second;
00214 }
00215 ppred.appendTerm(new Term(termId, (void*)&ppred, true));
00216 }
00217 }
00218 }
00219 else
00220 {
00221 (*allPredGndingsAreQueries)[ptemplate->getId()] = true;
00222 for (int i = 0; i < ptemplate->getNumTerms(); i++)
00223 ppred.appendTerm(new Term(--varIdCnt, (void*)&ppred, true));
00224 }
00225
00226
00227 if (ppred.getNumTerms() != ptemplate->getNumTerms())
00228 {
00229 cout << "ERROR: " << predName << " requires " << ptemplate->getNumTerms()
00230 << " terms but given " << ppred.getNumTerms() << endl;
00231 ret = false;
00232 }
00233 if (!ret) continue;
00234
00235
00237 vtiArr = NULL;
00238 ppred.createVarsTypeIdArr(vtiArr);
00239
00240
00241 if (vtiArr->size() <= 1)
00242 {
00243 assert(ppred.isGrounded());
00244 assert(!db->isClosedWorld(ppred.getId()));
00245 TruthValue tv = db->getValue(&ppred);
00246 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00247
00248
00249 if (printToFile) assert(tv != UNKNOWN);
00250 if (tv == UNKNOWN)
00251 {
00252 if (queries->append(gndPred) < 0) delete gndPred;
00253 }
00254 else
00255 {
00256
00257 if (printToFile)
00258 {
00259
00260 if (trueQueries)
00261 {
00262 double prob = 0.0;
00263 if (domain->getDB()->getEvidenceStatus(&ppred))
00264 {
00265
00266 continue;
00267
00268 }
00269 else
00270 {
00271 int found = trueQueries->find(gndPred);
00272 if (found >= 0) prob = (*trueProbs)[found];
00273 else
00274
00275 prob = (prob*10000+1/2.0)/(10000+1.0);
00276
00277 }
00278 gndPred->print(out, domain); out << " " << prob << endl;
00279 }
00280 else
00281 {
00282 if (amapPos)
00283 {
00284 if (tv == TRUE)
00285 {
00286 ppred.printWithStrVar(out, domain);
00287 out << endl;
00288 }
00289 }
00290 else
00291 {
00292 ppred.printWithStrVar(out, domain);
00293 out << " " << tv << endl;
00294 }
00295 }
00296 delete gndPred;
00297 }
00298 else
00299 {
00300
00301
00302
00303 if (knownQueries->append(gndPred) < 0) delete gndPred;
00304 }
00305 }
00306 }
00307 else
00308 {
00309 ArraysAccessor<int> acc;
00310 for (int i = 1; i < vtiArr->size(); i++)
00311 {
00312 const Array<int>* cons=domain->getConstantsByType((*vtiArr)[i]->typeId);
00313 acc.appendArray(cons);
00314 }
00315
00316
00317 Array<int> constIds;
00318 while (acc.getNextCombination(constIds))
00319 {
00320 assert(constIds.size() == vtiArr->size()-1);
00321 for (int j = 0; j < constIds.size(); j++)
00322 {
00323 Array<Term*>& terms = (*vtiArr)[j+1]->vars;
00324 for (int k = 0; k < terms.size(); k++)
00325 terms[k]->setId(constIds[j]);
00326 }
00327
00328
00329 assert(!db->isClosedWorld(ppred.getId()));
00330
00331 TruthValue tv = db->getValue(&ppred);
00332 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00333
00334
00335 if (printToFile) assert(tv != UNKNOWN);
00336 if (tv == UNKNOWN)
00337 {
00338 if (queries->append(gndPred) < 0) delete gndPred;
00339 }
00340 else
00341 {
00342
00343 if (printToFile)
00344 {
00345
00346 if (trueQueries)
00347 {
00348 double prob = 0.0;
00349 if (domain->getDB()->getEvidenceStatus(&ppred))
00350 {
00351
00352 continue;
00353
00354 }
00355 else
00356 {
00357 int found = trueQueries->find(gndPred);
00358 if (found >= 0) prob = (*trueProbs)[found];
00359 else
00360
00361 prob = (prob*10000+1/2.0)/(10000+1.0);
00362 }
00363
00364
00365 gndPred->print(out, domain); out << " " << prob << endl;
00366 }
00367 else
00368 {
00369 if (amapPos)
00370 {
00371 if (tv == TRUE)
00372 {
00373 ppred.printWithStrVar(out, domain);
00374 out << endl;
00375 }
00376 }
00377 else
00378 {
00379 ppred.printWithStrVar(out, domain);
00380 out << " " << tv << endl;
00381 }
00382 }
00383 delete gndPred;
00384 }
00385 else
00386 {
00387
00388
00389 if (knownQueries->append(gndPred) < 0) delete gndPred;
00390 }
00391 }
00392 }
00393 }
00394
00395 ppred.deleteVarsTypeIdArr(vtiArr);
00396 }
00397
00398 if (!printToFile)
00399 {
00400 queries->compress();
00401 knownQueries->compress();
00402 }
00403
00404 return ret;
00405 }
00406
00411 bool createComLineQueryPreds(const string& queryPredsStr,
00412 const Domain* const & domain,
00413 Database* const & db,
00414 GroundPredicateHashArray* const & queries,
00415 GroundPredicateHashArray* const & knownQueries,
00416 Array<int>* const & allPredGndingsAreQueries)
00417 {
00418 return createComLineQueryPreds(queryPredsStr, domain, db,
00419 queries, knownQueries,
00420 allPredGndingsAreQueries,
00421 false, cout, false, NULL, NULL);
00422 }
00423
00434 bool extractPredNames(string preds, const string* queryFile,
00435 StringHashArray& predNames)
00436 {
00437 predNames.clear();
00438
00439
00440 string::size_type cur = 0, ws, ltparen;
00441 string qpred, predName;
00442
00443 if (preds.length() > 0)
00444 {
00445 preds.append(" ");
00446
00447
00448 int balparen = 0;
00449 for (unsigned int i = 0; i < preds.length(); i++)
00450 {
00451 if (preds.at(i) == '(') balparen++;
00452 else if (preds.at(i) == ')') balparen--;
00453 else if (preds.at(i) == ',' && balparen == 0) preds.at(i) = ' ';
00454 }
00455
00456 while (preds.at(cur) == ' ') cur++;
00457 while (true)
00458 {
00459 ws = preds.find(" ", cur);
00460 if (ws == string::npos) break;
00461 qpred = preds.substr(cur,ws-cur+1);
00462 cur = ws+1;
00463 while (cur < preds.length() && preds.at(cur) == ' ') cur++;
00464 ltparen = qpred.find("(",0);
00465
00466 if (ltparen == string::npos)
00467 {
00468 ws = qpred.find(" ");
00469 if (ws != string::npos) qpred = qpred.substr(0,ws);
00470 predName = qpred;
00471 }
00472 else
00473 predName = qpred.substr(0,ltparen);
00474
00475 predNames.append(predName);
00476 }
00477 }
00478
00479 if (queryFile == NULL || queryFile->length() == 0) return true;
00480
00481
00482 ifstream in((*queryFile).c_str());
00483 if (!in.good())
00484 {
00485 cout << "ERROR: unable to open " << *queryFile << endl;
00486 return false;
00487 }
00488 string buffer;
00489 while (getline(in, buffer))
00490 {
00491 cur = 0;
00492 while (cur < buffer.length() && buffer.at(cur) == ' ') cur++;
00493 ltparen = buffer.find("(", cur);
00494 if (ltparen == string::npos) continue;
00495 predName = buffer.substr(cur, ltparen-cur);
00496 predNames.append(predName);
00497 }
00498
00499 in.close();
00500 return true;
00501 }
00502
00507 char getTruthValueFirstChar(const TruthValue& tv)
00508 {
00509 if (tv == TRUE) return 'T';
00510 if (tv == FALSE) return 'F';
00511 if (tv == UNKNOWN) return 'U';
00512 assert(false);
00513 exit(-1);
00514 return '#';
00515 }
00516
00520 void setsrand()
00521 {
00522 struct timeval tv;
00523 struct timezone tzp;
00524 gettimeofday(&tv,&tzp);
00525 unsigned int seed = (( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec;
00526 srand(seed);
00527 }
00528
00529
00530
00531 void copyFileAndAppendDbFile(const string& srcFile, string& dstFile,
00532 const Array<string>& dbFilesArr,
00533 const Array<string>& constFilesArr)
00534 {
00535 ofstream out(dstFile.c_str());
00536 ifstream in(srcFile.c_str());
00537 if (!out.good()) { cout<<"ERROR: failed to open "<<dstFile<<endl;exit(-1);}
00538 if (!in.good()) { cout<<"ERROR: failed to open "<<srcFile<<endl;exit(-1);}
00539
00540 string buffer;
00541 while(getline(in, buffer)) out << buffer << endl;
00542 in.close();
00543
00544 out << endl;
00545 for (int i = 0; i < constFilesArr.size(); i++)
00546 out << "#include \"" << constFilesArr[i] << "\"" << endl;
00547 out << endl;
00548 for (int i = 0; i < dbFilesArr.size(); i++)
00549 out << "#include \"" << dbFilesArr[i] << "\"" << endl;
00550 out.close();
00551 }
00552
00553
00554 bool checkQueryPredsNotInClosedWorldPreds(const StringHashArray& qpredNames,
00555 const StringHashArray& cwPredNames)
00556 {
00557 bool ok = true;
00558 for (int i = 0; i < qpredNames.size(); i++)
00559 if (cwPredNames.contains(qpredNames[i]))
00560 {
00561 cout << "ERROR: query predicate " << qpredNames[i]
00562 << " cannot be specified as closed world" << endl;
00563 ok = false;
00564 }
00565 return ok;
00566 }
00567
00596 bool createQueryFilePreds(const string& queryFile,
00597 const Domain* const & domain,
00598 Database* const & db,
00599 GroundPredicateHashArray* const &queries,
00600 GroundPredicateHashArray* const &knownQueries,
00601 bool printToFile, ostream& out, bool amapPos,
00602 const GroundPredicateHashArray* const &trueQueries,
00603 const Array<double>* const & trueProbs)
00604 {
00605 if (queryFile.length() == 0) return true;
00606
00607 bool ret = true;
00608 ifstream in(queryFile.c_str());
00609 unsigned int line = 0;
00610 unsigned int cur;
00611 int constId, predId;
00612 bool ok;
00613 string predStr, predName, constant;
00614 Array<int> constIds;
00615 const PredicateTemplate* ptemplate;
00616
00617 while (getline(in, predStr))
00618 {
00619 line++;
00620 cur = 0;
00621
00622
00623 ok = Util::substr(predStr, cur, predName, "(");
00624 if (!ok) continue;
00625
00626 predId = domain->getPredicateId(predName.c_str());
00627 ptemplate = domain->getPredicateTemplate(predId);
00628
00629 if (predId < 0 || ptemplate == NULL)
00630 {
00631 cout << "ERROR: Cannot find " << predName << " in domain on line "
00632 << line << " of query file " << queryFile << endl;
00633 ret = false;
00634 continue;
00635 }
00636
00637
00638 constIds.clear();
00639 while (Util::substr(predStr, cur, constant, ","))
00640 {
00641 constId = domain->getConstantId(constant.c_str());
00642 constIds.append(constId);
00643 if (constId < 0)
00644 {
00645 cout << "ERROR: Cannot find " << constant << " in database on line "
00646 << line << " of query file " << queryFile << endl;
00647 ret = false;
00648 }
00649 }
00650
00651
00652 ok = Util::substr(predStr, cur, constant, ")");
00653 if (!ok)
00654 {
00655 cout << "ERROR: Failed to parse ground predicate on line " << line
00656 << " of query file " << queryFile << endl;
00657 ret = false;
00658 continue;
00659 }
00660
00661 constId = domain->getConstantId(constant.c_str());
00662 constIds.append(constId);
00663 if (constId < 0)
00664 {
00665 cout << "ERROR: Cannot find " << constant << " in database on line "
00666 << line << " of query file " << queryFile << endl;
00667 ret = false;
00668 }
00669
00670 if (!ret) continue;
00672
00673
00674 if (constIds.size() != ptemplate->getNumTerms())
00675 {
00676 cout << "ERROR: incorrect number of terms for " << predName
00677 << ". Expected " << ptemplate->getNumTerms() << ", given "
00678 << constIds.size() << endl;
00679 ret = false;
00680 continue;
00681 }
00682
00683 Predicate pred(ptemplate);
00684 for (int i = 0; i < constIds.size(); i++)
00685 {
00686 if (pred.getTermTypeAsInt(i) != domain->getConstantTypeId(constIds[i]))
00687 {
00688 cout << "ERROR: wrong type for term "
00689 << domain->getConstantName(constIds[i]) << " for predicate "
00690 << predName << " on line " << line << " of query file "
00691 << queryFile << endl;
00692 ret = false;
00693 continue;
00694 }
00695 pred.appendTerm(new Term(constIds[i], (void*)&pred, true));
00696 }
00697 if (!ret) continue;
00698
00699 assert(!db->isClosedWorld(predId));
00700
00701 TruthValue tv = db->getValue(&pred);
00702 GroundPredicate* gndPred = new GroundPredicate(&pred);
00703
00704
00705 if (printToFile) assert(tv != UNKNOWN);
00706 if (tv == UNKNOWN)
00707 {
00708 if (queries->append(gndPred) < 0) delete gndPred;
00709 }
00710 else
00711 {
00712
00713 if (printToFile)
00714 {
00715
00716
00717 if (trueQueries)
00718 {
00719 double prob = 0.0;
00720 if (domain->getDB()->getEvidenceStatus(&pred))
00721 {
00722
00723 continue;
00724
00725 }
00726 else
00727 {
00728 int found = trueQueries->find(gndPred);
00729 if (found >= 0) prob = (*trueProbs)[found];
00730 else
00731
00732 prob = (prob*10000+1/2.0)/(10000+1.0);
00733 }
00734 gndPred->print(out, domain); out << " " << prob << endl;
00735 }
00736 else
00737 {
00738 if (amapPos)
00739 {
00740 if (tv == TRUE)
00741 {
00742 pred.printWithStrVar(out, domain);
00743 out << endl;
00744 }
00745 }
00746 else
00747 {
00748 pred.printWithStrVar(out, domain);
00749 out << " " << tv << endl;
00750 }
00751 }
00752 delete gndPred;
00753 }
00754 else
00755 {
00756
00757
00758 if (knownQueries->append(gndPred) < 0) delete gndPred;
00759 }
00760 }
00761 }
00762
00763 in.close();
00764 return ret;
00765 }
00766
00771 bool createQueryFilePreds(const string& queryFile, const Domain* const & domain,
00772 Database* const & db,
00773 GroundPredicateHashArray* const &queries,
00774 GroundPredicateHashArray* const &knownQueries)
00775 {
00776 return createQueryFilePreds(queryFile, domain, db, queries, knownQueries,
00777 false, cout, false, NULL, NULL);
00778 }
00779
00780 void readPredValuesAndSetToUnknown(const StringHashArray& predNames,
00781 Domain *domain,
00782 Array<Predicate *> &queryPreds,
00783 Array<TruthValue> &queryPredValues,
00784 bool isQueryEvidence)
00785 {
00786 Array<Predicate*> ppreds;
00787
00788
00789 queryPreds.clear();
00790 queryPredValues.clear();
00791 for (int predno = 0; predno < predNames.size(); predno++)
00792 {
00793 ppreds.clear();
00794 int predid = domain->getPredicateId(predNames[predno].c_str());
00795 Predicate::createAllGroundings(predid, domain, ppreds);
00796 for (int gpredno = 0; gpredno < ppreds.size(); gpredno++)
00797 {
00798 Predicate *pred = ppreds[gpredno];
00799 TruthValue tv = domain->getDB()->getValue(pred);
00800 if (tv == UNKNOWN)
00801 domain->getDB()->setValue(pred,FALSE);
00802
00803
00804
00805
00806 if (isQueryEvidence && tv == UNKNOWN)
00807 delete pred;
00808 else
00809 queryPreds.append(pred);
00810 }
00811 }
00812
00813
00814 domain->getDB()->setValuesToUnknown(&queryPreds, &queryPredValues);
00815 }
00816
00827 void setPredsToGivenValues(const StringHashArray& predNames, Domain *domain,
00828 Array<TruthValue> &gpredValues)
00829 {
00830 Array<Predicate*> gpreds;
00831 Array<Predicate*> ppreds;
00832 Array<TruthValue> tmpValues;
00833
00834
00835 gpreds.clear();
00836 tmpValues.clear();
00837 for (int predno = 0; predno < predNames.size(); predno++)
00838 {
00839 ppreds.clear();
00840 int predid = domain->getPredicateId(predNames[predno].c_str());
00841 Predicate::createAllGroundings(predid, domain, ppreds);
00842
00843 gpreds.append(ppreds);
00844 }
00845
00846 domain->getDB()->setValuesToGivenValues(&gpreds, &gpredValues);
00847 for (int gpredno = 0; gpredno < gpreds.size(); gpredno++)
00848 delete gpreds[gpredno];
00849 }
00850
00851
00858 int buildInference(Inference*& inference, Domain*& domain)
00859 {
00860 string inMLNFile, wkMLNFile, evidenceFile;
00861
00862 StringHashArray queryPredNames;
00863 StringHashArray owPredNames;
00864 StringHashArray cwPredNames;
00865 MLN* mln = NULL;
00866 Array<string> constFilesArr;
00867 Array<string> evidenceFilesArr;
00868
00869 Array<Predicate *> queryPreds;
00870 Array<TruthValue> queryPredValues;
00871
00872
00873
00874
00875
00876
00877 extractFileNames(ainMLNFiles, constFilesArr);
00878 assert(constFilesArr.size() >= 1);
00879 inMLNFile.append(constFilesArr[0]);
00880 constFilesArr.removeItem(0);
00881 extractFileNames(aevidenceFiles, evidenceFilesArr);
00882
00883 if (aqueryPredsStr) queryPredsStr.append(aqueryPredsStr);
00884 if (aqueryFile) queryFile.append(aqueryFile);
00885
00886 if (queryPredsStr.length() == 0 && queryFile.length() == 0)
00887 { cout << "No query predicates specified" << endl; return -1; }
00888
00889 if (agibbsInfer && amcmcNumChains < 2)
00890 {
00891 cout << "ERROR: there must be at least 2 MCMC chains in Gibbs sampling"
00892 << endl; return -1;
00893 }
00894
00895 if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer)
00896 {
00897 cout << "ERROR: must specify one of -ms/-simtp/-m/-a/-p flags." << endl;
00898 return -1;
00899 }
00900
00901
00902 if (queryPredsStr.length() > 0 || queryFile.length() > 0)
00903 {
00904 if (!extractPredNames(queryPredsStr, &queryFile, queryPredNames)) return -1;
00905 }
00906
00907 if (amwsMaxSteps <= 0)
00908 { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00909
00910 if (amwsTries <= 0)
00911 { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
00912
00913
00914 if (aOpenWorldPredsStr)
00915 {
00916 if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames))
00917 return -1;
00918 assert(owPredNames.size() > 0);
00919 }
00920
00921
00922 if (aClosedWorldPredsStr)
00923 {
00924 if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames))
00925 return -1;
00926 assert(cwPredNames.size() > 0);
00927 if (!checkQueryPredsNotInClosedWorldPreds(queryPredNames, cwPredNames))
00928 return -1;
00929 }
00930
00931
00932
00933
00934
00935
00936
00937
00938
00939
00940
00941 SampleSatParams* ssparams = new SampleSatParams;
00942 ssparams->lateSa = assLateSa;
00943 ssparams->saRatio = assSaRatio;
00944 ssparams->saTemp = assSaTemp;
00945
00946
00947 MaxWalksatParams* mwsparams = new MaxWalksatParams;
00948 mwsparams->ssParams = ssparams;
00949 mwsparams->maxSteps = amwsMaxSteps;
00950 mwsparams->maxTries = amwsTries;
00951 mwsparams->targetCost = amwsTargetWt;
00952 mwsparams->hard = amwsHard;
00953
00954
00955 mwsparams->numSolutions = amwsNumSolutions;
00956 mwsparams->heuristic = amwsHeuristic;
00957 mwsparams->tabuLength = amwsTabuLength;
00958 mwsparams->lazyLowState = amwsLazyLowState;
00959
00960
00961 MCSatParams* msparams = new MCSatParams;
00962 msparams->mwsParams = mwsparams;
00963
00964 msparams->numChains = 1;
00965 msparams->burnMinSteps = amcmcBurnMinSteps;
00966 msparams->burnMaxSteps = amcmcBurnMaxSteps;
00967 msparams->minSteps = amcmcMinSteps;
00968 msparams->maxSteps = amcmcMaxSteps;
00969 msparams->maxSeconds = amcmcMaxSeconds;
00970 msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00971
00972
00973 GibbsParams* gibbsparams = new GibbsParams;
00974 gibbsparams->mwsParams = mwsparams;
00975 gibbsparams->numChains = amcmcNumChains;
00976 gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00977 gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00978 gibbsparams->minSteps = amcmcMinSteps;
00979 gibbsparams->maxSteps = amcmcMaxSteps;
00980 gibbsparams->maxSeconds = amcmcMaxSeconds;
00981
00982 gibbsparams->gamma = 1 - agibbsDelta;
00983 gibbsparams->epsilonError = agibbsEpsilonError;
00984 gibbsparams->fracConverged = agibbsFracConverged;
00985 gibbsparams->walksatType = agibbsWalksatType;
00986 gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00987
00988
00989 SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00990 stparams->mwsParams = mwsparams;
00991 stparams->numChains = amcmcNumChains;
00992 stparams->burnMinSteps = amcmcBurnMinSteps;
00993 stparams->burnMaxSteps = amcmcBurnMaxSteps;
00994 stparams->minSteps = amcmcMinSteps;
00995 stparams->maxSteps = amcmcMaxSteps;
00996 stparams->maxSeconds = amcmcMaxSeconds;
00997
00998 stparams->subInterval = asimtpSubInterval;
00999 stparams->numST = asimtpNumST;
01000 stparams->numSwap = asimtpNumSwap;
01001
01003
01004 cout << "Reading formulas and evidence predicates..." << endl;
01005
01006
01007 string::size_type bslash = inMLNFile.rfind("/");
01008 string tmp = (bslash == string::npos) ?
01009 inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
01010 char buf[100];
01011 sprintf(buf, "%s%s", tmp.c_str(), ZZ_TMP_FILE_POSTFIX);
01012 wkMLNFile = buf;
01013 copyFileAndAppendDbFile(inMLNFile, wkMLNFile,
01014 evidenceFilesArr, constFilesArr);
01015
01016
01017 domain = new Domain;
01018 mln = new MLN();
01019 bool addUnitClauses = false;
01020 bool mustHaveWtOrFullStop = true;
01021 bool warnAboutDupGndPreds = true;
01022 bool flipWtsOfFlippedClause = true;
01023
01024 bool allPredsExceptQueriesAreCW = owPredNames.empty();
01025 Domain* forCheckingPlusTypes = NULL;
01026
01027
01028
01029 if (!runYYParser(mln, domain, wkMLNFile.c_str(), allPredsExceptQueriesAreCW,
01030 &owPredNames, &queryPredNames, addUnitClauses,
01031 warnAboutDupGndPreds, 0, mustHaveWtOrFullStop,
01032 forCheckingPlusTypes, true, flipWtsOfFlippedClause))
01033 {
01034 unlink(wkMLNFile.c_str());
01035 return -1;
01036 }
01037
01038 unlink(wkMLNFile.c_str());
01039
01041
01043 Array<int> allPredGndingsAreQueries;
01044
01045
01046
01047 if (!aLazy)
01048 {
01049 if (queryFile.length() > 0)
01050 {
01051 cout << "Reading query predicates that are specified in query file..."
01052 << endl;
01053 bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
01054 &queries, &knownQueries);
01055 if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01056 }
01057
01058 allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
01059 if (queryPredsStr.length() > 0)
01060 {
01061 cout << "Creating query predicates that are specified on command line..."
01062 << endl;
01063 bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(),
01064 &queries, &knownQueries,
01065 &allPredGndingsAreQueries);
01066 if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01067 }
01068 }
01069
01070
01071 bool markHardGndClauses = false;
01072 bool trackParentClauseWts = false;
01073
01074
01075 VariableState* state = new VariableState(&queries, NULL, NULL,
01076 &allPredGndingsAreQueries,
01077 markHardGndClauses,
01078 trackParentClauseWts,
01079 mln, domain, aLazy);
01080 bool trackClauseTrueCnts = false;
01081
01082 if (amapPos || amapAll || amcsatInfer || agibbsInfer || asimtpInfer)
01083 {
01084 if (amapPos || amapAll)
01085 {
01086
01087
01088 mwsparams->numSolutions = 1;
01089 inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts, mwsparams);
01090 }
01091 else if (amcsatInfer)
01092 {
01093 inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
01094 }
01095 else if (asimtpInfer)
01096 {
01097
01098
01099 mwsparams->numSolutions = 1;
01100 inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
01101 stparams);
01102 }
01103 else if (agibbsInfer)
01104 {
01105
01106
01107 mwsparams->numSolutions = 1;
01108 inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
01109 gibbsparams);
01110 }
01111 }
01112 return 1;
01113 }
01114
01115
01116 typedef hash_map<string, const Array<const char*>*, HashString, EqualString>
01117 StringToStrArrayMap;
01118
01119 #endif