#include "MlpLearnAlgorithm.hpp" #include "Dataset.hpp" #include "Population.hpp" #include "MlpFactory.hpp" #include "MlpIndividual.hpp" #include "MlpObjectives.hpp" #include "Mlp.hpp" #include "Random.hpp" using namespace jymlp; #include using boost::property_tree::ptree; #define BOOST_FILESYSTEM_NO_DEPRECATED #include #include #include #include #include #include using namespace std; void MlpLearnAlgorithm::impl_setup(const ptree & all){ // FIXME: TODO } void MlpLearnAlgorithm::setup(const ptree & all, const string & outputdirname, const string & datadirname){ Algorithm::setup(all,outputdirname); _hack_copy_of_ptree = all; _datadirname = datadirname; } void MlpLearnAlgorithm::run(){ /* TODO: Most of these first things should go to impl_setup! */ const ptree & all = _hack_copy_of_ptree; const ptree & sub = all.get_child("MLP_learn"); string datafname = _datadirname + sub.get("datafile"); string task = sub.get("task"); size_t popsz = sub.get("evolution.pop_size"); int maxiter = sub.get("evolution.max_iter"); string eoper = sub.get("evolution.evol_oper"); int bpiter = sub.get("improvement.imp_steps"); double bpstep = sub.get("improvement.imp_inilen"); string improve_method = sub.get("improvement.imp_obj"); shared_ptr ds(new Dataset(datafname,-1.0,+1.0)); // Note binvec encoding! istringstream iss("init_size: " + sub.get("init.init_arch")); MlpFactory mf(iss); mt19937 mt; MlpObjectives meobj(sub.get("evolution.evol_obj"),ds->getNClasses()); for (size_t i=0;i cweights(ds->getNClasses()); for (auto & d: cweights) d=1.0/ds->getNClasses(); // default to equal class weights. unique_ptr nextop; // Create, write and read some: vector unparented(1); unparented[0] = 0; // Create unparented. for (size_t i=0;i(mf.create(&mt)), ds); ni->setId(nextID()); if (improve_method == "mse") { nextop = unique_ptr (new MlpBackpropUnaryOperator(bpiter,bpstep,cweights,false,false,false)); } else if (improve_method == "mse(direction=randomclass)"){ randomize1(cweights, &mt); nextop = unique_ptr (new MlpBackpropUnaryOperator(bpiter,bpstep,cweights,false,false,false)); } else if (improve_method == "mse(direction=rerandomclass)"){ randomize1(cweights, &mt); nextop = unique_ptr (new MlpBackpropUnaryOperator(bpiter,bpstep,cweights,true,false,false)); } else if (improve_method == "ssw(amount=rerandom);mse(direction=rerandomclass)"){ randomize1(cweights, &mt); nextop = unique_ptr (new MlpBackpropUnaryOperator(bpiter,bpstep,cweights,true,true,false)); } else if (improve_method == "mee/mse"){ randomize1(cweights, &mt); nextop = unique_ptr (new MlpBackpropUnaryOperator(bpiter,bpstep,cweights,true,true,true)); } else { cerr << "Unknown improvement method: " << improve_method << endl; exit(1); // FIXME: Implement exception facilities. } ni->addUnaryOperator(move(nextop)); markAsParents(ni->getId(),unparented); parent_pop.adopt(unique_ptr(ni)); } parent_pop.evaluate(); // Start by NSGA-II start step: parent_pop.assignRankAndCrowdingDistance(); ofstream ofs(_outputdirname + "pop0.dat", ios::trunc); parent_pop.toStream(ofs); ofs.close(); // Trace also the 0 pop: for (size_t ii = 0; ii mockpar(1); // as of yet, only one parent possible. TODO: Crossovers etc. // FIXME: Parent information can only be updated by the combination process! for (size_t ii = 0; ii("datafile"); ofstream ofs3(_outputdirname + "pop_final.dat", ios::trunc); parent_pop.toStream(ofs3); // TODO: History is unnecessary, if separate populations are stored? // (The auxiliary non-dominated archive is a completely different story...) ofstream ofs2(_outputdirname + "pop_history.dat", ios::trunc); history.toStream(ofs2); // FIXME: Should update this during the run, not only at the end: /* And.. it segfaults.. hmm.. should debug this properly Population nondom_all(meobj.nobj()); nondom_all.fillNondominatedSortFrom(history); ofstream ofsa(_outputdirname + "pop_nondom.dat", ios::trunc); nondom_all.toStream(ofsa); */ // Parent graph: ofstream ofs4(_outputdirname + "parents.dat", ios::trunc); writeParentHistory(ofs4); ofs4.close(); // Traced measures: ofstream ofs5(_outputdirname + "traced.dat", ios::trunc); writeTracedMeasures(ofs5); }