learnwts.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef LEARNWTS_H_NOV_23_2005
00067 #define LEARNWTS_H_NOV_23_2005
00068 
00069 #include <sys/time.h>
00070 #include "util.h"
00071 #include "timer.h"
00072 #include "fol.h"
00073 #include "mln.h"
00074 #include "indextranslator.h"
00075 
00076 
00077 extern const char* ZZ_TMP_FILE_POSTFIX; //defined in fol.y
00078 const bool DOMAINS_SHARE_DATA_STRUCT = true;
00079 
00087 void extractFileNames(const char* const & namesStr, Array<string>& namesArray)
00088 {
00089   if (namesStr == NULL) return;
00090   string s(namesStr);
00091   s = Util::trim(s);
00092   if (s.length() == 0) return;
00093   s.append(",");
00094   string::size_type cur = 0;
00095   string::size_type comma;
00096   string name;
00097   while (true)
00098   {
00099     comma = s.find(",", cur);
00100     if (comma == string::npos) return;
00101     name = s.substr(cur, comma-cur);
00102     namesArray.append(name);
00103     cur = comma+1;
00104   }
00105 }
00106 
00114 void extractArgs(const char* const & argsStr, int& argc, char** argv)
00115 {
00116   argc = 0;
00117   if (argsStr == NULL) return;
00118   string s(argsStr);
00119   s = Util::trim(s);
00120   if (s.length() == 0) return;
00121   s.append(" ");
00122   string::size_type cur = 0;
00123   string::size_type blank;
00124   string arg;
00125 
00126   while (true)
00127   {
00128     blank = s.find(" ", cur);
00129     if (blank == string::npos) return;
00130     arg = s.substr(cur, blank-cur);
00131     arg = Util::trim(arg);
00132     memset(argv[argc], '\0', 30);
00133     arg.copy(argv[argc], arg.length());
00134     argc++;
00135     cur = blank + 1;
00136   }
00137 }
00138 
00139 void createDomainAndMLN(Array<Domain*>& domains, Array<MLN*>& mlns,
00140                         const string& inMLNFile, ostringstream& constFiles,
00141                         ostringstream& dbFiles,
00142                         const StringHashArray* const & nonEvidPredNames,
00143                         const bool& addUnitClauses, const double& priorMean,
00144                         const bool& checkPlusTypes, const bool& mwsLazy,
00145                         const bool& allPredsExceptQueriesAreCW,
00146                         const StringHashArray* const & owPredNames)
00147 {
00148   string::size_type bslash = inMLNFile.rfind("/");
00149   string tmp = (bslash == string::npos) ? 
00150                inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
00151   char tmpInMLN[100];
00152   sprintf(tmpInMLN, "%s%s",  tmp.c_str(), ZZ_TMP_FILE_POSTFIX);
00153 
00154   ofstream out(tmpInMLN);
00155   ifstream in(inMLNFile.c_str());
00156   if (!out.good()) { cout<<"ERROR: failed to open "<<tmpInMLN <<endl; exit(-1);}
00157   if (!in.good())  { cout<<"ERROR: failed to open "<<inMLNFile<<endl; exit(-1);}
00158 
00159   string buffer;
00160   while(getline(in, buffer)) out << buffer << endl;
00161   in.close();
00162 
00163   out << constFiles.str() << endl 
00164       << dbFiles.str() << endl;
00165   out.close();
00166   
00167   // read the formulas from the input MLN
00168   Domain* domain = new Domain;
00169   MLN* mln = new MLN();
00170   
00171     // Unknown evidence atoms are filled in by EM
00172   //bool allPredsExceptQueriesAreCW = true;
00173   bool warnAboutDupGndPreds = true;
00174   bool mustHaveWtOrFullStop = false;
00175   bool flipWtsOfFlippedClause = false;
00176   Domain* domain0 = (checkPlusTypes) ? domains[0] : NULL;
00177 
00178   bool ok = runYYParser(mln, domain, tmpInMLN, allPredsExceptQueriesAreCW, 
00179                         owPredNames, nonEvidPredNames, addUnitClauses, 
00180                         warnAboutDupGndPreds, priorMean, mustHaveWtOrFullStop,
00181                         domain0, mwsLazy, flipWtsOfFlippedClause);
00182 
00183   if (!ok) { unlink(tmpInMLN); exit(-1); }
00184   domains.append(domain);
00185   mlns.append(mln);
00186   unlink(tmpInMLN);
00187 }
00188 
00189 
00190 void createDomainsAndMLNs(Array<Domain*>& domains, Array<MLN*>& mlns, 
00191                           const bool& multipleDatabases,
00192                           const string& inMLNFile,
00193                           const Array<string>& constFilesArr,
00194                           const Array<string>& dbFilesArr,
00195                           const StringHashArray* const & nonEvidPredNames,
00196                           const bool& addUnitClauses, const double& priorMean,
00197                           const bool& mwsLazy,
00198                           const bool& allPredsExceptQueriesAreCW,
00199                           const StringHashArray* const & owPredNames)
00200 {
00201   if (!multipleDatabases)
00202   {
00203     ostringstream constFilesStream, dbFilesStream;
00204     for (int i = 0; i < constFilesArr.size(); i++) 
00205       constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00206     for (int i = 0; i < dbFilesArr.size(); i++)    
00207       dbFilesStream << "#include \"" << dbFilesArr[i] << "\"" << endl;
00208     createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream, 
00209                        dbFilesStream, nonEvidPredNames,
00210                        addUnitClauses, priorMean, false, mwsLazy,
00211                        allPredsExceptQueriesAreCW, owPredNames);
00212   }
00213   else
00214   {   //if multiple databases
00215     for (int i = 0; i < dbFilesArr.size(); i++) // for each domain
00216     {
00217       cout << "parsing MLN and creating domain " << i << "..." << endl;
00218       ostringstream constFilesStream, dbFilesStream;
00219       if (constFilesArr.size() > 0)
00220         constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00221       dbFilesStream    << "#include \"" << dbFilesArr[i]    << "\"" << endl;
00222       
00223       bool checkPlusTypes = (i > 0);
00224 
00225       createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream,
00226                          dbFilesStream, nonEvidPredNames,
00227                          addUnitClauses, priorMean, checkPlusTypes, mwsLazy,
00228                          allPredsExceptQueriesAreCW, owPredNames);
00229 
00230         // let the domains share data structures
00231       if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00232       {
00233         const ClauseHashArray* carr = mlns[i]->getClauses();
00234         for (int j = 0; j < carr->size(); j++)
00235         {
00236           Clause* c = (*carr)[j];
00237           for (int k = 0; k < c->getNumPredicates(); k++)
00238           {
00239             Predicate* p = c->getPredicate(k);
00240             const PredicateTemplate* t 
00241               = domains[0]->getPredicateTemplate(p->getName());
00242             assert(t);
00243             p->setTemplate((PredicateTemplate*)t);
00244           }
00245         }
00246 
00247         ((Domain*)domains[i])->replaceTypeDualMap((
00248                                          DualMap*)domains[0]->getTypeDualMap());
00249         ((Domain*)domains[i])->replaceStrToPredTemplateMapAndPredDualMap(
00250                   (StrToPredTemplateMap*) domains[0]->getStrToPredTemplateMap(),
00251                   (DualMap*) domains[0]->getPredDualMap());
00252         ((Domain*)domains[i])->replaceStrToFuncTemplateMapAndFuncDualMap(
00253                   (StrToFuncTemplateMap*) domains[0]->getStrToFuncTemplateMap(),
00254                   (DualMap*) domains[0]->getFuncDualMap());
00255         ((Domain*)domains[i])->replaceEqualPredTemplate(
00256                         (PredicateTemplate*)domains[0]->getEqualPredTemplate());
00257         ((Domain*)domains[i])->replaceFuncSet(
00258                                         (FunctionSet*)domains[0]->getFuncSet());
00259       }
00260 
00261     } // for each domain
00262   }
00263 
00264   //commented out: not true when there are domains with different constants
00265   //int numClauses = mlns[0]->getNumClauses();
00266   //for (int i = 1; i < mlns.size(); i++) 
00267   //  assert(mlns[i]->getNumClauses() == numClauses);
00268   //numClauses = 0; //avoid compilation warning
00269 }
00270 
00271 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns, 
00272                            Array<Domain*>& domains, const Array<double>& wts, 
00273                            IndexTranslator* const& indexTrans)
00274 {
00275     //assign the optimal weights belonging to clauses (and none of those 
00276     //belonging to existentially quantified formulas) to the MLNs
00277   double* wwts = (double*) wts.getItems();
00278   indexTrans->assignNonTiedClauseWtsToMLNs(++wwts);
00279 
00280 
00281     // output the predicate declaration
00282   out << "//predicate declarations" << endl;
00283   domains[0]->printPredicateTemplates(out);
00284   out << endl;
00285 
00286   // output the function declarations
00287   out << "//function declarations" << endl;
00288   domains[0]->printFunctionTemplates(out);
00289   out << endl;
00290 
00291   mlns[0]->printMLNNonExistFormulas(out, domains[0]);
00292 
00293   const ClauseHashArray* clauseOrdering = indexTrans->getClauseOrdering();
00294   const StringHashArray* exFormOrdering = indexTrans->getExistFormulaOrdering();
00295   for (int i = 0; i < exFormOrdering->size(); i++)
00296   {
00297       // output the original formula and its weight
00298     out.width(0); out << "// "; out.width(6); 
00299     out << wts[1+clauseOrdering->size()+i] << "  " <<(*exFormOrdering)[i]<<endl;
00300     out << wts[1+clauseOrdering->size()+i] << "  " <<(*exFormOrdering)[i]<<endl;
00301     out << endl;
00302   }
00303 }
00304 
00305 
00306 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns, 
00307                            Array<Domain*>& domains, const Array<double>& wts)
00308 {
00309     // assign the optimal weights to the clauses in all MLNs
00310   for (int i = 0; i < mlns.size(); i++)
00311   {
00312     MLN* mln = mlns[i];
00313     const ClauseHashArray* clauses = mln->getClauses();
00314     for (int i = 0; i < clauses->size(); i++) 
00315       (*clauses)[i]->setWt(wts[i+1]);
00316   }
00317 
00318     // output the predicate declaration
00319   out << "//predicate declarations" << endl;
00320   domains[0]->printPredicateTemplates(out);
00321   out << endl;
00322 
00323   // output the function declarations
00324   out << "//function declarations" << endl;
00325   domains[0]->printFunctionTemplates(out);
00326   out << endl;
00327   mlns[0]->printMLN(out, domains[0]);
00328 }
00329 
00330 
00331 void deleteDomains(Array<Domain*>& domains)
00332 {
00333   for (int i = 0; i < domains.size(); i++) 
00334   {
00335     if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00336     {
00337       ((Domain*)domains[i])->setTypeDualMap(NULL);
00338       ((Domain*)domains[i])->setStrToPredTemplateMapAndPredDualMap(NULL, NULL);
00339       ((Domain*)domains[i])->setStrToFuncTemplateMapAndFuncDualMap(NULL, NULL);
00340       ((Domain*)domains[i])->setEqualPredTemplate(NULL);
00341       ((Domain*)domains[i])->setFuncSet(NULL);
00342     }
00343     delete domains[i];
00344   }
00345 }
00346 
00347 
00348 #endif

Generated on Wed Feb 14 15:15:18 2007 for Alchemy by  doxygen 1.5.1