00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef VOTED_PERCEPTRON_H_OCT_30_2005
00067 #define VOTED_PERCEPTRON_H_OCT_30_2005
00068
00069 #include "infer.h"
00070 #include "clause.h"
00071 #include "timer.h"
00072 #include "indextranslator.h"
00073 #include "maxwalksat.h"
00074
00075 const bool vpdebug = false;
00076 const double EPSILON=.00001;
00077
00082 class VotedPerceptron
00083 {
00084 public:
00085
00100 VotedPerceptron(const Array<Inference*>& inferences,
00101 const StringHashArray& nonEvidPredNames,
00102 IndexTranslator* const & idxTrans, const bool& lazyInference,
00103 const bool& rescaleGradient, const bool& withEM)
00104 : domainCnt_(inferences.size()), idxTrans_(idxTrans),
00105 lazyInference_(lazyInference), rescaleGradient_(rescaleGradient),
00106 withEM_(withEM)
00107 {
00108 cout << endl << "Constructing voted perceptron..." << endl << endl;
00109
00110 inferences_.append(inferences);
00111 logOddsPerDomain_.growToSize(domainCnt_);
00112 clauseCntPerDomain_.growToSize(domainCnt_);
00113
00114 for (int i = 0; i < domainCnt_; i++)
00115 {
00116 clauseCntPerDomain_[i] =
00117 inferences_[i]->getState()->getMLN()->getNumClauses();
00118 logOddsPerDomain_[i].growToSize(clauseCntPerDomain_[i], 0);
00119 }
00120
00121 totalTrueCnts_.growToSize(domainCnt_);
00122 defaultTrueCnts_.growToSize(domainCnt_);
00123 relevantClausesPerDomain_.growToSize(domainCnt_);
00124
00125
00126 findRelevantClauses(nonEvidPredNames);
00127 findRelevantClausesFormulas();
00128
00129
00130 if (lazyInference_)
00131 {
00132 findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(nonEvidPredNames);
00133
00134 for (int i = 0; i < domainCnt_; i++)
00135 {
00136 const MLN* mln = inferences_[i]->getState()->getMLN();
00137 Array<double>& logOdds = logOddsPerDomain_[i];
00138 assert(mln->getNumClauses() == logOdds.size());
00139 for (int j = 0; j < mln->getNumClauses(); j++)
00140 ((Clause*) mln->getClause(j))->setWt(logOdds[j]);
00141 }
00142 }
00143
00144 else
00145 {
00146 initializeWts();
00147 }
00148
00149
00150 for (int i = 0; i < inferences_.size(); i++)
00151 inferences_[i]->init();
00152 }
00153
00154
00155 ~VotedPerceptron()
00156 {
00157 for (int i = 0; i < trainTrueCnts_.size(); i++)
00158 delete[] trainTrueCnts_[i];
00159 }
00160
00161
00162
00163 void setMeansStdDevs(const int& arrSize, const double* const & priorMeans,
00164 const double* const & priorStdDevs)
00165 {
00166 if (arrSize < 0)
00167 {
00168 usePrior_ = false;
00169 priorMeans_ = NULL;
00170 priorStdDevs_ = NULL;
00171 }
00172 else
00173 {
00174
00175 usePrior_ = true;
00176 priorMeans_ = priorMeans;
00177 priorStdDevs_ = priorStdDevs;
00178
00179
00180
00181
00182 }
00183 }
00184
00185
00186
00187 void learnWeights(double* const & weights, const int& numWeights,
00188 const int& maxIter, const double& learningRate,
00189 const double& momentum, bool initWithLogOdds,
00190 const int& mwsMaxSubsequentSteps)
00191 {
00192
00193 memset(weights, 0, numWeights*sizeof(double));
00194
00195 double* averageWeights = new double[numWeights];
00196 double* gradient = new double[numWeights];
00197 double* lastchange = new double[numWeights];
00198
00199
00200 if (initWithLogOdds)
00201 {
00202
00203 if (idxTrans_ == NULL)
00204 {
00205 for (int i = 0; i < domainCnt_; i++)
00206 {
00207 Array<double>& logOdds = logOddsPerDomain_[i];
00208 assert(numWeights == logOdds.size());
00209 for (int j = 0; j < logOdds.size(); j++) weights[j] += logOdds[j];
00210 }
00211 }
00212 else
00213 {
00214 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00215 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00216
00217 Array<int> numLogOdds;
00218 Array<double> wtsForDomain;
00219 numLogOdds.growToSize(numWeights);
00220 wtsForDomain.growToSize(numWeights);
00221
00222 for (int i = 0; i < domainCnt_; i++)
00223 {
00224 memset((int*)numLogOdds.getItems(), 0, numLogOdds.size()*sizeof(int));
00225 memset((double*)wtsForDomain.getItems(), 0,
00226 wtsForDomain.size()*sizeof(double));
00227
00228 Array<double>& logOdds = logOddsPerDomain_[i];
00229
00230
00231
00232 for (int j = 0; j < logOdds.size(); j++)
00233 {
00234 Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];
00235 for (int k = 0; k < idxDivs->size(); k++)
00236 {
00237 wtsForDomain[ (*idxDivs)[k].idx ] += logOdds[j];
00238 numLogOdds[ (*idxDivs)[k].idx ]++;
00239 }
00240 }
00241
00242 for (int j = 0; j < numWeights; j++)
00243 if (numLogOdds[j] > 0) weights[j] += wtsForDomain[j]/numLogOdds[j];
00244 }
00245 }
00246 }
00247
00248
00249 for (int i = 0; i < numWeights; i++)
00250 {
00251 weights[i] /= domainCnt_;
00252 averageWeights[i] = weights[i];
00253 lastchange[i] = 0.0;
00254 }
00255
00256 for (int iter = 1; iter <= maxIter; iter++)
00257 {
00258 cout << endl << "Iteration " << iter << " : " << endl << endl;
00259
00260
00261
00262 if (iter == 3)
00263 {
00264 for (int i = 0; i < inferences_.size(); i++)
00265 {
00266
00267 if (MaxWalkSat* mws = dynamic_cast<MaxWalkSat*>(inferences_[i]))
00268 {
00269 mws->setMaxSteps(mwsMaxSubsequentSteps);
00270 }
00271 }
00272 }
00273
00274 cout << "Getting the gradient.. " << endl;
00275 getGradient(weights, gradient, numWeights);
00276 cout << endl;
00277
00278
00279 for (int w = 0; w < numWeights; w++)
00280 {
00281 double wchange = gradient[w] * learningRate + lastchange[w] * momentum;
00282 cout << "clause/formula " << w << ": wtChange = " << wchange;
00283 cout << " oldWt = " << weights[w];
00284 weights[w] += wchange;
00285 lastchange[w] = wchange;
00286 cout << " newWt = " << weights[w];
00287 averageWeights[w] = (iter * averageWeights[w] + weights[w])/(iter + 1);
00288 cout << " averageWt = " << averageWeights[w] << endl;
00289 }
00290
00291 }
00292
00293 cout << endl << "Learned Weights : " << endl;
00294 for (int w = 0; w < numWeights; w++)
00295 {
00296 weights[w] = averageWeights[w];
00297 cout << w << ":" << weights[w] << endl;
00298 }
00299
00300 delete [] averageWeights;
00301 delete [] gradient;
00302 delete [] lastchange;
00303
00304 resetDBs();
00305 }
00306
00307
00308 private:
00309
00313 void resetDBs()
00314 {
00315 if (!lazyInference_)
00316 {
00317 for (int i = 0; i < domainCnt_; i++)
00318 {
00319 VariableState* state = inferences_[i]->getState();
00320 Database* db = state->getDomain()->getDB();
00321
00322 const GroundPredicateHashArray* knePreds = state->getKnePreds();
00323 const Array<TruthValue>* knePredValues = state->getKnePredValues();
00324 db->setValuesToGivenValues(knePreds, knePredValues);
00325
00326 const GroundPredicateHashArray* unePreds = state->getUnePreds();
00327 for (int predno = 0; predno < unePreds->size(); predno++)
00328 db->setValue((*unePreds)[predno], UNKNOWN);
00329 }
00330 }
00331 }
00332
00338 void findRelevantClauses(const StringHashArray& nonEvidPredNames)
00339 {
00340 for (int d = 0; d < domainCnt_; d++)
00341 {
00342 int clauseCnt = clauseCntPerDomain_[d];
00343 Array<bool>& relevantClauses = relevantClausesPerDomain_[d];
00344 relevantClauses.growToSize(clauseCnt);
00345 memset((bool*)relevantClauses.getItems(), false,
00346 relevantClauses.size()*sizeof(bool));
00347 const Domain* domain = inferences_[d]->getState()->getDomain();
00348 const MLN* mln = inferences_[d]->getState()->getMLN();
00349
00350 const Array<IndexClause*>* indclauses;
00351 const Clause* clause;
00352 int predid, clauseid;
00353 for (int i = 0; i < nonEvidPredNames.size(); i++)
00354 {
00355 predid = domain->getPredicateId(nonEvidPredNames[i].c_str());
00356
00357
00358 indclauses = mln->getClausesContainingPred(predid);
00359 if (indclauses)
00360 {
00361 for (int j = 0; j < indclauses->size(); j++)
00362 {
00363 clause = (*indclauses)[j]->clause;
00364 clauseid = mln->findClauseIdx(clause);
00365 relevantClauses[clauseid] = true;
00366
00367 }
00368
00369 }
00370 }
00371 }
00372 }
00373
00374
00375 void findRelevantClausesFormulas()
00376 {
00377 if (idxTrans_ == NULL)
00378 {
00379 Array<bool>& relevantClauses = relevantClausesPerDomain_[0];
00380 relevantClausesFormulas_.growToSize(relevantClauses.size());
00381 for (int i = 0; i < relevantClauses.size(); i++)
00382 relevantClausesFormulas_[i] = relevantClauses[i];
00383 }
00384 else
00385 {
00386 idxTrans_->setRelevantClausesFormulas(relevantClausesFormulas_,
00387 relevantClausesPerDomain_[0]);
00388 cout << "Relevant clauses/formulas:" << endl;
00389 idxTrans_->printRelevantClausesFormulas(cout, relevantClausesFormulas_);
00390 cout << endl;
00391 }
00392 }
00393
00394
00404 void calculateCounts(Array<double>& trueCnt, Array<double>& falseCnt,
00405 const int& domainIdx, const bool& hasUnknownPreds)
00406 {
00407 Clause* clause;
00408 double tmpUnknownCnt;
00409 int clauseCnt = clauseCntPerDomain_[domainIdx];
00410 Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00411 const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00412 const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00413
00414 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00415 {
00416 if (!relevantClauses[clauseno])
00417 {
00418 continue;
00419
00420 }
00421 clause = (Clause*) mln->getClause(clauseno);
00422 clause->getNumTrueFalseUnknownGroundings(domain, domain->getDB(),
00423 hasUnknownPreds,
00424 trueCnt[clauseno],
00425 falseCnt[clauseno],
00426 tmpUnknownCnt);
00427 assert(hasUnknownPreds || (tmpUnknownCnt==0));
00428 }
00429 }
00430
00431
00432 void initializeWts()
00433 {
00434 cout << "Initializing weights ..." << endl;
00435 Array<double *> trainFalseCnts;
00436 trainTrueCnts_.growToSize(domainCnt_);
00437 trainFalseCnts.growToSize(domainCnt_);
00438
00439 for (int i = 0; i < domainCnt_; i++)
00440 {
00441 int clauseCnt = clauseCntPerDomain_[i];
00442 VariableState* state = inferences_[i]->getState();
00443 const GroundPredicateHashArray* unePreds = state->getUnePreds();
00444 const GroundPredicateHashArray* knePreds = state->getKnePreds();
00445
00446 trainTrueCnts_[i] = new double[clauseCnt];
00447 trainFalseCnts[i] = new double[clauseCnt];
00448
00449 int totalPreds = unePreds->size() + knePreds->size();
00450
00451
00452 Array<bool>* unknownPred = new Array<bool>;
00453 unknownPred->growToSize(totalPreds, false);
00454 for (int predno = 0; predno < totalPreds; predno++)
00455 {
00456 GroundPredicate* p;
00457 if (predno < unePreds->size())
00458 p = (*unePreds)[predno];
00459 else
00460 p = (*knePreds)[predno - unePreds->size()];
00461 TruthValue tv = state->getDomain()->getDB()->getValue(p);
00462
00463
00464 if (tv == TRUE)
00465 {
00466 state->setValueOfAtom(predno + 1, true);
00467 p->setTruthValue(true);
00468 }
00469 else
00470 {
00471 state->setValueOfAtom(predno + 1, false);
00472 p->setTruthValue(false);
00473
00474
00475 if (tv == UNKNOWN)
00476 {
00477 (*unknownPred)[predno] = true;
00478 }
00479 }
00480 }
00481
00482 state->initMakeBreakCostWatch();
00483
00484 state->getNumClauseGndingsWithUnknown(trainTrueCnts_[i], clauseCnt, true,
00485 unknownPred);
00486
00487
00488 state->getNumClauseGndingsWithUnknown(trainFalseCnts[i], clauseCnt, false,
00489 unknownPred);
00490 delete unknownPred;
00491 if (vpdebug)
00492 {
00493 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00494 {
00495 cout << clauseno << " : tc = " << trainTrueCnts_[i][clauseno]
00496 << " ** fc = " << trainFalseCnts[i][clauseno] << endl;
00497 }
00498 }
00499 }
00500
00501 double tc,fc;
00502 cout << "List of CNF Clauses : " << endl;
00503 for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++)
00504 {
00505 if (!relevantClausesPerDomain_[0][clauseno])
00506 {
00507 for (int i = 0; i < domainCnt_; i++)
00508 {
00509 Array<double>& logOdds = logOddsPerDomain_[i];
00510 logOdds[clauseno] = 0.0;
00511 }
00512 continue;
00513 }
00514
00515 cout << clauseno << ":";
00516 const Clause* clause =
00517 inferences_[0]->getState()->getMLN()->getClause(clauseno);
00518
00519 clause->print(cout, inferences_[0]->getState()->getDomain());
00520 cout << endl;
00521
00522 tc = 0.0; fc = 0.0;
00523 for (int i = 0; i < domainCnt_;i++)
00524 {
00525 tc += trainTrueCnts_[i][clauseno];
00526 fc += trainFalseCnts[i][clauseno];
00527 }
00528
00529
00530
00531
00532 double weight = 0.0;
00533 double totalCnt = tc + fc;
00534
00535 if (totalCnt == 0)
00536 {
00537
00538 weight = EPSILON;
00539 }
00540 else
00541 {
00542 double prob = tc / (tc+fc);
00543 if (prob == 0) prob = 0.00001;
00544 if (prob == 1) prob = 0.99999;
00545 weight = log(prob/(1-prob));
00546
00547
00548
00549
00550
00551
00552 if (abs(weight) < EPSILON) weight = EPSILON;
00553
00554 }
00555 for (int i = 0; i < domainCnt_; i++)
00556 {
00557 Array<double>& logOdds = logOddsPerDomain_[i];
00558 logOdds[clauseno] = weight;
00559 }
00560 }
00561 cout << endl;
00562
00563 for (int i = 0; i < trainFalseCnts.size(); i++)
00564 delete[] trainFalseCnts[i];
00565 }
00566
00575 void findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(
00576 const StringHashArray& nonEvidPredNames)
00577 {
00578 bool hasUnknownPreds;
00579 Array<Array<double> > totalFalseCnts;
00580 Array<Array<double> > defaultFalseCnts;
00581 totalFalseCnts.growToSize(domainCnt_);
00582 defaultFalseCnts.growToSize(domainCnt_);
00583
00584 Array<Predicate*> gpreds;
00585 Array<Predicate*> ppreds;
00586 Array<TruthValue> gpredValues;
00587 Array<TruthValue> tmpValues;
00588
00589 for (int i = 0; i < domainCnt_; i++)
00590 {
00591 const Domain* domain = inferences_[i]->getState()->getDomain();
00592 int clauseCnt = clauseCntPerDomain_[i];
00593 domain->getDB()->setPerformingInference(false);
00594
00595
00596 gpreds.clear();
00597 gpredValues.clear();
00598 tmpValues.clear();
00599 for (int predno = 0; predno < nonEvidPredNames.size(); predno++)
00600 {
00601 ppreds.clear();
00602 int predid = domain->getPredicateId(nonEvidPredNames[predno].c_str());
00603 Predicate::createAllGroundings(predid, domain, ppreds);
00604
00605 gpreds.append(ppreds);
00606 }
00607
00608 domain->getDB()->alterTruthValue(&gpreds, UNKNOWN, FALSE, &gpredValues);
00609
00610
00611
00612
00613 hasUnknownPreds = false;
00614
00615 Array<double>& trueCnt = totalTrueCnts_[i];
00616 Array<double>& falseCnt = totalFalseCnts[i];
00617 trueCnt.growToSize(clauseCnt);
00618 falseCnt.growToSize(clauseCnt);
00619 calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00620
00621
00622
00623 hasUnknownPreds = true;
00624
00625 domain->getDB()->setValuesToUnknown(&gpreds, &tmpValues);
00626
00627 Array<double>& dTrueCnt = defaultTrueCnts_[i];
00628 Array<double>& dFalseCnt = defaultFalseCnts[i];
00629 dTrueCnt.growToSize(clauseCnt);
00630 dFalseCnt.growToSize(clauseCnt);
00631 calculateCounts(dTrueCnt, dFalseCnt, i, hasUnknownPreds);
00632
00633
00634
00635
00636
00637
00638
00639
00640 for (int predno = 0; predno < gpreds.size(); predno++)
00641 delete gpreds[predno];
00642
00643 domain->getDB()->setPerformingInference(true);
00644 }
00645
00646
00647 for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++)
00648 {
00649 double tc = 0;
00650 double fc = 0;
00651 for (int i = 0; i < domainCnt_; i++)
00652 {
00653 Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00654 Array<double>& logOdds = logOddsPerDomain_[i];
00655
00656 if (!relevantClauses[clauseno]) { logOdds[clauseno] = 0; continue; }
00657 tc += totalTrueCnts_[i][clauseno] - defaultTrueCnts_[i][clauseno];
00658 fc += totalFalseCnts[i][clauseno] - defaultFalseCnts[i][clauseno];
00659
00660 if (vpdebug)
00661 cout << clauseno << " : tc = " << tc << " ** fc = "<< fc <<endl;
00662 }
00663
00664 double weight = 0.0;
00665
00666 if ((tc + fc) == 0)
00667 {
00668
00669 }
00670 else
00671 {
00672 double prob = tc / (tc+fc);
00673 if (prob == 0) prob = 0.00001;
00674 if (prob == 1) prob = 0.99999;
00675 weight = log(prob / (1-prob));
00676
00677
00678
00679 if (abs(weight) < EPSILON) weight = EPSILON;
00680
00681 }
00682
00683
00684 for(int i = 0; i < domainCnt_; i++)
00685 {
00686 Array<double>& logOdds = logOddsPerDomain_[i];
00687 logOdds[clauseno] = weight;
00688 }
00689 }
00690 }
00691
00692
00696 void infer()
00697 {
00698 for (int i = 0; i < domainCnt_; i++)
00699 {
00700 VariableState* state = inferences_[i]->getState();
00701 state->setGndClausesWtsToSumOfParentWts();
00702
00703
00704 state->init();
00705 inferences_[i]->infer();
00706 state->saveLowStateToGndPreds();
00707 }
00708 }
00709
00714 void fillInMissingValues()
00715 {
00716 assert(withEM_);
00717 cout << "Filling in missing data ..." << endl;
00718
00719
00720 Array<Array<TruthValue> > ueValues;
00721 ueValues.growToSize(domainCnt_);
00722 for (int i = 0; i < domainCnt_; i++)
00723 {
00724 VariableState* state = inferences_[i]->getState();
00725 const Domain* domain = state->getDomain();
00726 const GroundPredicateHashArray* knePreds = state->getKnePreds();
00727 const Array<TruthValue>* knePredValues = state->getKnePredValues();
00728
00729
00730 domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00731
00732
00733 state->setGndClausesWtsToSumOfParentWts();
00734
00735 state->init();
00736 inferences_[i]->infer();
00737 state->saveLowStateToGndPreds();
00738
00739 if (vpdebug)
00740 {
00741 cout << "Inferred following values: " << endl;
00742 inferences_[i]->printProbabilities(cout);
00743 }
00744
00745
00746 if (lazyInference_)
00747 {
00748 Array<double>& trueCnt = totalTrueCnts_[i];
00749 Array<double> falseCnt;
00750 bool hasUnknownPreds = false;
00751 falseCnt.growToSize(trueCnt.size());
00752 calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00753 }
00754 else
00755 {
00756 int clauseCnt = clauseCntPerDomain_[i];
00757 state->initMakeBreakCostWatch();
00758
00759 const Array<double>* clauseTrueCnts =
00760 inferences_[i]->getClauseTrueCnts();
00761 assert(clauseTrueCnts->size() == clauseCnt);
00762 for (int j = 0; j < clauseCnt; j++)
00763 trainTrueCnts_[i][j] = (*clauseTrueCnts)[j];
00764 }
00765
00766
00767
00768
00769
00770 Array<TruthValue> tmpValues;
00771 tmpValues.growToSize(knePreds->size());
00772 domain->getDB()->setValuesToUnknown(knePreds, &tmpValues);
00773 }
00774 cout << "Done filling in missing data" << endl;
00775 }
00776
00777 void getGradientForDomain(double* const & gradient, const int& domainIdx)
00778 {
00779 Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00780 int clauseCnt = clauseCntPerDomain_[domainIdx];
00781 double* trainCnts = NULL;
00782 double* inferredCnts = NULL;
00783 double* clauseTrainCnts = new double[clauseCnt];
00784 double* clauseInferredCnts = new double[clauseCnt];
00785 double trainCnt, inferredCnt;
00786 Array<double>& totalTrueCnts = totalTrueCnts_[domainIdx];
00787 Array<double>& defaultTrueCnts = defaultTrueCnts_[domainIdx];
00788 const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00789 const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00790
00791 memset(clauseTrainCnts, 0, clauseCnt*sizeof(double));
00792 memset(clauseInferredCnts, 0, clauseCnt*sizeof(double));
00793
00794 if (!lazyInference_)
00795 {
00796 if (!inferredCnts) inferredCnts = new double[clauseCnt];
00797
00798 const Array<double>* clauseTrueCnts =
00799 inferences_[domainIdx]->getClauseTrueCnts();
00800 assert(clauseTrueCnts->size() == clauseCnt);
00801 for (int i = 0; i < clauseCnt; i++)
00802 inferredCnts[i] = (*clauseTrueCnts)[i];
00803 trainCnts = trainTrueCnts_[domainIdx];
00804 }
00805
00806
00807 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00808 {
00809 if (!relevantClauses[clauseno]) continue;
00810
00811 if (lazyInference_)
00812 {
00813 Clause* clause = (Clause*) mln->getClause(clauseno);
00814
00815 trainCnt = totalTrueCnts[clauseno];
00816 inferredCnt =
00817 clause->getNumTrueGroundings(domain, domain->getDB(), false);
00818 trainCnt -= defaultTrueCnts[clauseno];
00819 inferredCnt -= defaultTrueCnts[clauseno];
00820
00821 clauseTrainCnts[clauseno] += trainCnt;
00822 clauseInferredCnts[clauseno] += inferredCnt;
00823 }
00824 else
00825 {
00826 clauseTrainCnts[clauseno] += trainCnts[clauseno];
00827 clauseInferredCnts[clauseno] += inferredCnts[clauseno];
00828 }
00829
00830 }
00831
00832 if (vpdebug)
00833 {
00834 cout << "net counts : " << endl;
00835 cout << "\t\ttrain count\t\t\t\tinferred count" << endl << endl;
00836 }
00837
00838 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00839 {
00840 if (!relevantClauses[clauseno]) continue;
00841
00842 if (vpdebug)
00843 cout << clauseno << ":\t\t" << clauseTrainCnts[clauseno] << "\t\t\t\t"
00844 << clauseInferredCnts[clauseno] << endl;
00845 if (rescaleGradient_ && clauseTrainCnts[clauseno] > 0)
00846 {
00847 gradient[clauseno] +=
00848 (clauseTrainCnts[clauseno] - clauseInferredCnts[clauseno])
00849 / clauseTrainCnts[clauseno];
00850 }
00851 else
00852 {
00853 gradient[clauseno] += clauseTrainCnts[clauseno] -
00854 clauseInferredCnts[clauseno];
00855 }
00856 }
00857
00858 delete[] clauseTrainCnts;
00859 delete[] clauseInferredCnts;
00860 }
00861
00862
00863
00864 void getGradient(double* const & weights, double* const & gradient,
00865 const int numWts)
00866 {
00867
00868
00869
00870
00871
00872 if (idxTrans_ == NULL)
00873 {
00874 int clauseCnt = clauseCntPerDomain_[0];
00875 for (int i = 0; i < domainCnt_; i++)
00876 {
00877 Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00878 assert(clauseCntPerDomain_[i] == clauseCnt);
00879 const MLN* mln = inferences_[i]->getState()->getMLN();
00880
00881 for (int j = 0; j < clauseCnt; j++)
00882 {
00883 Clause* c = (Clause*) mln->getClause(j);
00884 if (relevantClauses[j]) c->setWt(weights[j]);
00885 else c->setWt(0);
00886 }
00887 }
00888 }
00889 else
00890 {
00891 Array<Array<double> >* wtsPerDomain = idxTrans_->getWtsPerDomain();
00892 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00893 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00894
00895 for (int i = 0; i < domainCnt_; i++)
00896 {
00897 Array<double>& wts = (*wtsPerDomain)[i];
00898 memset((double*)wts.getItems(), 0, wts.size()*sizeof(double));
00899
00900
00901 for (int j = 0; j < wts.size(); j++)
00902 {
00903 Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];
00904 for (int k = 0; k < idxDivs->size(); k++)
00905 wts[j] += weights[ (*idxDivs)[k].idx ] / (*idxDivs)[k].div;
00906 }
00907 }
00908
00909 for (int i = 0; i < domainCnt_; i++)
00910 {
00911 Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00912 int clauseCnt = clauseCntPerDomain_[i];
00913 Array<double>& wts = (*wtsPerDomain)[i];
00914 assert(wts.size() == clauseCnt);
00915 const MLN* mln = inferences_[i]->getState()->getMLN();
00916
00917 for (int j = 0; j < clauseCnt; j++)
00918 {
00919 Clause* c = (Clause*) mln->getClause(j);
00920 if (relevantClauses[j]) c->setWt(wts[j]);
00921 else c->setWt(0);
00922 }
00923 }
00924 }
00925
00926
00927 if (withEM_) fillInMissingValues();
00928 cout << "Running inference ..." << endl;
00929 infer();
00930 cout << "Done with inference" << endl;
00931
00932
00933 memset(gradient, 0, numWts*sizeof(double));
00934
00935
00936 if (idxTrans_ == NULL)
00937 {
00938 for (int i = 0; i < domainCnt_; i++)
00939 {
00940
00941 getGradientForDomain(gradient, i);
00942 }
00943 }
00944 else
00945 {
00946
00947 Array<Array<double> >* gradsPerDomain = idxTrans_->getGradsPerDomain();
00948 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00949 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00950
00951 for (int i = 0; i < domainCnt_; i++)
00952 {
00953
00954
00955 Array<double>& grads = (*gradsPerDomain)[i];
00956 memset((double*)grads.getItems(), 0, grads.size()*sizeof(double));
00957
00958 getGradientForDomain((double*)grads.getItems(), i);
00959
00960
00961 assert(grads.size() == clauseCntPerDomain_[i]);
00962 for (int j = 0; j < grads.size(); j++)
00963 {
00964 Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];
00965 for (int k = 0; k < idxDivs->size(); k++)
00966 gradient[ (*idxDivs)[k].idx ] += grads[j] / (*idxDivs)[k].div;
00967 }
00968 }
00969 }
00970
00971
00972 if (usePrior_)
00973 {
00974 for (int i = 0; i < numWts; i++)
00975 {
00976 if (!relevantClausesFormulas_[i]) continue;
00977 double priorDerivative = -(weights[i]-priorMeans_[i])/
00978 (priorStdDevs_[i]*priorStdDevs_[i]);
00979
00980
00981 gradient[i] += priorDerivative;
00982
00983 }
00984 }
00985 }
00986
00987
00988 private:
00989 int domainCnt_;
00990
00991
00992 Array<Array<double> > logOddsPerDomain_;
00993 Array<int> clauseCntPerDomain_;
00994
00995
00996 Array<Array<double> > totalTrueCnts_;
00997 Array<Array<double> > defaultTrueCnts_;
00998
00999 Array<Array<bool> > relevantClausesPerDomain_;
01000 Array<bool> relevantClausesFormulas_;
01001
01002
01003 Array<double*> trainTrueCnts_;
01004
01005 bool usePrior_;
01006 const double* priorMeans_, * priorStdDevs_;
01007
01008 IndexTranslator* idxTrans_;
01009
01010 bool lazyInference_;
01011 bool rescaleGradient_;
01012 bool isQueryEvidence_;
01013
01014 Array<Inference*> inferences_;
01015
01016
01017 bool withEM_;
01018 };
01019
01020
01021 #endif