00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef MCMC_H_
00067 #define MCMC_H_
00068
00069 #include "inference.h"
00070 #include "mcmcparams.h"
00071
00072
00073 const bool mcmcdebug = false;
00074
00079 class MCMC : public Inference
00080 {
00081 public:
00082
00089 MCMC(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00090 MCMCParams* params)
00091 : Inference(state, seed, trackClauseTrueCnts)
00092 {
00093
00094 numChains_ = params->numChains;
00095 burnMinSteps_ = params->burnMinSteps;
00096 burnMaxSteps_ = params->burnMaxSteps;
00097 minSteps_ = params->minSteps;
00098 maxSteps_ = params->maxSteps;
00099 maxSeconds_ = params->maxSeconds;
00100 }
00101
00105 ~MCMC() {}
00106
00110 void printProbabilities(ostream& out)
00111 {
00112 for (int i = 0; i < state_->getNumAtoms(); i++)
00113 {
00114 double prob = getProbTrue(i);
00115
00116
00117 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00118 state_->printGndPred(i, out);
00119 out << " " << prob << endl;
00120 }
00121 }
00122
00131 void getPredsWithNonZeroProb(vector<string>& nonZeroPreds,
00132 vector<float>& probs)
00133 {
00134 nonZeroPreds.clear();
00135 probs.clear();
00136 for (int i = 0; i < state_->getNumAtoms(); i++)
00137 {
00138 double prob = getProbTrue(i);
00139 if (prob > 0)
00140 {
00141
00142 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00143 ostringstream oss(ostringstream::out);
00144 state_->printGndPred(i, oss);
00145 nonZeroPreds.push_back(oss.str());
00146 probs.push_back(prob);
00147 }
00148 }
00149 }
00150
00157 double getProbability(GroundPredicate* const& gndPred)
00158 {
00159 int idx = state_->getGndPredIndex(gndPred);
00160 double prob = 0.0;
00161 if (idx >= 0) prob = getProbTrue(idx);
00162
00163 return (prob*10000 + 1/2.0)/(10000 + 1.0);
00164 }
00165
00169 void printTruePreds(ostream& out)
00170 {
00171 for (int i = 0; i < state_->getNumAtoms(); i++)
00172 {
00173 double prob = getProbTrue(i);
00174
00175
00176 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00177 if (prob >= 0.5) state_->printGndPred(i, out);
00178 }
00179 }
00180
00181 protected:
00182
00191 void initTruthValuesAndWts(const int& numChains)
00192 {
00193 int numPreds = state_->getNumAtoms();
00194 truthValues_.growToSize(numPreds);
00195 wtsWhenFalse_.growToSize(numPreds);
00196 wtsWhenTrue_.growToSize(numPreds);
00197 for (int i = 0; i < numPreds; i++)
00198 {
00199 truthValues_[i].growToSize(numChains, false);
00200 wtsWhenFalse_[i].growToSize(numChains, 0);
00201 wtsWhenTrue_[i].growToSize(numChains, 0);
00202 }
00203
00204 int numClauses = state_->getNumClauses();
00205 numTrueLits_.growToSize(numClauses);
00206 for (int i = 0; i < numClauses; i++)
00207 {
00208 numTrueLits_[i].growToSize(numChains, 0);
00209 }
00210 }
00211
00216 void initNumTrue()
00217 {
00218 int numPreds = state_->getNumAtoms();
00219 numTrue_.growToSize(numPreds);
00220 for (int i = 0; i < numTrue_.size(); i++)
00221 numTrue_[i] = 0;
00222 }
00223
00230 void initNumTrueLits(const int& numChains)
00231 {
00232 for (int i = 0; i < state_->getNumClauses(); i++)
00233 {
00234 GroundClause* gndClause = state_->getGndClause(i);
00235 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00236 {
00237 const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1;
00238 for (int c = 0; c < numChains; c++)
00239 {
00240 if (truthValues_[atomIdx][c] == gndClause->getGroundPredicateSense(j))
00241 {
00242 numTrueLits_[i][c]++;
00243 assert(numTrueLits_[i][c] <= state_->getNumAtoms());
00244 }
00245 }
00246 }
00247 }
00248 }
00249
00257 void randomInitGndPredsTruthValues(const int& numChains)
00258 {
00259 for (int c = 0; c < numChains; c++)
00260 {
00261
00262 for (int i = 0; i < state_->getNumBlocks(); i++)
00263 {
00264
00265 if (state_->getBlockEvidence(i))
00266 {
00267
00268 setOthersInBlockToFalse(c, -1, i);
00269 continue;
00270 }
00271
00272 Array<int>& block = state_->getBlockArray(i);
00273 int chosen = random() % block.size();
00274 truthValues_[block[chosen]][c] = true;
00275 setOthersInBlockToFalse(c, chosen, i);
00276 }
00277
00278
00279 for (int i = 0; i < truthValues_.size(); i++)
00280 {
00281
00282 if (state_->getBlockIndex(i) == -1)
00283 {
00284 bool tv = genTruthValueForProb(0.5);
00285 truthValues_[i][c] = tv;
00286 }
00287 }
00288 }
00289 }
00290
00297 bool genTruthValueForProb(const double& p)
00298 {
00299 if (p == 1.0) return true;
00300 if (p == 0.0) return false;
00301 bool r = random() <= p*RAND_MAX;
00302 return r;
00303 }
00304
00314 double getProbabilityOfPred(const int& predIdx, const int& chainIdx,
00315 const double& invTemp)
00316 {
00317
00318 if (numChains_ > 1)
00319 {
00320 return 1.0 /
00321 ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] -
00322 wtsWhenTrue_[predIdx][chainIdx]) *
00323 invTemp));
00324 }
00325 else
00326 {
00327 GroundPredicate* gndPred = state_->getGndPred(predIdx);
00328 return 1.0 /
00329 ( 1.0 + exp((gndPred->getWtWhenFalse() -
00330 gndPred->getWtWhenTrue()) *
00331 invTemp));
00332 }
00333 }
00334
00343 void setOthersInBlockToFalse(const int& chainIdx, const int& atomIdx,
00344 const int& blockIdx)
00345 {
00346 Array<int>& block = state_->getBlockArray(blockIdx);
00347 for (int i = 0; i < block.size(); i++)
00348 {
00349 if (i != atomIdx)
00350 truthValues_[block[i]][chainIdx] = false;
00351 }
00352 }
00353
00364 void performGibbsStep(const int& chainIdx, const bool& burningIn,
00365 GroundPredicateHashArray& affectedGndPreds,
00366 Array<int>& affectedGndPredIndices)
00367 {
00368 if (mcmcdebug) cout << "Gibbs step" << endl;
00369
00370
00371 for (int i = 0; i < state_->getNumBlocks(); i++)
00372 {
00373
00374 if (state_->getBlockEvidence(i)) continue;
00375
00376 Array<int>& block = state_->getBlockArray(i);
00377
00378 int chosen = gibbsSampleFromBlock(chainIdx, block, 1);
00379
00380 bool truthValue;
00381 GroundPredicate* gndPred = state_->getGndPred(block[chosen]);
00382 if (numChains_ > 1) truthValue = truthValues_[block[chosen]][chainIdx];
00383 else truthValue = gndPred->getTruthValue();
00384
00385
00386 if (!truthValue)
00387 {
00388 for (int j = 0; j < block.size(); j++)
00389 {
00390
00391 bool otherTruthValue;
00392 GroundPredicate* otherGndPred = state_->getGndPred(block[j]);
00393 if (numChains_ > 1)
00394 otherTruthValue = truthValues_[block[j]][chainIdx];
00395 else
00396 otherTruthValue = otherGndPred->getTruthValue();
00397 if (otherTruthValue)
00398 {
00399
00400 if (numChains_ > 1)
00401 truthValues_[block[j]][chainIdx] = false;
00402 else
00403 otherGndPred->setTruthValue(false);
00404
00405 affectedGndPreds.clear();
00406 affectedGndPredIndices.clear();
00407 gndPredFlippedUpdates(block[j], chainIdx, affectedGndPreds,
00408 affectedGndPredIndices);
00409 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00410 chainIdx);
00411 }
00412 }
00413
00414
00415 if (numChains_ > 1) truthValues_[block[chosen]][chainIdx] = true;
00416 else gndPred->setTruthValue(true);
00417 affectedGndPreds.clear();
00418 affectedGndPredIndices.clear();
00419 gndPredFlippedUpdates(block[chosen], chainIdx, affectedGndPreds,
00420 affectedGndPredIndices);
00421 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00422 chainIdx);
00423 }
00424
00425
00426
00427 if (!burningIn) numTrue_[block[chosen]]++;
00428 }
00429
00430
00431 for (int i = 0; i < state_->getNumAtoms(); i++)
00432 {
00433
00434 if (state_->getBlockIndex(i) >= 0) continue;
00435
00436 if (mcmcdebug)
00437 {
00438 cout << "Chain " << chainIdx << ": Probability of pred "
00439 << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl;
00440 }
00441
00442 bool newAssignment
00443 = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1));
00444
00445
00446 bool truthValue;
00447 GroundPredicate* gndPred = state_->getGndPred(i);
00448 if (numChains_ > 1) truthValue = truthValues_[i][chainIdx];
00449 else truthValue = gndPred->getTruthValue();
00450
00451 if (newAssignment != truthValue)
00452 {
00453 if (mcmcdebug)
00454 {
00455 cout << "Chain " << chainIdx << ": Changing truth value of pred "
00456 << i << " to " << newAssignment << endl;
00457 }
00458
00459 if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment;
00460 else gndPred->setTruthValue(newAssignment);
00461 affectedGndPreds.clear();
00462 affectedGndPredIndices.clear();
00463 gndPredFlippedUpdates(i, chainIdx, affectedGndPreds,
00464 affectedGndPredIndices);
00465 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00466 chainIdx);
00467 }
00468
00469
00470
00471 if (!burningIn && newAssignment) numTrue_[i]++;
00472 }
00473
00474 if (!burningIn && trackClauseTrueCnts_)
00475 state_->getNumClauseGndings(clauseTrueCnts_, true);
00476
00477 if (mcmcdebug) cout << "End of Gibbs step" << endl;
00478 }
00479
00488 void updateWtsForGndPreds(GroundPredicateHashArray& gndPreds,
00489 Array<int>& gndPredIndices,
00490 const int& chainIdx)
00491 {
00492 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
00493
00494 for (int g = 0; g < gndPreds.size(); g++)
00495 {
00496 double wtIfNoChange = 0, wtIfInverted = 0, wt;
00497
00498 Array<int>& negGndClauses =
00499 state_->getNegOccurenceArray(gndPredIndices[g] + 1);
00500 Array<int>& posGndClauses =
00501 state_->getPosOccurenceArray(gndPredIndices[g] + 1);
00502 int gndClauseIdx;
00503 bool sense;
00504
00505 if (mcmcdebug)
00506 {
00507 cout << "Ground clauses in which pred " << g << " occurs neg.: "
00508 << negGndClauses.size() << endl;
00509 cout << "Ground clauses in which pred " << g << " occurs pos.: "
00510 << posGndClauses.size() << endl;
00511 }
00512
00513 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00514 {
00515 if (i < negGndClauses.size())
00516 {
00517 gndClauseIdx = negGndClauses[i];
00518 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
00519 sense = false;
00520 }
00521 else
00522 {
00523 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00524 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
00525 sense = true;
00526 }
00527
00528 GroundClause* gndClause = state_->getGndClause(gndClauseIdx);
00529 if (gndClause->isHardClause())
00530 wt = state_->getClauseCost(gndClauseIdx);
00531 else
00532 wt = gndClause->getWt();
00533
00534 int numSatLiterals;
00535 if (numChains_ > 1)
00536 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
00537 else
00538 numSatLiterals = state_->getNumTrueLits(gndClauseIdx);
00539 if (numSatLiterals > 1)
00540 {
00541
00542
00543 if (wt > 0)
00544 {
00545 wtIfNoChange += wt;
00546 wtIfInverted += wt;
00547 }
00548 }
00549 else
00550 if (numSatLiterals == 1)
00551 {
00552 if (wt > 0) wtIfNoChange += wt;
00553
00554 bool truthValue;
00555 if (numChains_ > 1)
00556 truthValue = truthValues_[gndPredIndices[g]][chainIdx];
00557 else
00558 truthValue = gndPreds[g]->getTruthValue();
00559
00560 if (truthValue == sense)
00561 {
00562
00563 if (wt < 0) wtIfInverted += abs(wt);
00564 }
00565 else
00566 {
00567
00568 if (wt > 0) wtIfInverted += wt;
00569 }
00570 }
00571 else
00572 if (numSatLiterals == 0)
00573 {
00574
00575 if (wt > 0) wtIfInverted += wt;
00576 else if (wt < 0) wtIfNoChange += abs(wt);
00577 }
00578 }
00579
00580 if (mcmcdebug)
00581 {
00582 cout << "wtIfNoChange of pred " << g << ": "
00583 << wtIfNoChange << endl;
00584 cout << "wtIfInverted of pred " << g << ": "
00585 << wtIfInverted << endl;
00586 }
00587
00588
00589 if (numChains_ > 1)
00590 {
00591 if (truthValues_[gndPredIndices[g]][chainIdx])
00592 {
00593 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
00594 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
00595 }
00596 else
00597 {
00598 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
00599 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
00600 }
00601 }
00602 else
00603 {
00604 if (gndPreds[g]->getTruthValue())
00605 {
00606 gndPreds[g]->setWtWhenTrue(wtIfNoChange);
00607 gndPreds[g]->setWtWhenFalse(wtIfInverted);
00608 }
00609 else
00610 {
00611 gndPreds[g]->setWtWhenFalse(wtIfNoChange);
00612 gndPreds[g]->setWtWhenTrue(wtIfInverted);
00613 }
00614 }
00615 }
00616 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
00617 }
00618
00628 int gibbsSampleFromBlock(const int& chainIdx, const Array<int>& block,
00629 const double& invTemp)
00630 {
00631 Array<double> numerators;
00632 double denominator = 0;
00633
00634 for (int i = 0; i < block.size(); i++)
00635 {
00636 double prob = getProbabilityOfPred(block[i], chainIdx, invTemp);
00637 numerators.append(prob);
00638 denominator += prob;
00639 }
00640 double r = random();
00641 double numSum = 0.0;
00642 for (int i = 0; i < block.size(); i++)
00643 {
00644 numSum += numerators[i];
00645 if (r < ((numSum / denominator) * RAND_MAX))
00646 {
00647 return i;
00648 }
00649 }
00650 return block.size() - 1;
00651 }
00652
00661 void gndPredFlippedUpdates(const int& gndPredIdx, const int& chainIdx,
00662 GroundPredicateHashArray& affectedGndPreds,
00663 Array<int>& affectedGndPredIndices)
00664 {
00665 if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl;
00666 int numAtoms = state_->getNumAtoms();
00667 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
00668 affectedGndPreds.append(gndPred, numAtoms);
00669 affectedGndPredIndices.append(gndPredIdx);
00670 assert(affectedGndPreds.size() <= numAtoms);
00671
00672 Array<int>& negGndClauses =
00673 state_->getNegOccurenceArray(gndPredIdx + 1);
00674 Array<int>& posGndClauses =
00675 state_->getPosOccurenceArray(gndPredIdx + 1);
00676 int gndClauseIdx;
00677 GroundClause* gndClause;
00678 bool sense;
00679
00680
00681 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00682 {
00683 if (i < negGndClauses.size())
00684 {
00685 gndClauseIdx = negGndClauses[i];
00686 sense = false;
00687 }
00688 else
00689 {
00690 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00691 sense = true;
00692 }
00693 gndClause = state_->getGndClause(gndClauseIdx);
00694
00695
00696 if (numChains_ > 1)
00697 {
00698 if (truthValues_[gndPredIdx][chainIdx] == sense)
00699 numTrueLits_[gndClauseIdx][chainIdx]++;
00700 else
00701 numTrueLits_[gndClauseIdx][chainIdx]--;
00702 }
00703 else
00704 {
00705 if (gndPred->getTruthValue() == sense)
00706 state_->incrementNumTrueLits(gndClauseIdx);
00707 else
00708 state_->decrementNumTrueLits(gndClauseIdx);
00709 }
00710
00711 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00712 {
00713 const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr();
00714 GroundPredicate* pred =
00715 (GroundPredicate*)gndClause->getGroundPredicate(j,
00716 (GroundPredicateHashArray*)gpha);
00717 affectedGndPreds.append(pred, numAtoms);
00718 affectedGndPredIndices.append(
00719 abs(gndClause->getGroundPredicateIndex(j)) - 1);
00720 assert(affectedGndPreds.size() <= numAtoms);
00721 }
00722 }
00723 if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl;
00724 }
00725
00726 double getProbTrue(const int& predIdx) const { return numTrue_[predIdx]; }
00727
00728 void setProbTrue(const int& predIdx, const double& p)
00729 {
00730 assert(p >= 0);
00731 numTrue_[predIdx] = p;
00732 }
00733
00740 void saveLowStateToChain(const int& chainIdx)
00741 {
00742 for (int i = 0; i < state_->getNumAtoms(); i++)
00743 truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1);
00744 }
00745
00751 void setMCMCParameters(MCMCParams* params)
00752 {
00753
00754 numChains_ = params->numChains;
00755 burnMinSteps_ = params->burnMinSteps;
00756 burnMaxSteps_ = params->burnMaxSteps;
00757 minSteps_ = params->minSteps;
00758 maxSteps_ = params->maxSteps;
00759 maxSeconds_ = params->maxSeconds;
00760 }
00761
00762 protected:
00763
00765
00766 int numChains_;
00767
00768 int burnMinSteps_;
00769
00770 int burnMaxSteps_;
00771
00772 int minSteps_;
00773
00774 int maxSteps_;
00775
00776 int maxSeconds_;
00778
00779
00780 Array<Array<bool> > truthValues_;
00781
00782 Array<Array<double> > wtsWhenFalse_;
00783
00784 Array<Array<double> > wtsWhenTrue_;
00785
00786
00787
00788 Array<double> numTrue_;
00789
00790
00791
00792 Array<Array<int> > numTrueLits_;
00793 };
00794
00795 #endif