votedperceptron.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef VOTED_PERCEPTRON_H_OCT_30_2005
00067 #define VOTED_PERCEPTRON_H_OCT_30_2005
00068 
00069 #include "infer.h"
00070 #include "clause.h"
00071 #include "timer.h"
00072 #include "indextranslator.h"
00073 #include "maxwalksat.h"
00074 
00075 const bool vpdebug = false;
00076 const double EPSILON=.00001;
00077 
00082 class VotedPerceptron 
00083 {
00084  public:
00085 
00100   VotedPerceptron(const Array<Inference*>& inferences,
00101                   const StringHashArray& nonEvidPredNames,
00102                   IndexTranslator* const & idxTrans, const bool& lazyInference,
00103                   const bool& rescaleGradient, const bool& withEM)
00104     : domainCnt_(inferences.size()), idxTrans_(idxTrans),
00105       lazyInference_(lazyInference), rescaleGradient_(rescaleGradient),
00106       withEM_(withEM)
00107   { 
00108     cout << endl << "Constructing voted perceptron..." << endl << endl;
00109 
00110     inferences_.append(inferences);
00111     logOddsPerDomain_.growToSize(domainCnt_);
00112     clauseCntPerDomain_.growToSize(domainCnt_);
00113     
00114     for (int i = 0; i < domainCnt_; i++)
00115     {
00116       clauseCntPerDomain_[i] =
00117         inferences_[i]->getState()->getMLN()->getNumClauses();
00118       logOddsPerDomain_[i].growToSize(clauseCntPerDomain_[i], 0);
00119     }
00120 
00121     totalTrueCnts_.growToSize(domainCnt_);
00122     defaultTrueCnts_.growToSize(domainCnt_);
00123     relevantClausesPerDomain_.growToSize(domainCnt_);
00124     //relevantClausesFormulas_ is set in findRelevantClausesFormulas()
00125 
00126     findRelevantClauses(nonEvidPredNames);
00127     findRelevantClausesFormulas();
00128 
00129       // Initialize the clause wts for lazy version
00130     if (lazyInference_)
00131     {
00132       findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(nonEvidPredNames);
00133     
00134       for (int i = 0; i < domainCnt_; i++)
00135       {
00136         const MLN* mln = inferences_[i]->getState()->getMLN();
00137         Array<double>& logOdds = logOddsPerDomain_[i];
00138         assert(mln->getNumClauses() == logOdds.size());
00139         for (int j = 0; j < mln->getNumClauses(); j++)
00140           ((Clause*) mln->getClause(j))->setWt(logOdds[j]);
00141       }
00142     }
00143       // Initialize the clause wts for eager version
00144     else
00145     {      
00146       initializeWts();
00147     }
00148     
00149       // Initialize the inference / state
00150     for (int i = 0; i < inferences_.size(); i++)
00151       inferences_[i]->init();
00152   }
00153 
00154 
00155   ~VotedPerceptron() 
00156   {
00157     for (int i = 0; i < trainTrueCnts_.size(); i++)
00158       delete[] trainTrueCnts_[i];
00159   }
00160 
00161 
00162     // set the prior means and std devs.
00163   void setMeansStdDevs(const int& arrSize, const double* const & priorMeans, 
00164                        const double* const & priorStdDevs) 
00165   {
00166     if (arrSize < 0) 
00167     {
00168       usePrior_ = false;
00169       priorMeans_ = NULL;
00170       priorStdDevs_ = NULL;
00171     } 
00172     else 
00173     {
00174       //cout << "arr size = " << arrSize<<", clause count = "<<clauseCnt_<<endl;
00175       usePrior_ = true;
00176       priorMeans_ = priorMeans;
00177       priorStdDevs_ = priorStdDevs;
00178 
00179       //cout << "\t\t Mean \t\t Std Deviation" << endl;
00180       //for (int i = 0; i < arrSize; i++) 
00181       //  cout << i << "\t\t" << priorMeans_[i]<<"\t\t"<<priorStdDevs_[i]<<endl;
00182     }
00183   }
00184 
00185 
00186     // learn the weights
00187   void learnWeights(double* const & weights, const int& numWeights,
00188                     const int& maxIter, const double& learningRate,
00189                     const double& momentum, bool initWithLogOdds,
00190                     const int& mwsMaxSubsequentSteps) 
00191   {
00192     //cout << "Learning weights discriminatively... " << endl;
00193     memset(weights, 0, numWeights*sizeof(double));
00194 
00195     double* averageWeights = new double[numWeights];
00196     double* gradient = new double[numWeights];
00197     double* lastchange = new double[numWeights];
00198 
00199       // Set the initial weight to the average log odds across domains/databases
00200     if (initWithLogOdds)
00201     {
00202         // If there is one db or the clauses for multiple databases line up
00203       if (idxTrans_ == NULL)
00204       {
00205         for (int i = 0; i < domainCnt_; i++)
00206         {
00207           Array<double>& logOdds = logOddsPerDomain_[i];
00208           assert(numWeights == logOdds.size());
00209           for (int j = 0; j < logOdds.size(); j++) weights[j] += logOdds[j];
00210         }
00211       }
00212       else
00213       { //the clauses for multiple databases do not line up
00214         const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain 
00215           = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00216 
00217         Array<int> numLogOdds; 
00218         Array<double> wtsForDomain;
00219         numLogOdds.growToSize(numWeights);
00220         wtsForDomain.growToSize(numWeights);
00221       
00222         for (int i = 0; i < domainCnt_; i++)
00223         {
00224           memset((int*)numLogOdds.getItems(), 0, numLogOdds.size()*sizeof(int));
00225           memset((double*)wtsForDomain.getItems(), 0,
00226                  wtsForDomain.size()*sizeof(double));
00227 
00228           Array<double>& logOdds = logOddsPerDomain_[i];
00229         
00230             // Map the each log odds of a clause to the weight of a
00231             // clause/formula
00232           for (int j = 0; j < logOdds.size(); j++)
00233           {
00234             Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];          
00235             for (int k = 0; k < idxDivs->size(); k++)
00236             {
00237               wtsForDomain[ (*idxDivs)[k].idx ] += logOdds[j];
00238               numLogOdds[ (*idxDivs)[k].idx ]++;
00239             }
00240           }
00241 
00242           for (int j = 0; j < numWeights; j++)
00243             if (numLogOdds[j] > 0) weights[j] += wtsForDomain[j]/numLogOdds[j];  
00244         }
00245       }
00246     }
00247 
00248       // Initialize weights, averageWeights, lastchange
00249     for (int i = 0; i < numWeights; i++) 
00250     {      
00251       weights[i] /= domainCnt_;
00252       averageWeights[i] = weights[i];
00253       lastchange[i] = 0.0;
00254     }
00255 
00256     for (int iter = 1; iter <= maxIter; iter++) 
00257     {
00258       cout << endl << "Iteration " << iter << " : " << endl << endl;
00259 
00260         // In 3rd iteration, we want to tell MWS to perform subsequentSteps
00261         // (Iter. 1 is random assigment if initial weights are 0)
00262       if (iter == 3)
00263       {
00264         for (int i = 0; i < inferences_.size(); i++)
00265         {
00266             // Check if using MWS
00267           if (MaxWalkSat* mws = dynamic_cast<MaxWalkSat*>(inferences_[i]))
00268           {
00269             mws->setMaxSteps(mwsMaxSubsequentSteps);
00270           }
00271         }
00272       }
00273 
00274       cout << "Getting the gradient.. " << endl;
00275       getGradient(weights, gradient, numWeights);
00276       cout << endl; 
00277 
00278         // Add gradient to weights
00279       for (int w = 0; w < numWeights; w++) 
00280       {
00281         double wchange = gradient[w] * learningRate + lastchange[w] * momentum;
00282         cout << "clause/formula " << w << ": wtChange = " << wchange;
00283         cout << "  oldWt = " << weights[w];
00284         weights[w] += wchange;
00285         lastchange[w] = wchange;
00286         cout << "  newWt = " << weights[w];
00287         averageWeights[w] = (iter * averageWeights[w] + weights[w])/(iter + 1);
00288         cout << "  averageWt = " << averageWeights[w] << endl;
00289       }
00290       // done with an iteration
00291     }
00292     
00293     cout << endl << "Learned Weights : " << endl;
00294     for (int w = 0; w < numWeights; w++) 
00295     {
00296       weights[w] = averageWeights[w];
00297       cout << w << ":" << weights[w] << endl;
00298     }
00299 
00300     delete [] averageWeights;
00301     delete [] gradient;
00302     delete [] lastchange;
00303     
00304     resetDBs();
00305   }
00306  
00307  
00308  private:
00309  
00313   void resetDBs() 
00314   {
00315     if (!lazyInference_)
00316     {
00317       for (int i = 0; i < domainCnt_; i++) 
00318       {
00319         VariableState* state = inferences_[i]->getState();
00320         Database* db = state->getDomain()->getDB();
00321           // Change known NE to original values
00322         const GroundPredicateHashArray* knePreds = state->getKnePreds();
00323         const Array<TruthValue>* knePredValues = state->getKnePredValues();      
00324         db->setValuesToGivenValues(knePreds, knePredValues);
00325           // Set unknown NE back to UKNOWN
00326         const GroundPredicateHashArray* unePreds = state->getUnePreds();
00327         for (int predno = 0; predno < unePreds->size(); predno++) 
00328           db->setValue((*unePreds)[predno], UNKNOWN);
00329       }
00330     }
00331   }
00332 
00338   void findRelevantClauses(const StringHashArray& nonEvidPredNames) 
00339   {
00340     for (int d = 0; d < domainCnt_; d++)
00341     {
00342       int clauseCnt = clauseCntPerDomain_[d];
00343       Array<bool>& relevantClauses = relevantClausesPerDomain_[d];
00344       relevantClauses.growToSize(clauseCnt);
00345       memset((bool*)relevantClauses.getItems(), false, 
00346              relevantClauses.size()*sizeof(bool));
00347       const Domain* domain = inferences_[d]->getState()->getDomain();
00348       const MLN* mln = inferences_[d]->getState()->getMLN();
00349     
00350       const Array<IndexClause*>* indclauses;
00351       const Clause* clause;
00352       int predid, clauseid;
00353       for (int i = 0; i < nonEvidPredNames.size(); i++)
00354       {
00355         predid = domain->getPredicateId(nonEvidPredNames[i].c_str());
00356         //cout << "finding the relevant clauses for predid = " << predid 
00357         //     << " in domain " << d << endl;
00358         indclauses = mln->getClausesContainingPred(predid);
00359         if (indclauses) 
00360         {
00361           for (int j = 0; j < indclauses->size(); j++) 
00362           {
00363             clause = (*indclauses)[j]->clause;                  
00364             clauseid = mln->findClauseIdx(clause);
00365             relevantClauses[clauseid] = true;
00366             //cout << clauseid << " ";
00367           }
00368           //cout<<endl;
00369         }
00370       }    
00371     }
00372   }
00373 
00374   
00375   void findRelevantClausesFormulas()
00376   {
00377     if (idxTrans_ == NULL)
00378     {
00379       Array<bool>& relevantClauses = relevantClausesPerDomain_[0];
00380       relevantClausesFormulas_.growToSize(relevantClauses.size());
00381       for (int i = 0; i < relevantClauses.size(); i++)
00382         relevantClausesFormulas_[i] = relevantClauses[i];
00383     }
00384     else
00385     {
00386       idxTrans_->setRelevantClausesFormulas(relevantClausesFormulas_,
00387                                             relevantClausesPerDomain_[0]);
00388       cout << "Relevant clauses/formulas:" << endl;
00389       idxTrans_->printRelevantClausesFormulas(cout, relevantClausesFormulas_);
00390       cout << endl;
00391     }
00392   }
00393 
00394 
00404   void calculateCounts(Array<double>& trueCnt, Array<double>& falseCnt,
00405                        const int& domainIdx, const bool& hasUnknownPreds) 
00406   {
00407     Clause* clause;
00408     double tmpUnknownCnt;
00409     int clauseCnt = clauseCntPerDomain_[domainIdx];
00410     Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00411     const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00412     const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00413 
00414     for (int clauseno = 0; clauseno < clauseCnt; clauseno++) 
00415     {
00416       if (!relevantClauses[clauseno]) 
00417       {
00418         continue;
00419         //cout << "\n\nthis is an irrelevant clause.." << endl;
00420       }
00421       clause = (Clause*) mln->getClause(clauseno);
00422       clause->getNumTrueFalseUnknownGroundings(domain, domain->getDB(), 
00423                                                hasUnknownPreds,
00424                                                trueCnt[clauseno],
00425                                                falseCnt[clauseno],
00426                                                tmpUnknownCnt);
00427       assert(hasUnknownPreds || (tmpUnknownCnt==0));
00428     }
00429   }
00430 
00431 
00432   void initializeWts()
00433   {
00434     cout << "Initializing weights ..." << endl;
00435     Array<double *> trainFalseCnts;
00436     trainTrueCnts_.growToSize(domainCnt_);
00437     trainFalseCnts.growToSize(domainCnt_);
00438   
00439     for (int i = 0; i < domainCnt_; i++)
00440     {
00441       int clauseCnt = clauseCntPerDomain_[i];
00442       VariableState* state = inferences_[i]->getState();
00443       const GroundPredicateHashArray* unePreds = state->getUnePreds();
00444       const GroundPredicateHashArray* knePreds = state->getKnePreds();
00445 
00446       trainTrueCnts_[i] = new double[clauseCnt];
00447       trainFalseCnts[i] = new double[clauseCnt];
00448 
00449       int totalPreds = unePreds->size() + knePreds->size();
00450         // Used to store gnd preds to be ignored in the count because they are
00451         // UNKNOWN
00452       Array<bool>* unknownPred = new Array<bool>;
00453       unknownPred->growToSize(totalPreds, false);
00454       for (int predno = 0; predno < totalPreds; predno++) 
00455       {
00456         GroundPredicate* p;
00457         if (predno < unePreds->size())
00458           p = (*unePreds)[predno];
00459         else
00460           p = (*knePreds)[predno - unePreds->size()];
00461         TruthValue tv = state->getDomain()->getDB()->getValue(p);
00462 
00463         //assert(tv != UNKNOWN);
00464         if (tv == TRUE)
00465         {
00466           state->setValueOfAtom(predno + 1, true);
00467           p->setTruthValue(true);
00468         }
00469         else
00470         {
00471           state->setValueOfAtom(predno + 1, false);
00472           p->setTruthValue(false);
00473             // Can have unknown truth values when using EM. We want to ignore
00474             // these when performing the counts
00475           if (tv == UNKNOWN)
00476           {
00477             (*unknownPred)[predno] = true;
00478           }
00479         }
00480       }
00481 
00482       state->initMakeBreakCostWatch();
00483       //cout<<"getting true cnts => "<<endl;
00484       state->getNumClauseGndingsWithUnknown(trainTrueCnts_[i], clauseCnt, true,
00485                                             unknownPred);
00486       //cout<<endl;
00487       //cout<<"getting false cnts => "<<endl;
00488       state->getNumClauseGndingsWithUnknown(trainFalseCnts[i], clauseCnt, false,
00489                                             unknownPred);
00490       delete unknownPred;
00491       if (vpdebug)
00492       {
00493         for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00494         {
00495           cout << clauseno << " : tc = " << trainTrueCnts_[i][clauseno]
00496                << " ** fc = " << trainFalseCnts[i][clauseno] << endl;
00497         }
00498       }
00499     }
00500 
00501     double tc,fc;
00502     cout << "List of CNF Clauses : " << endl;
00503     for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++)
00504     {
00505       if (!relevantClausesPerDomain_[0][clauseno])
00506       {
00507         for (int i = 0; i < domainCnt_; i++)
00508         {
00509           Array<double>& logOdds = logOddsPerDomain_[i];
00510           logOdds[clauseno] = 0.0;
00511         }
00512         continue;
00513       }
00514       //cout << endl << endl;
00515       cout << clauseno << ":";
00516       const Clause* clause =
00517         inferences_[0]->getState()->getMLN()->getClause(clauseno);
00518       //cout << (*fncArr)[clauseno]->formula <<endl;
00519       clause->print(cout, inferences_[0]->getState()->getDomain());
00520       cout << endl;
00521       
00522       tc = 0.0; fc = 0.0;
00523       for (int i = 0; i < domainCnt_;i++)
00524       {
00525         tc += trainTrueCnts_[i][clauseno];
00526         fc += trainFalseCnts[i][clauseno];
00527       }
00528         
00529       //cout << "true count  = " << tc << endl;
00530       //cout << "false count = " << fc << endl;
00531         
00532       double weight = 0.0;
00533       double totalCnt = tc + fc;
00534                 
00535       if (totalCnt == 0) 
00536       {
00537         //cout << "NOTE: Total count is 0 for clause " << clauseno << endl;
00538         weight = EPSILON;
00539       } 
00540       else 
00541       {
00542         double prob =  tc / (tc+fc);
00543         if (prob == 0) prob = 0.00001;
00544         if (prob == 1) prob = 0.99999;
00545         weight = log(prob/(1-prob));
00546           //if weight exactly equals 0, make it small non zero, so that clause  
00547           //is not ignored during the construction of the MRF
00548         //if(weight == 0) weight = 0.0001;
00549           //commented above - make sure all weights are positive in the
00550           //beginning
00551         //if(weight < EPSILON) weight = EPSILON;
00552         if (abs(weight) < EPSILON) weight = EPSILON;
00553           //cout << "Prob " << prob << " becomes weight of " << weight << endl;
00554       }
00555       for (int i = 0; i < domainCnt_; i++) 
00556       {
00557         Array<double>& logOdds = logOddsPerDomain_[i];
00558         logOdds[clauseno] = weight;
00559       }
00560     }
00561     cout << endl;
00562     
00563     for (int i = 0; i < trainFalseCnts.size(); i++)
00564       delete[] trainFalseCnts[i];
00565   }
00566 
00575   void findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(
00576                                        const StringHashArray& nonEvidPredNames)
00577   {
00578     bool hasUnknownPreds;
00579     Array<Array<double> > totalFalseCnts; 
00580     Array<Array<double> > defaultFalseCnts;
00581     totalFalseCnts.growToSize(domainCnt_);
00582     defaultFalseCnts.growToSize(domainCnt_);
00583     
00584     Array<Predicate*> gpreds;
00585     Array<Predicate*> ppreds;
00586     Array<TruthValue> gpredValues;
00587     Array<TruthValue> tmpValues;
00588 
00589     for (int i = 0; i < domainCnt_; i++) 
00590     {
00591       const Domain* domain = inferences_[i]->getState()->getDomain();
00592       int clauseCnt = clauseCntPerDomain_[i];
00593       domain->getDB()->setPerformingInference(false);
00594 
00595       //cout << endl << "Getting the counts for the domain " << i << endl;
00596       gpreds.clear();
00597       gpredValues.clear();
00598       tmpValues.clear();
00599       for (int predno = 0; predno < nonEvidPredNames.size(); predno++) 
00600       {
00601         ppreds.clear();
00602         int predid = domain->getPredicateId(nonEvidPredNames[predno].c_str());
00603         Predicate::createAllGroundings(predid, domain, ppreds);
00604         //cout<<"size of gnd for pred " << predid << " = "<<ppreds.size()<<endl;
00605         gpreds.append(ppreds);
00606       }
00607       
00608       domain->getDB()->alterTruthValue(&gpreds, UNKNOWN, FALSE, &gpredValues);
00609           
00610       //cout <<"size of unknown set for domain "<<i<<" = "<<gpreds.size()<<endl;
00611       //cout << "size of the values " << i << " = " << gpredValues.size()<<endl;
00612         
00613       hasUnknownPreds = false;
00614       
00615       Array<double>& trueCnt = totalTrueCnts_[i];
00616       Array<double>& falseCnt = totalFalseCnts[i];
00617       trueCnt.growToSize(clauseCnt);
00618       falseCnt.growToSize(clauseCnt);
00619       calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00620 
00621       //cout << "got the total counts..\n\n\n" << endl;
00622       
00623       hasUnknownPreds = true;
00624 
00625       domain->getDB()->setValuesToUnknown(&gpreds, &tmpValues);
00626 
00627       Array<double>& dTrueCnt = defaultTrueCnts_[i];
00628       Array<double>& dFalseCnt = defaultFalseCnts[i];
00629       dTrueCnt.growToSize(clauseCnt);
00630       dFalseCnt.growToSize(clauseCnt);
00631       calculateCounts(dTrueCnt, dFalseCnt, i, hasUnknownPreds);
00632 
00633       //commented out: no need to revert the grounded non-evidence predicates
00634       //               to their initial values because we want to set ALL of
00635       //               them to UNKNOWN
00636       //assert(gpreds.size() == gpredValues.size());
00637       //domain->getDB()->setValuesToGivenValues(&gpreds, &gpredValues);
00638           
00639       //cout << "the ground predicates are :" << endl;
00640       for (int predno = 0; predno < gpreds.size(); predno++) 
00641         delete gpreds[predno];
00642 
00643       domain->getDB()->setPerformingInference(true);
00644     }
00645     //cout << endl << endl;
00646     //cout << "got the default counts..." << endl;     
00647     for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++) 
00648     {
00649       double tc = 0;
00650       double fc = 0;
00651       for (int i = 0; i < domainCnt_; i++) 
00652       {
00653         Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00654         Array<double>& logOdds = logOddsPerDomain_[i];
00655       
00656         if (!relevantClauses[clauseno]) { logOdds[clauseno] = 0; continue; }
00657         tc += totalTrueCnts_[i][clauseno] - defaultTrueCnts_[i][clauseno];
00658         fc += totalFalseCnts[i][clauseno] - defaultFalseCnts[i][clauseno];
00659 
00660         if (vpdebug)
00661           cout << clauseno << " : tc = " << tc << " ** fc = "<< fc <<endl;      
00662       }
00663       
00664       double weight = 0.0;
00665 
00666       if ((tc + fc) == 0) 
00667       {
00668         //cout << "NOTE: Total count is 0 for clause " << clauseno << endl;
00669       } 
00670       else 
00671       {
00672         double prob = tc / (tc+fc);
00673         if (prob == 0) prob = 0.00001;
00674         if (prob == 1) prob = 0.99999;
00675         weight = log(prob / (1-prob));
00676             //if weight exactly equals 0, make it small non zero, so that clause
00677             //is not ignored during the construction of the MRF
00678         //if (weight == 0) weight = 0.0001;
00679         if (abs(weight) < EPSILON) weight = EPSILON;
00680           //cout << "Prob " << prob << " becomes weight of " << weight << endl;
00681       }
00682       
00683         // Set logOdds in all domains to the weight calculated
00684       for(int i = 0; i < domainCnt_; i++) 
00685       { 
00686         Array<double>& logOdds = logOddsPerDomain_[i];
00687         logOdds[clauseno] = weight;
00688       }
00689     }
00690   }
00691  
00692   
00696   void infer() 
00697   {
00698     for (int i = 0; i < domainCnt_; i++) 
00699     {
00700       VariableState* state = inferences_[i]->getState();
00701       state->setGndClausesWtsToSumOfParentWts();
00702       //inferences_[i]->init();
00703         // MWS: Search is started from state at end of last iteration
00704       state->init();
00705       inferences_[i]->infer();
00706       state->saveLowStateToGndPreds();
00707     }
00708   }
00709 
00714   void fillInMissingValues()
00715   {
00716     assert(withEM_);
00717     cout << "Filling in missing data ..." << endl;
00718       // Get values of initial unknown preds by producing MAP state of
00719       // unknown preds given known evidence and non-evidence preds (VPEM)
00720     Array<Array<TruthValue> > ueValues;
00721     ueValues.growToSize(domainCnt_);
00722     for (int i = 0; i < domainCnt_; i++)
00723     {
00724       VariableState* state = inferences_[i]->getState();
00725       const Domain* domain = state->getDomain();
00726       const GroundPredicateHashArray* knePreds = state->getKnePreds();
00727       const Array<TruthValue>* knePredValues = state->getKnePredValues();
00728 
00729         // Mark known non-evidence preds as evidence
00730       domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00731 
00732         // Infer missing values
00733       state->setGndClausesWtsToSumOfParentWts();
00734         // MWS: Search is started from state at end of last iteration
00735       state->init();
00736       inferences_[i]->infer();
00737       state->saveLowStateToGndPreds();
00738 
00739       if (vpdebug)
00740       {
00741         cout << "Inferred following values: " << endl;
00742         inferences_[i]->printProbabilities(cout);
00743       }
00744 
00745         // Compute counts
00746       if (lazyInference_)
00747       {
00748         Array<double>& trueCnt = totalTrueCnts_[i];
00749         Array<double> falseCnt;
00750         bool hasUnknownPreds = false;
00751         falseCnt.growToSize(trueCnt.size());
00752         calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00753       }
00754       else
00755       {
00756         int clauseCnt = clauseCntPerDomain_[i];
00757         state->initMakeBreakCostWatch();
00758         //cout<<"getting true cnts => "<<endl;
00759         const Array<double>* clauseTrueCnts =
00760           inferences_[i]->getClauseTrueCnts();
00761         assert(clauseTrueCnts->size() == clauseCnt);
00762         for (int j = 0; j < clauseCnt; j++)
00763           trainTrueCnts_[i][j] = (*clauseTrueCnts)[j];
00764       }
00765 
00766         // Set evidence values back
00767       //assert(uePreds.size() == ueValues[i].size());
00768       //domain->getDB()->setValuesToGivenValues(&uePreds, &ueValues[i]);
00769         // Set non-evidence values to unknown
00770       Array<TruthValue> tmpValues;
00771       tmpValues.growToSize(knePreds->size());
00772       domain->getDB()->setValuesToUnknown(knePreds, &tmpValues);
00773     }
00774     cout << "Done filling in missing data" << endl;    
00775   }
00776 
00777   void getGradientForDomain(double* const & gradient, const int& domainIdx)
00778   {
00779     Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00780     int clauseCnt = clauseCntPerDomain_[domainIdx];
00781     double* trainCnts = NULL;
00782     double* inferredCnts = NULL;
00783     double* clauseTrainCnts = new double[clauseCnt]; 
00784     double* clauseInferredCnts = new double[clauseCnt];
00785     double trainCnt, inferredCnt;
00786     Array<double>& totalTrueCnts = totalTrueCnts_[domainIdx];
00787     Array<double>& defaultTrueCnts = defaultTrueCnts_[domainIdx];    
00788     const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00789     const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00790 
00791     memset(clauseTrainCnts, 0, clauseCnt*sizeof(double));
00792     memset(clauseInferredCnts, 0, clauseCnt*sizeof(double));
00793 
00794     if (!lazyInference_)
00795     {
00796       if (!inferredCnts) inferredCnts = new double[clauseCnt];
00797 
00798       const Array<double>* clauseTrueCnts =
00799         inferences_[domainIdx]->getClauseTrueCnts();
00800       assert(clauseTrueCnts->size() == clauseCnt);
00801       for (int i = 0; i < clauseCnt; i++)
00802         inferredCnts[i] = (*clauseTrueCnts)[i];
00803       trainCnts = trainTrueCnts_[domainIdx];
00804     }
00805       //loop over all the training examples
00806     //cout << "\t\ttrain count\t\t\t\tinferred count" << endl << endl;
00807     for (int clauseno = 0; clauseno < clauseCnt; clauseno++) 
00808     {
00809       if (!relevantClauses[clauseno]) continue;
00810       
00811       if (lazyInference_)
00812       {
00813         Clause* clause = (Clause*) mln->getClause(clauseno);
00814 
00815         trainCnt = totalTrueCnts[clauseno];
00816         inferredCnt =
00817           clause->getNumTrueGroundings(domain, domain->getDB(), false);
00818         trainCnt -= defaultTrueCnts[clauseno];
00819         inferredCnt -= defaultTrueCnts[clauseno];
00820       
00821         clauseTrainCnts[clauseno] += trainCnt;
00822         clauseInferredCnts[clauseno] += inferredCnt;
00823       }
00824       else
00825       {
00826         clauseTrainCnts[clauseno] += trainCnts[clauseno];
00827         clauseInferredCnts[clauseno] += inferredCnts[clauseno];
00828       }
00829       //cout << clauseno << ":\t\t" <<trainCnt<<"\t\t\t\t"<<inferredCnt<<endl;
00830     }
00831 
00832     if (vpdebug)
00833     {
00834       cout << "net counts : " << endl;
00835       cout << "\t\ttrain count\t\t\t\tinferred count" << endl << endl;
00836     }
00837 
00838     for (int clauseno = 0; clauseno < clauseCnt; clauseno++) 
00839     {
00840       if (!relevantClauses[clauseno]) continue;
00841       
00842       if (vpdebug)
00843         cout << clauseno << ":\t\t" << clauseTrainCnts[clauseno] << "\t\t\t\t"
00844              << clauseInferredCnts[clauseno] << endl;
00845       if (rescaleGradient_ && clauseTrainCnts[clauseno] > 0)
00846       {
00847         gradient[clauseno] += 
00848           (clauseTrainCnts[clauseno] - clauseInferredCnts[clauseno])
00849             / clauseTrainCnts[clauseno];
00850       }
00851       else
00852       {
00853         gradient[clauseno] += clauseTrainCnts[clauseno] - 
00854                               clauseInferredCnts[clauseno];
00855       }
00856     }
00857 
00858     delete[] clauseTrainCnts;
00859     delete[] clauseInferredCnts;
00860   }
00861 
00862 
00863     // Get the gradient 
00864   void getGradient(double* const & weights, double* const & gradient,
00865                    const int numWts) 
00866   {
00867     // Set the weights and run inference
00868     
00869     //cout << "New Weights = **** " << endl << endl;
00870     
00871       // If there is one db or the clauses for multiple databases line up
00872     if (idxTrans_ == NULL)
00873     {
00874       int clauseCnt = clauseCntPerDomain_[0];
00875       for (int i = 0; i < domainCnt_; i++)
00876       {
00877         Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00878         assert(clauseCntPerDomain_[i] == clauseCnt);
00879         const MLN* mln = inferences_[i]->getState()->getMLN();
00880         
00881         for (int j = 0; j < clauseCnt; j++) 
00882         {
00883           Clause* c = (Clause*) mln->getClause(j);
00884           if (relevantClauses[j]) c->setWt(weights[j]);
00885           else                    c->setWt(0);
00886         }
00887       }
00888     }
00889     else
00890     {   // The clauses for multiple databases do not line up
00891       Array<Array<double> >* wtsPerDomain = idxTrans_->getWtsPerDomain();
00892       const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain 
00893         = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00894       
00895       for (int i = 0; i < domainCnt_; i++)
00896       {
00897         Array<double>& wts = (*wtsPerDomain)[i];
00898         memset((double*)wts.getItems(), 0, wts.size()*sizeof(double));
00899 
00900           //map clause/formula weights to clause weights
00901         for (int j = 0; j < wts.size(); j++)
00902         {
00903           Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];          
00904           for (int k = 0; k < idxDivs->size(); k++)
00905             wts[j] += weights[ (*idxDivs)[k].idx ] / (*idxDivs)[k].div;
00906         }
00907       }
00908       
00909       for (int i = 0; i < domainCnt_; i++)
00910       {
00911         Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00912         int clauseCnt = clauseCntPerDomain_[i];
00913         Array<double>& wts = (*wtsPerDomain)[i];
00914         assert(wts.size() == clauseCnt);
00915         const MLN* mln = inferences_[i]->getState()->getMLN();
00916 
00917         for (int j = 0; j < clauseCnt; j++)
00918         {
00919           Clause* c = (Clause*) mln->getClause(j);
00920           if (relevantClauses[j]) c->setWt(wts[j]);
00921           else                   c->setWt(0);
00922         }
00923       }
00924     }
00925     //for (int i = 0; i < numWts; i++) cout << i << " : " << weights[i] << endl;
00926 
00927     if (withEM_) fillInMissingValues();
00928     cout << "Running inference ..." << endl;
00929     infer();
00930     cout << "Done with inference" << endl;
00931 
00932       // Compute the gradient
00933     memset(gradient, 0, numWts*sizeof(double));
00934 
00935       // There is one DB or the clauses of multiple DBs line up
00936     if (idxTrans_ == NULL)
00937     {
00938       for (int i = 0; i < domainCnt_; i++) 
00939       {           
00940         //cout << "For domain number " << i << endl << endl; 
00941         getGradientForDomain(gradient, i);        
00942       }
00943     }
00944     else
00945     {
00946         // The clauses for multiple databases do not line up
00947       Array<Array<double> >* gradsPerDomain = idxTrans_->getGradsPerDomain();
00948       const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain 
00949         = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00950      
00951       for (int i = 0; i < domainCnt_; i++) 
00952       {           
00953         //cout << "For domain number " << i << endl << endl; 
00954 
00955         Array<double>& grads = (*gradsPerDomain)[i];
00956         memset((double*)grads.getItems(), 0, grads.size()*sizeof(double));
00957         
00958         getGradientForDomain((double*)grads.getItems(), i);
00959         
00960           // map clause gradient to clause/formula gradients
00961         assert(grads.size() == clauseCntPerDomain_[i]);
00962         for (int j = 0; j < grads.size(); j++)
00963         {
00964           Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];          
00965           for (int k = 0; k < idxDivs->size(); k++)
00966             gradient[ (*idxDivs)[k].idx ] += grads[j] / (*idxDivs)[k].div;
00967         }
00968       }
00969     }
00970 
00971       // Add the deriative of the prior 
00972     if (usePrior_) 
00973     {
00974           for (int i = 0; i < numWts; i++) 
00975       {
00976         if (!relevantClausesFormulas_[i]) continue;
00977         double priorDerivative = -(weights[i]-priorMeans_[i])/
00978                                  (priorStdDevs_[i]*priorStdDevs_[i]);
00979         //cout << i << " : " << "gradient : " << gradient[i]
00980         //     << "  prior gradient : " << priorDerivative;
00981         gradient[i] += priorDerivative; 
00982             //cout << "  net gradient : " << gradient[i] << endl; 
00983       }
00984     }
00985   }
00986 
00987 
00988  private:
00989   int domainCnt_;
00990   //Array<Domain*> domains_;  
00991   //Array<MLN*> mlns_;
00992   Array<Array<double> > logOddsPerDomain_;
00993   Array<int> clauseCntPerDomain_;
00994 
00995         // Used in lazy version
00996   Array<Array<double> > totalTrueCnts_; 
00997   Array<Array<double> > defaultTrueCnts_;
00998 
00999   Array<Array<bool> > relevantClausesPerDomain_;
01000   Array<bool> relevantClausesFormulas_;
01001 
01002         // Used to compute cnts from mrf
01003   Array<double*> trainTrueCnts_;
01004 
01005   bool usePrior_;
01006   const double* priorMeans_, * priorStdDevs_; 
01007 
01008   IndexTranslator* idxTrans_; //not owned by object; don't delete
01009   
01010   bool lazyInference_;
01011   bool rescaleGradient_;
01012   bool isQueryEvidence_;
01013 
01014   Array<Inference*> inferences_;
01015   
01016     // Using EM to fill in missing values?
01017   bool withEM_;
01018 };
01019 
01020 
01021 #endif

Generated on Wed Feb 14 15:15:18 2007 for Alchemy by  doxygen 1.5.1