mcmc.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef MCMC_H_
00067 #define MCMC_H_
00068 
00069 #include "inference.h"
00070 #include "mcmcparams.h"
00071 
00072   // Set to true for more output
00073 const bool mcmcdebug = false;
00074 
00079 class MCMC : public Inference
00080 {
00081  public:
00082 
00089   MCMC(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00090        MCMCParams* params)
00091     : Inference(state, seed, trackClauseTrueCnts)
00092   {
00093       // User-set parameters
00094     numChains_ = params->numChains;
00095     burnMinSteps_ = params->burnMinSteps;
00096     burnMaxSteps_ = params->burnMaxSteps;
00097     minSteps_ = params->minSteps;
00098     maxSteps_ = params->maxSteps;
00099     maxSeconds_ = params->maxSeconds;
00100   }
00101 
00105   ~MCMC() {}
00106 
00110   void printProbabilities(ostream& out)
00111   {
00112     for (int i = 0; i < state_->getNumAtoms(); i++)
00113     {
00114       double prob = getProbTrue(i);
00115 
00116         // Uniform smoothing
00117       prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00118       state_->printGndPred(i, out);
00119       out << " " << prob << endl;
00120     }    
00121   }
00122 
00131   void getPredsWithNonZeroProb(vector<string>& nonZeroPreds,
00132                                vector<float>& probs)
00133   {
00134     nonZeroPreds.clear();
00135     probs.clear();
00136     for (int i = 0; i < state_->getNumAtoms(); i++)
00137     {
00138       double prob = getProbTrue(i);
00139       if (prob > 0)
00140       {
00141           // Uniform smoothing
00142         prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00143         ostringstream oss(ostringstream::out);
00144         state_->printGndPred(i, oss);
00145         nonZeroPreds.push_back(oss.str());
00146         probs.push_back(prob);
00147       }
00148     }
00149   }
00150 
00157   double getProbability(GroundPredicate* const& gndPred)
00158   {
00159     int idx = state_->getGndPredIndex(gndPred);
00160     double prob = 0.0;
00161     if (idx >= 0) prob = getProbTrue(idx);
00162       // Uniform smoothing
00163     return (prob*10000 + 1/2.0)/(10000 + 1.0);
00164   }
00165 
00169   void printTruePreds(ostream& out)
00170   {
00171     for (int i = 0; i < state_->getNumAtoms(); i++)
00172     {
00173       double prob = getProbTrue(i);
00174 
00175         // Uniform smoothing
00176       prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00177       if (prob >= 0.5) state_->printGndPred(i, out);
00178     }    
00179   }
00180 
00181  protected:
00182 
00191   void initTruthValuesAndWts(const int& numChains)
00192   {
00193     int numPreds = state_->getNumAtoms();
00194     truthValues_.growToSize(numPreds);
00195     wtsWhenFalse_.growToSize(numPreds);
00196     wtsWhenTrue_.growToSize(numPreds);
00197     for (int i = 0; i < numPreds; i++)
00198     {
00199       truthValues_[i].growToSize(numChains, false);
00200       wtsWhenFalse_[i].growToSize(numChains, 0);
00201       wtsWhenTrue_[i].growToSize(numChains, 0);
00202     }
00203     
00204     int numClauses = state_->getNumClauses();
00205     numTrueLits_.growToSize(numClauses);
00206     for (int i = 0; i < numClauses; i++)
00207     {
00208       numTrueLits_[i].growToSize(numChains, 0);
00209     }
00210   }
00211 
00216   void initNumTrue()
00217   {
00218     int numPreds = state_->getNumAtoms();
00219     numTrue_.growToSize(numPreds);
00220     for (int i = 0; i < numTrue_.size(); i++)
00221       numTrue_[i] = 0;
00222   }
00223 
00230   void initNumTrueLits(const int& numChains)
00231   {
00232     for (int i = 0; i < state_->getNumClauses(); i++)
00233     {
00234       GroundClause* gndClause = state_->getGndClause(i);
00235       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00236       {
00237         const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1;
00238         for (int c = 0; c < numChains; c++)
00239         {
00240           if (truthValues_[atomIdx][c] == gndClause->getGroundPredicateSense(j))
00241           {
00242             numTrueLits_[i][c]++;
00243             assert(numTrueLits_[i][c] <= state_->getNumAtoms());
00244           }
00245         }
00246       }
00247     }
00248   }
00249  
00257   void randomInitGndPredsTruthValues(const int& numChains)
00258   {
00259     for (int c = 0; c < numChains; c++)
00260     {
00261         // For each block: select one to set to true
00262       for (int i = 0; i < state_->getNumBlocks(); i++)
00263       {
00264           // If evidence atom exists, then all others are false
00265         if (state_->getBlockEvidence(i))
00266         {
00267             // If 2nd argument is -1, then all are set to false
00268           setOthersInBlockToFalse(c, -1, i);
00269           continue;
00270         }
00271         
00272         Array<int>& block = state_->getBlockArray(i);
00273         int chosen = random() % block.size();
00274         truthValues_[block[chosen]][c] = true;
00275         setOthersInBlockToFalse(c, chosen, i);
00276       }
00277       
00278         // Random tv for all not in blocks
00279       for (int i = 0; i < truthValues_.size(); i++)
00280       {
00281           // Predicates in blocks have been handled above
00282         if (state_->getBlockIndex(i) == -1)
00283         {
00284           bool tv = genTruthValueForProb(0.5);
00285           truthValues_[i][c] = tv;
00286         }
00287       }
00288     }
00289   }
00290 
00297   bool genTruthValueForProb(const double& p)
00298   {
00299     if (p == 1.0) return true;
00300     if (p == 0.0) return false;
00301     bool r = random() <= p*RAND_MAX;
00302     return r;
00303   }
00304 
00314   double getProbabilityOfPred(const int& predIdx, const int& chainIdx,
00315                               const double& invTemp)
00316   {
00317       // Different for multi-chain
00318     if (numChains_ > 1)
00319     {
00320       return 1.0 /
00321              ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] - 
00322                           wtsWhenTrue_[predIdx][chainIdx]) *
00323                           invTemp));      
00324     }
00325     else
00326     {
00327       GroundPredicate* gndPred = state_->getGndPred(predIdx);
00328       return 1.0 /
00329              ( 1.0 + exp((gndPred->getWtWhenFalse() - 
00330                           gndPred->getWtWhenTrue()) *
00331                           invTemp));
00332     }
00333   }
00334  
00343   void setOthersInBlockToFalse(const int& chainIdx, const int& atomIdx,
00344                                const int& blockIdx)
00345   {
00346     Array<int>& block = state_->getBlockArray(blockIdx);
00347     for (int i = 0; i < block.size(); i++)
00348     {
00349       if (i != atomIdx)
00350         truthValues_[block[i]][chainIdx] = false;
00351     }
00352   }
00353 
00364   void performGibbsStep(const int& chainIdx, const bool& burningIn,
00365                         GroundPredicateHashArray& affectedGndPreds,
00366                         Array<int>& affectedGndPredIndices)
00367   {
00368     if (mcmcdebug) cout << "Gibbs step" << endl;
00369 
00370       // For each block: select one to set to true
00371     for (int i = 0; i < state_->getNumBlocks(); i++)
00372     {
00373         // If evidence atom exists, then all others stay false
00374       if (state_->getBlockEvidence(i)) continue;
00375  
00376       Array<int>& block = state_->getBlockArray(i);
00377         // chosen is index in the block, block[chosen] is index in gndPreds_
00378       int chosen = gibbsSampleFromBlock(chainIdx, block, 1);
00379         // Truth values are stored differently for multi-chain
00380       bool truthValue;
00381       GroundPredicate* gndPred = state_->getGndPred(block[chosen]);
00382       if (numChains_ > 1) truthValue = truthValues_[block[chosen]][chainIdx];
00383       else truthValue = gndPred->getTruthValue();
00384         // If chosen pred was false, then need to set previous true
00385         // one to false and update wts
00386       if (!truthValue)
00387       {
00388         for (int j = 0; j < block.size(); j++)
00389         {
00390             // Truth values are stored differently for multi-chain
00391           bool otherTruthValue;
00392           GroundPredicate* otherGndPred = state_->getGndPred(block[j]);
00393           if (numChains_ > 1)
00394             otherTruthValue = truthValues_[block[j]][chainIdx];
00395           else
00396             otherTruthValue = otherGndPred->getTruthValue();
00397           if (otherTruthValue)
00398           {
00399               // Truth values are stored differently for multi-chain
00400             if (numChains_ > 1)
00401               truthValues_[block[j]][chainIdx] = false;
00402             else
00403               otherGndPred->setTruthValue(false);
00404               
00405             affectedGndPreds.clear();
00406             affectedGndPredIndices.clear();
00407             gndPredFlippedUpdates(block[j], chainIdx, affectedGndPreds,
00408                                   affectedGndPredIndices);
00409             updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00410                                  chainIdx);
00411           }
00412         }
00413           // Set truth value and update wts for chosen atom
00414           // Truth values are stored differently for multi-chain
00415         if (numChains_ > 1) truthValues_[block[chosen]][chainIdx] = true;
00416         else gndPred->setTruthValue(true);
00417         affectedGndPreds.clear();
00418         affectedGndPredIndices.clear();
00419         gndPredFlippedUpdates(block[chosen], chainIdx, affectedGndPreds,
00420                               affectedGndPredIndices);
00421         updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00422                              chainIdx);
00423       }
00424 
00425         // If in actual gibbs sampling phase, track the num of times
00426         // the ground predicate is set to true
00427       if (!burningIn) numTrue_[block[chosen]]++;
00428     }
00429 
00430       // Now go through all preds not in blocks
00431     for (int i = 0; i < state_->getNumAtoms(); i++)
00432     {
00433         // Predicates in blocks have been handled above
00434       if (state_->getBlockIndex(i) >= 0) continue;
00435 
00436       if (mcmcdebug)
00437       {
00438         cout << "Chain " << chainIdx << ": Probability of pred "
00439              << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl;
00440       }
00441       
00442       bool newAssignment
00443         = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1));
00444 
00445         // Truth values are stored differently for multi-chain
00446       bool truthValue;
00447       GroundPredicate* gndPred = state_->getGndPred(i);
00448       if (numChains_ > 1) truthValue = truthValues_[i][chainIdx];
00449       else truthValue = gndPred->getTruthValue();
00450         // If gndPred is flipped, do updates & find all affected gndPreds
00451       if (newAssignment != truthValue)
00452       {
00453         if (mcmcdebug)
00454         {
00455           cout << "Chain " << chainIdx << ": Changing truth value of pred "
00456                << i << " to " << newAssignment << endl;
00457         }
00458         
00459         if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment;
00460         else gndPred->setTruthValue(newAssignment);
00461         affectedGndPreds.clear();
00462         affectedGndPredIndices.clear();
00463         gndPredFlippedUpdates(i, chainIdx, affectedGndPreds,
00464                               affectedGndPredIndices);
00465         updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00466                              chainIdx);
00467       }
00468 
00469         // If in actual gibbs sampling phase, track the num of times
00470         // the ground predicate is set to true
00471       if (!burningIn && newAssignment) numTrue_[i]++;
00472     }
00473       // If keeping track of true clause groundings
00474     if (!burningIn && trackClauseTrueCnts_)
00475       state_->getNumClauseGndings(clauseTrueCnts_, true);
00476 
00477     if (mcmcdebug) cout << "End of Gibbs step" << endl;
00478   }
00479 
00488   void updateWtsForGndPreds(GroundPredicateHashArray& gndPreds,
00489                             Array<int>& gndPredIndices,
00490                             const int& chainIdx)
00491   {
00492     if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
00493       // for each ground predicate whose MB has changed
00494     for (int g = 0; g < gndPreds.size(); g++)
00495     {
00496       double wtIfNoChange = 0, wtIfInverted = 0, wt;
00497         // Ground clauses in which this pred occurs
00498       Array<int>& negGndClauses =
00499         state_->getNegOccurenceArray(gndPredIndices[g] + 1);
00500       Array<int>& posGndClauses =
00501         state_->getPosOccurenceArray(gndPredIndices[g] + 1);
00502       int gndClauseIdx;
00503       bool sense;
00504       
00505       if (mcmcdebug)
00506       {
00507         cout << "Ground clauses in which pred " << g << " occurs neg.: "
00508              << negGndClauses.size() << endl;
00509         cout << "Ground clauses in which pred " << g << " occurs pos.: "
00510              << posGndClauses.size() << endl;
00511       }
00512       
00513       for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00514       {
00515         if (i < negGndClauses.size())
00516         {
00517           gndClauseIdx = negGndClauses[i];
00518           if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
00519           sense = false;
00520         }
00521         else
00522         {
00523           gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00524           if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
00525           sense = true;
00526         }
00527         
00528         GroundClause* gndClause = state_->getGndClause(gndClauseIdx);
00529         if (gndClause->isHardClause())
00530           wt = state_->getClauseCost(gndClauseIdx);
00531         else
00532           wt = gndClause->getWt();
00533           // NumTrueLits are stored differently for multi-chain
00534         int numSatLiterals;
00535         if (numChains_ > 1)
00536           numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
00537         else
00538           numSatLiterals = state_->getNumTrueLits(gndClauseIdx);
00539         if (numSatLiterals > 1)
00540         {
00541             // Some other literal is making it sat, so it doesn't matter
00542             // if pos. clause. If neg., nothing can be done to unsatisfy it.
00543           if (wt > 0)
00544           {
00545             wtIfNoChange += wt;
00546             wtIfInverted += wt;
00547           }
00548         }
00549         else 
00550         if (numSatLiterals == 1) 
00551         {
00552           if (wt > 0) wtIfNoChange += wt;
00553             // Truth values are stored differently for multi-chain
00554           bool truthValue;
00555           if (numChains_ > 1)
00556             truthValue = truthValues_[gndPredIndices[g]][chainIdx];
00557           else
00558             truthValue = gndPreds[g]->getTruthValue();
00559             // If the current truth value is the same as its sense in gndClause
00560           if (truthValue == sense) 
00561           {
00562             // This gndPred is the only one making this function satisfied
00563             if (wt < 0) wtIfInverted += abs(wt);
00564           }
00565           else 
00566           {
00567               // Some other literal is making it satisfied
00568             if (wt > 0) wtIfInverted += wt;
00569           }
00570         }
00571         else
00572         if (numSatLiterals == 0) 
00573         {
00574           // None satisfy, so when gndPred switch to its negative, it'll satisfy
00575           if (wt > 0) wtIfInverted += wt;
00576           else if (wt < 0) wtIfNoChange += abs(wt);
00577         }
00578       } // for each ground clause that gndPred appears in
00579 
00580       if (mcmcdebug)
00581       {
00582         cout << "wtIfNoChange of pred " << g << ": "
00583              << wtIfNoChange << endl;
00584         cout << "wtIfInverted of pred " << g << ": "
00585              << wtIfInverted << endl;
00586       }
00587 
00588         // Clause info is stored differently for multi-chain
00589       if (numChains_ > 1)
00590       {
00591         if (truthValues_[gndPredIndices[g]][chainIdx]) 
00592         {
00593           wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
00594           wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
00595         }
00596         else 
00597         {
00598           wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
00599           wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
00600         }
00601       }
00602       else
00603       { // Single chain
00604         if (gndPreds[g]->getTruthValue())
00605         {
00606           gndPreds[g]->setWtWhenTrue(wtIfNoChange);
00607           gndPreds[g]->setWtWhenFalse(wtIfInverted);
00608         }
00609         else
00610         {
00611           gndPreds[g]->setWtWhenFalse(wtIfNoChange);
00612           gndPreds[g]->setWtWhenTrue(wtIfInverted);            
00613         }
00614       }
00615     } // for each ground predicate whose MB has changed
00616     if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
00617   }
00618 
00628   int gibbsSampleFromBlock(const int& chainIdx, const Array<int>& block,
00629                            const double& invTemp)
00630   {
00631     Array<double> numerators;
00632     double denominator = 0;
00633     
00634     for (int i = 0; i < block.size(); i++)
00635     {
00636       double prob = getProbabilityOfPred(block[i], chainIdx, invTemp);
00637       numerators.append(prob);
00638       denominator += prob;
00639     }
00640     double r = random();
00641     double numSum = 0.0;
00642     for (int i = 0; i < block.size(); i++)
00643     {
00644       numSum += numerators[i];
00645       if (r < ((numSum / denominator) * RAND_MAX))
00646       {
00647         return i;
00648       }
00649     }
00650     return block.size() - 1;
00651   }
00652 
00661   void gndPredFlippedUpdates(const int& gndPredIdx, const int& chainIdx,
00662                              GroundPredicateHashArray& affectedGndPreds,
00663                              Array<int>& affectedGndPredIndices)
00664   {
00665     if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl;
00666     int numAtoms = state_->getNumAtoms();
00667     GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
00668     affectedGndPreds.append(gndPred, numAtoms);
00669     affectedGndPredIndices.append(gndPredIdx);
00670     assert(affectedGndPreds.size() <= numAtoms);
00671 
00672     Array<int>& negGndClauses =
00673       state_->getNegOccurenceArray(gndPredIdx + 1);
00674     Array<int>& posGndClauses =
00675       state_->getPosOccurenceArray(gndPredIdx + 1);
00676     int gndClauseIdx;
00677     GroundClause* gndClause; 
00678     bool sense;
00679 
00680       // Find the Markov blanket of this ground predicate
00681     for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00682     {
00683       if (i < negGndClauses.size())
00684       {
00685         gndClauseIdx = negGndClauses[i];
00686         sense = false;
00687       }
00688       else
00689       {
00690         gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00691         sense = true;
00692       }
00693       gndClause = state_->getGndClause(gndClauseIdx);
00694 
00695         // Different for multi-chain
00696       if (numChains_ > 1)
00697       {        
00698         if (truthValues_[gndPredIdx][chainIdx] == sense)
00699           numTrueLits_[gndClauseIdx][chainIdx]++;
00700         else
00701           numTrueLits_[gndClauseIdx][chainIdx]--;
00702       }
00703       else
00704       { // Single chain
00705         if (gndPred->getTruthValue() == sense)
00706           state_->incrementNumTrueLits(gndClauseIdx);
00707         else
00708           state_->decrementNumTrueLits(gndClauseIdx);
00709       }
00710       
00711       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00712       {
00713         const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr();
00714         GroundPredicate* pred = 
00715           (GroundPredicate*)gndClause->getGroundPredicate(j,
00716             (GroundPredicateHashArray*)gpha);
00717         affectedGndPreds.append(pred, numAtoms);
00718         affectedGndPredIndices.append(
00719                                abs(gndClause->getGroundPredicateIndex(j)) - 1);
00720         assert(affectedGndPreds.size() <= numAtoms);
00721       }
00722     }
00723     if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl;
00724   }
00725 
00726   double getProbTrue(const int& predIdx) const { return numTrue_[predIdx]; }
00727   
00728   void setProbTrue(const int& predIdx, const double& p)
00729   { 
00730     assert(p >= 0);
00731     numTrue_[predIdx] = p;
00732   }
00733 
00740   void saveLowStateToChain(const int& chainIdx)
00741   {
00742     for (int i = 0; i < state_->getNumAtoms(); i++)
00743       truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1);
00744   }
00745 
00751   void setMCMCParameters(MCMCParams* params)
00752   {
00753       // User-set parameters
00754     numChains_ = params->numChains;
00755     burnMinSteps_ = params->burnMinSteps;
00756     burnMaxSteps_ = params->burnMaxSteps;
00757     minSteps_ = params->minSteps;
00758     maxSteps_ = params->maxSteps;
00759     maxSeconds_ = params->maxSeconds;    
00760   }
00761   
00762  protected:
00763  
00765     // No. of chains which MCMC will use
00766   int numChains_;
00767     // Min. no. of burn-in steps MCMC will take per chain
00768   int burnMinSteps_;
00769     // Max. no. of burn-in steps MCMC will take per chain
00770   int burnMaxSteps_;
00771     // Min. no. of sampling steps MCMC will take per chain
00772   int minSteps_;
00773     // Max. no. of sampling steps MCMC will take per chain
00774   int maxSteps_;
00775     // Max. no. of seconds MCMC should run
00776   int maxSeconds_;
00778 
00779     // Truth values in each chain for each ground predicate (truthValues_[p][c])
00780   Array<Array<bool> > truthValues_;
00781     // Wts when false in each chain for each ground predicate
00782   Array<Array<double> > wtsWhenFalse_;
00783     // Wts when true in each chain for each groud predicate
00784   Array<Array<double> > wtsWhenTrue_;
00785 
00786     // Number of times each ground predicate is set to true
00787     // overloaded to hold probability that ground predicate is true
00788   Array<double> numTrue_; // numTrue_[p]
00789 
00790     // Num. of satisfying literals in each chain for each groud predicate
00791     // numTrueLits_[clause][chain]
00792   Array<Array<int> > numTrueLits_;
00793 };
00794 
00795 #endif /*MCMC_H_*/

Generated on Wed Feb 14 15:15:17 2007 for Alchemy by  doxygen 1.5.1