[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RANDOM_FOREST_HXX 00038 #define VIGRA_RANDOM_FOREST_HXX 00039 00040 #include <iostream> 00041 #include <algorithm> 00042 #include <map> 00043 #include <set> 00044 #include <list> 00045 #include <numeric> 00046 #include "mathutil.hxx" 00047 #include "array_vector.hxx" 00048 #include "sized_int.hxx" 00049 #include "matrix.hxx" 00050 #include "random.hxx" 00051 #include "functorexpression.hxx" 00052 #include "random_forest/rf_common.hxx" 00053 #include "random_forest/rf_nodeproxy.hxx" 00054 #include "random_forest/rf_split.hxx" 00055 #include "random_forest/rf_decisionTree.hxx" 00056 #include "random_forest/rf_visitors.hxx" 00057 #include "random_forest/rf_region.hxx" 00058 #include "sampling.hxx" 00059 #include "random_forest/rf_preprocessing.hxx" 00060 #include "random_forest/rf_online_prediction_set.hxx" 00061 #include "random_forest/rf_earlystopping.hxx" 00062 #include "random_forest/rf_ridge_split.hxx" 00063 namespace vigra 00064 { 00065 00066 /** \addtogroup MachineLearning Machine Learning 00067 00068 This module provides classification algorithms that map 00069 features to labels or label probablities. 00070 Look at the RandomForest class first for a overview of most of the 00071 functionality provided as well as use cases. 00072 **/ 00073 //@{ 00074 00075 namespace detail 00076 { 00077 00078 00079 00080 /* \brief sampling option factory function 00081 */ 00082 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt) 00083 { 00084 SamplerOptions return_opt; 00085 return_opt.withReplacement(RF_opt.sample_with_replacement_); 00086 return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL); 00087 return return_opt; 00088 } 00089 }//namespace detail 00090 00091 /** Random Forest class 00092 * 00093 * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess 00094 * the input while learning and predicting. Currently Available: 00095 * ClassificationTag and RegressionTag. It is recommended to use 00096 * Splitfunctor::Preprocessor_t while using custom splitfunctors 00097 * as they may need the data to be in a different format. 00098 * \sa Preprocessor 00099 * 00100 * simple usage for classification (regression is not yet supported): 00101 * look at RandomForest::learn() as well as RandomForestOptions() for additional 00102 * options. 00103 * 00104 * \code 00105 * typedef xxx feature_t \\ replace xxx with whichever type 00106 * typedef yyy label_t \\ meme chose. 00107 * MultiArrayView<2, feature_t> f = get_some_features(); 00108 * MultiArrayView<2, label_t> l = get_some_labels(); 00109 * RandomForest<> rf() 00110 * double oob_error = rf.learn(f, l); 00111 * 00112 * MultiArrayView<2, feature_t> pf = get_some_unknown_features(); 00113 * MultiArrayView<2, label_t> prediction = allocate_space_for_response(); 00114 * MultiArrayView<2, double> prob = allocate_space_for_probability(); 00115 * 00116 * rf.predict_labels(pf, prediction); 00117 * rf.predict_probabilities(pf, prob); 00118 * 00119 * \endcode 00120 * 00121 * Additional information such as OOB Error and Variable Importance measures are accessed 00122 * via Visitors defined in rf::visitors. 00123 * Have a look at rf::split for other splitting methods. 00124 * 00125 */ 00126 template <class LabelType = double , class PreprocessorTag = ClassificationTag > 00127 class RandomForest 00128 { 00129 00130 public: 00131 //public typedefs 00132 typedef RandomForestOptions Options_t; 00133 typedef detail::DecisionTree DecisionTree_t; 00134 typedef ProblemSpec<LabelType> ProblemSpec_t; 00135 typedef GiniSplit Default_Split_t; 00136 typedef EarlyStoppStd Default_Stop_t; 00137 typedef rf::visitors::StopVisiting Default_Visitor_t; 00138 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator> 00139 StackEntry_t; 00140 typedef LabelType LabelT; 00141 protected: 00142 00143 /** optimisation for predictLabels 00144 * */ 00145 mutable MultiArray<2, double> garbage_prediction_; 00146 00147 public: 00148 00149 //problem independent data. 00150 Options_t options_; 00151 //problem dependent data members - is only set if 00152 //a copy constructor, some sort of import 00153 //function or the learn function is called 00154 ArrayVector<DecisionTree_t> trees_; 00155 ProblemSpec_t ext_param_; 00156 /*mutable ArrayVector<int> tree_indices_;*/ 00157 rf::visitors::OnlineLearnVisitor online_visitor_; 00158 00159 00160 void reset() 00161 { 00162 ext_param_.clear(); 00163 trees_.clear(); 00164 } 00165 00166 public: 00167 00168 /** \name Contructors 00169 * Note: No copy Constructor specified as no pointers are manipulated 00170 * in this class 00171 */ 00172 /*\{*/ 00173 /**\brief default constructor 00174 * 00175 * \param options general options to the Random Forest. Must be of Type 00176 * Options_t 00177 * \param ext_param problem specific values that can be supplied 00178 * additionally. (class weights , labels etc) 00179 * \sa RandomForestOptions, ProblemSpec 00180 * 00181 */ 00182 RandomForest(Options_t const & options = Options_t(), 00183 ProblemSpec_t const & ext_param = ProblemSpec_t()) 00184 : 00185 options_(options), 00186 ext_param_(ext_param)/*, 00187 tree_indices_(options.tree_count_,0)*/ 00188 { 00189 /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii) 00190 tree_indices_[ii] = ii;*/ 00191 } 00192 00193 /**\brief Create RF from external source 00194 * \param treeCount Number of trees to add. 00195 * \param topology_begin 00196 * Iterator to a Container where the topology_ data 00197 * of the trees are stored. 00198 * Iterator should support at least treeCount forward 00199 * iterations. (i.e. topology_end - topology_begin >= treeCount 00200 * \param parameter_begin 00201 * iterator to a Container where the parameters_ data 00202 * of the trees are stored. Iterator should support at 00203 * least treeCount forward iterations. 00204 * \param problem_spec 00205 * Extrinsic parameters that specify the problem e.g. 00206 * ClassCount, featureCount etc. 00207 * \param options (optional) specify options used to train the original 00208 * Random forest. This parameter is not used anywhere 00209 * during prediction and thus is optional. 00210 * 00211 */ 00212 /* TODO: This constructor may be replaced by a Constructor using 00213 * NodeProxy iterators to encapsulate the underlying data type. 00214 */ 00215 template<class TopologyIterator, class ParameterIterator> 00216 RandomForest(int treeCount, 00217 TopologyIterator topology_begin, 00218 ParameterIterator parameter_begin, 00219 ProblemSpec_t const & problem_spec, 00220 Options_t const & options = Options_t()) 00221 : 00222 trees_(treeCount, DecisionTree_t(problem_spec)), 00223 ext_param_(problem_spec), 00224 options_(options) 00225 { 00226 for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin) 00227 { 00228 trees_[k].topology_ = *topology_begin; 00229 trees_[k].parameters_ = *parameter_begin; 00230 } 00231 } 00232 00233 /*\}*/ 00234 00235 00236 /** \name Data Access 00237 * data access interface - usage of member variables is deprecated 00238 */ 00239 00240 /*\{*/ 00241 00242 00243 /**\brief return external parameters for viewing 00244 * \return ProblemSpec_t 00245 */ 00246 ProblemSpec_t const & ext_param() const 00247 { 00248 vigra_precondition(ext_param_.used() == true, 00249 "RandomForest::ext_param(): " 00250 "Random forest has not been trained yet."); 00251 return ext_param_; 00252 } 00253 00254 /**\brief set external parameters 00255 * 00256 * \param in external parameters to be set 00257 * 00258 * set external parameters explicitly. 00259 * If Random Forest has not been trained the preprocessor will 00260 * either ignore filling values set this way or will throw an exception 00261 * if values specified manually do not match the value calculated 00262 & during the preparation step. 00263 */ 00264 void set_ext_param(ProblemSpec_t const & in) 00265 { 00266 vigra_precondition(ext_param_.used() == false, 00267 "RandomForest::set_ext_param():" 00268 "Random forest has been trained! Call reset()" 00269 "before specifying new extrinsic parameters."); 00270 } 00271 00272 /**\brief access random forest options 00273 * 00274 * \return random forest options 00275 */ 00276 Options_t & set_options() 00277 { 00278 return options; 00279 } 00280 00281 00282 /**\brief access const random forest options 00283 * 00284 * \return const Option_t 00285 */ 00286 Options_t const & options() const 00287 { 00288 return options_; 00289 } 00290 00291 /**\brief access const trees 00292 */ 00293 DecisionTree_t const & tree(int index) const 00294 { 00295 return trees_[index]; 00296 } 00297 00298 /**\brief access trees 00299 */ 00300 DecisionTree_t & tree(int index) 00301 { 00302 return trees_[index]; 00303 } 00304 00305 /*\}*/ 00306 00307 /**\brief return number of features used while 00308 * training. 00309 */ 00310 int feature_count() const 00311 { 00312 return ext_param_.column_count_; 00313 } 00314 00315 00316 /**\brief return number of features used while 00317 * training. 00318 * 00319 * deprecated. Use feature_count() instead. 00320 */ 00321 int column_count() const 00322 { 00323 return ext_param_.column_count_; 00324 } 00325 00326 /**\brief return number of classes used while 00327 * training. 00328 */ 00329 int class_count() const 00330 { 00331 return ext_param_.class_count_; 00332 } 00333 00334 /**\brief return number of trees 00335 */ 00336 int tree_count() const 00337 { 00338 return options_.tree_count_; 00339 } 00340 00341 00342 00343 template<class U,class C1, 00344 class U2, class C2, 00345 class Split_t, 00346 class Stop_t, 00347 class Visitor_t, 00348 class Random_t> 00349 void onlineLearn( MultiArrayView<2,U,C1> const & features, 00350 MultiArrayView<2,U2,C2> const & response, 00351 int new_start_index, 00352 Visitor_t visitor_, 00353 Split_t split_, 00354 Stop_t stop_, 00355 Random_t & random, 00356 bool adjust_thresholds=false); 00357 00358 template <class U, class C1, class U2,class C2> 00359 void onlineLearn( MultiArrayView<2, U, C1> const & features, 00360 MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false) 00361 { 00362 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00363 onlineLearn(features, 00364 labels, 00365 new_start_index, 00366 rf_default(), 00367 rf_default(), 00368 rf_default(), 00369 rnd, 00370 adjust_thresholds); 00371 } 00372 00373 template<class U,class C1, 00374 class U2, class C2, 00375 class Split_t, 00376 class Stop_t, 00377 class Visitor_t, 00378 class Random_t> 00379 void reLearnTree(MultiArrayView<2,U,C1> const & features, 00380 MultiArrayView<2,U2,C2> const & response, 00381 int treeId, 00382 Visitor_t visitor_, 00383 Split_t split_, 00384 Stop_t stop_, 00385 Random_t & random); 00386 00387 template<class U, class C1, class U2, class C2> 00388 void reLearnTree(MultiArrayView<2, U, C1> const & features, 00389 MultiArrayView<2, U2, C2> const & labels, 00390 int treeId) 00391 { 00392 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00393 reLearnTree(features, 00394 labels, 00395 treeId, 00396 rf_default(), 00397 rf_default(), 00398 rf_default(), 00399 rnd); 00400 } 00401 00402 00403 /**\name Learning 00404 * Following functions differ in the degree of customization 00405 * allowed 00406 */ 00407 /*\{*/ 00408 /**\brief learn on data with custom config and random number generator 00409 * 00410 * \param features a N x M matrix containing N samples with M 00411 * features 00412 * \param response a N x D matrix containing the corresponding 00413 * response. Current split functors assume D to 00414 * be 1 and ignore any additional columns. 00415 * This is not enforced to allow future support 00416 * for uncertain labels, label independent strata etc. 00417 * The Preprocessor specified during construction 00418 * should be able to handle features and labels 00419 * features and the labels. 00420 * see also: SplitFunctor, Preprocessing 00421 * 00422 * \param visitor visitor which is to be applied after each split, 00423 * tree and at the end. Use rf_default for using 00424 * default value. (No Visitors) 00425 * see also: rf::visitors 00426 * \param split split functor to be used to calculate each split 00427 * use rf_default() for using default value. (GiniSplit) 00428 * see also: rf::split 00429 * \param stop 00430 * predicate to be used to calculate each split 00431 * use rf_default() for using default value. (EarlyStoppStd) 00432 * \param random RandomNumberGenerator to be used. Use 00433 * rf_default() to use default value.(RandomMT19337) 00434 * 00435 * 00436 */ 00437 template <class U, class C1, 00438 class U2,class C2, 00439 class Split_t, 00440 class Stop_t, 00441 class Visitor_t, 00442 class Random_t> 00443 void learn( MultiArrayView<2, U, C1> const & features, 00444 MultiArrayView<2, U2,C2> const & response, 00445 Visitor_t visitor, 00446 Split_t split, 00447 Stop_t stop, 00448 Random_t const & random); 00449 00450 template <class U, class C1, 00451 class U2,class C2, 00452 class Split_t, 00453 class Stop_t, 00454 class Visitor_t> 00455 void learn( MultiArrayView<2, U, C1> const & features, 00456 MultiArrayView<2, U2,C2> const & response, 00457 Visitor_t visitor, 00458 Split_t split, 00459 Stop_t stop) 00460 00461 { 00462 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00463 learn( features, 00464 response, 00465 visitor, 00466 split, 00467 stop, 00468 rnd); 00469 } 00470 00471 template <class U, class C1, class U2,class C2, class Visitor_t> 00472 void learn( MultiArrayView<2, U, C1> const & features, 00473 MultiArrayView<2, U2,C2> const & labels, 00474 Visitor_t visitor) 00475 { 00476 learn( features, 00477 labels, 00478 visitor, 00479 rf_default(), 00480 rf_default()); 00481 } 00482 00483 template <class U, class C1, class U2,class C2, 00484 class Visitor_t, class Split_t> 00485 void learn( MultiArrayView<2, U, C1> const & features, 00486 MultiArrayView<2, U2,C2> const & labels, 00487 Visitor_t visitor, 00488 Split_t split) 00489 { 00490 learn( features, 00491 labels, 00492 visitor, 00493 split, 00494 rf_default()); 00495 } 00496 00497 /**\brief learn on data with default configuration 00498 * 00499 * \param features a N x M matrix containing N samples with M 00500 * features 00501 * \param labels a N x D matrix containing the corresponding 00502 * N labels. Current split functors assume D to 00503 * be 1 and ignore any additional columns. 00504 * this is not enforced to allow future support 00505 * for uncertain labels. 00506 * 00507 * learning is done with: 00508 * 00509 * \sa rf::split, EarlyStoppStd 00510 * 00511 * - Randomly seeded random number generator 00512 * - default gini split functor as described by Breiman 00513 * - default The standard early stopping criterion 00514 */ 00515 template <class U, class C1, class U2,class C2> 00516 void learn( MultiArrayView<2, U, C1> const & features, 00517 MultiArrayView<2, U2,C2> const & labels) 00518 { 00519 learn( features, 00520 labels, 00521 rf_default(), 00522 rf_default(), 00523 rf_default()); 00524 } 00525 /*\}*/ 00526 00527 00528 00529 /**\name prediction 00530 */ 00531 /*\{*/ 00532 /** \brief predict a label given a feature. 00533 * 00534 * \param features: a 1 by featureCount matrix containing 00535 * data point to be predicted (this only works in 00536 * classification setting) 00537 * \param stop: early stopping critierion 00538 * \return double value representing class. You can use the 00539 * predictLabels() function together with the 00540 * rf.external_parameter().class_type_ attribute 00541 * to get back the same type used during learning. 00542 */ 00543 template <class U, class C, class Stop> 00544 LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const; 00545 00546 template <class U, class C> 00547 LabelType predictLabel(MultiArrayView<2, U, C>const & features) 00548 { 00549 return predictLabel(features, rf_default()); 00550 } 00551 /** \brief predict a label with features and class priors 00552 * 00553 * \param features: same as above. 00554 * \param prior: iterator to prior weighting of classes 00555 * \return sam as above. 00556 */ 00557 template <class U, class C> 00558 LabelType predictLabel(MultiArrayView<2, U, C> const & features, 00559 ArrayVectorView<double> prior) const; 00560 00561 /** \brief predict multiple labels with given features 00562 * 00563 * \param features: a n by featureCount matrix containing 00564 * data point to be predicted (this only works in 00565 * classification setting) 00566 * \param labels: a n by 1 matrix passed by reference to store 00567 * output. 00568 */ 00569 template <class U, class C1, class T, class C2> 00570 void predictLabels(MultiArrayView<2, U, C1>const & features, 00571 MultiArrayView<2, T, C2> & labels) const 00572 { 00573 vigra_precondition(features.shape(0) == labels.shape(0), 00574 "RandomForest::predictLabels(): Label array has wrong size."); 00575 for(int k=0; k<features.shape(0); ++k) 00576 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default())); 00577 } 00578 00579 template <class U, class C1, class T, class C2, class Stop> 00580 void predictLabels(MultiArrayView<2, U, C1>const & features, 00581 MultiArrayView<2, T, C2> & labels, 00582 Stop & stop) const 00583 { 00584 vigra_precondition(features.shape(0) == labels.shape(0), 00585 "RandomForest::predictLabels(): Label array has wrong size."); 00586 for(int k=0; k<features.shape(0); ++k) 00587 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop)); 00588 } 00589 /** \brief predict the class probabilities for multiple labels 00590 * 00591 * \param features same as above 00592 * \param prob a n x class_count_ matrix. passed by reference to 00593 * save class probabilities 00594 * \param stop earlystopping criterion 00595 * \sa EarlyStopping 00596 */ 00597 template <class U, class C1, class T, class C2, class Stop> 00598 void predictProbabilities(MultiArrayView<2, U, C1>const & features, 00599 MultiArrayView<2, T, C2> & prob, 00600 Stop & stop) const; 00601 template <class T1,class T2, class C> 00602 void predictProbabilities(OnlinePredictionSet<T1> & predictionSet, 00603 MultiArrayView<2, T2, C> & prob); 00604 00605 /** \brief predict the class probabilities for multiple labels 00606 * 00607 * \param features same as above 00608 * \param prob a n x class_count_ matrix. passed by reference to 00609 * save class probabilities 00610 */ 00611 template <class U, class C1, class T, class C2> 00612 void predictProbabilities(MultiArrayView<2, U, C1>const & features, 00613 MultiArrayView<2, T, C2> & prob) const 00614 { 00615 predictProbabilities(features, prob, rf_default()); 00616 } 00617 00618 00619 /*\}*/ 00620 00621 }; 00622 00623 00624 template <class LabelType, class PreprocessorTag> 00625 template<class U,class C1, 00626 class U2, class C2, 00627 class Split_t, 00628 class Stop_t, 00629 class Visitor_t, 00630 class Random_t> 00631 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features, 00632 MultiArrayView<2,U2,C2> const & response, 00633 int new_start_index, 00634 Visitor_t visitor_, 00635 Split_t split_, 00636 Stop_t stop_, 00637 Random_t & random, 00638 bool adjust_thresholds) 00639 { 00640 online_visitor_.activate(); 00641 online_visitor_.adjust_thresholds=adjust_thresholds; 00642 00643 using namespace rf; 00644 //typedefs 00645 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t; 00646 typedef UniformIntRandomFunctor<Random_t> 00647 RandFunctor_t; 00648 // default values and initialization 00649 // Value Chooser chooses second argument as value if first argument 00650 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00651 // it - just smile and wave. 00652 00653 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00654 Default_Stop_t default_stop(options_); 00655 typename RF_CHOOSER(Stop_t)::type stop 00656 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00657 Default_Split_t default_split; 00658 typename RF_CHOOSER(Split_t)::type split 00659 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00660 rf::visitors::StopVisiting stopvisiting; 00661 typedef rf::visitors::detail::VisitorNode 00662 <rf::visitors::OnlineLearnVisitor, 00663 typename RF_CHOOSER(Visitor_t)::type> 00664 IntermedVis; 00665 IntermedVis 00666 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00667 #undef RF_CHOOSER 00668 00669 // Preprocess the data to get something the split functor can work 00670 // with. Also fill the ext_param structure by preprocessing 00671 // option parameters that could only be completely evaluated 00672 // when the training data is known. 00673 ext_param_.class_count_=0; 00674 Preprocessor_t preprocessor( features, response, 00675 options_, ext_param_); 00676 00677 // Make stl compatible random functor. 00678 RandFunctor_t randint ( random); 00679 00680 // Give the Split functor information about the data. 00681 split.set_external_parameters(ext_param_); 00682 stop.set_external_parameters(ext_param_); 00683 00684 00685 //Create poisson samples 00686 PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_)); 00687 00688 //TODO: visitors for online learning 00689 //visitor.visit_at_beginning(*this, preprocessor); 00690 00691 // THE MAIN EFFING RF LOOP - YEAY DUDE! 00692 for(int ii = 0; ii < (int)trees_.size(); ++ii) 00693 { 00694 online_visitor_.tree_id=ii; 00695 poisson_sampler.sample(); 00696 std::map<int,int> leaf_parents; 00697 leaf_parents.clear(); 00698 //Get all the leaf nodes for that sample 00699 for(int s=0;s<poisson_sampler.numOfSamples();++s) 00700 { 00701 int sample=poisson_sampler[s]; 00702 online_visitor_.current_label=preprocessor.response()(sample,0); 00703 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent; 00704 int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_); 00705 00706 00707 //Add to the list for that leaf 00708 online_visitor_.add_to_index_list(ii,leaf,sample); 00709 //TODO: Class count? 00710 //Store parent 00711 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0) 00712 { 00713 leaf_parents[leaf]=online_visitor_.last_node_id; 00714 } 00715 } 00716 00717 00718 std::map<int,int>::iterator leaf_iterator; 00719 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator) 00720 { 00721 int leaf=leaf_iterator->first; 00722 int parent=leaf_iterator->second; 00723 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf]; 00724 ArrayVector<Int32> indeces; 00725 indeces.clear(); 00726 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]); 00727 StackEntry_t stack_entry(indeces.begin(), 00728 indeces.end(), 00729 ext_param_.class_count_); 00730 00731 00732 if(parent!=-1) 00733 { 00734 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf) 00735 { 00736 stack_entry.leftParent=parent; 00737 } 00738 else 00739 { 00740 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong"); 00741 stack_entry.rightParent=parent; 00742 } 00743 } 00744 //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf); 00745 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1); 00746 //Now, the last one moved onto leaf 00747 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf); 00748 //Now it should be classified correctly! 00749 } 00750 00751 /*visitor 00752 .visit_after_tree( *this, 00753 preprocessor, 00754 poisson_sampler, 00755 stack_entry, 00756 ii);*/ 00757 } 00758 00759 //visitor.visit_at_end(*this, preprocessor); 00760 online_visitor_.deactivate(); 00761 } 00762 00763 template<class LabelType, class PreprocessorTag> 00764 template<class U,class C1, 00765 class U2, class C2, 00766 class Split_t, 00767 class Stop_t, 00768 class Visitor_t, 00769 class Random_t> 00770 void RandomForest<LabelType, PreprocessorTag>::reLearnTree(MultiArrayView<2,U,C1> const & features, 00771 MultiArrayView<2,U2,C2> const & response, 00772 int treeId, 00773 Visitor_t visitor_, 00774 Split_t split_, 00775 Stop_t stop_, 00776 Random_t & random) 00777 { 00778 using namespace rf; 00779 00780 00781 typedef UniformIntRandomFunctor<Random_t> 00782 RandFunctor_t; 00783 00784 // See rf_preprocessing.hxx for more info on this 00785 ext_param_.class_count_=0; 00786 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t; 00787 00788 // default values and initialization 00789 // Value Chooser chooses second argument as value if first argument 00790 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00791 // it - just smile and wave. 00792 00793 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00794 Default_Stop_t default_stop(options_); 00795 typename RF_CHOOSER(Stop_t)::type stop 00796 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00797 Default_Split_t default_split; 00798 typename RF_CHOOSER(Split_t)::type split 00799 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00800 rf::visitors::StopVisiting stopvisiting; 00801 typedef rf::visitors::detail::VisitorNode 00802 <rf::visitors::OnlineLearnVisitor, 00803 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 00804 IntermedVis 00805 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00806 #undef RF_CHOOSER 00807 vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled"); 00808 online_visitor_.activate(); 00809 00810 // Make stl compatible random functor. 00811 RandFunctor_t randint ( random); 00812 00813 // Preprocess the data to get something the split functor can work 00814 // with. Also fill the ext_param structure by preprocessing 00815 // option parameters that could only be completely evaluated 00816 // when the training data is known. 00817 Preprocessor_t preprocessor( features, response, 00818 options_, ext_param_); 00819 00820 // Give the Split functor information about the data. 00821 split.set_external_parameters(ext_param_); 00822 stop.set_external_parameters(ext_param_); 00823 00824 /**\todo replace this crappy class out. It uses function pointers. 00825 * and is making code slower according to me. 00826 * Comment from Nathan: This is copied from Rahul, so me=Rahul 00827 */ 00828 Sampler<Random_t > sampler(preprocessor.strata().begin(), 00829 preprocessor.strata().end(), 00830 detail::make_sampler_opt(options_) 00831 .sampleSize(ext_param().actual_msample_), 00832 random); 00833 //initialize First region/node/stack entry 00834 sampler 00835 .sample(); 00836 00837 StackEntry_t 00838 first_stack_entry( sampler.sampledIndices().begin(), 00839 sampler.sampledIndices().end(), 00840 ext_param_.class_count_); 00841 first_stack_entry 00842 .set_oob_range( sampler.oobIndices().begin(), 00843 sampler.oobIndices().end()); 00844 online_visitor_.reset_tree(treeId); 00845 online_visitor_.tree_id=treeId; 00846 trees_[treeId].reset(); 00847 trees_[treeId] 00848 .learn( preprocessor.features(), 00849 preprocessor.response(), 00850 first_stack_entry, 00851 split, 00852 stop, 00853 visitor, 00854 randint); 00855 visitor 00856 .visit_after_tree( *this, 00857 preprocessor, 00858 sampler, 00859 first_stack_entry, 00860 treeId); 00861 00862 online_visitor_.deactivate(); 00863 } 00864 00865 template <class LabelType, class PreprocessorTag> 00866 template <class U, class C1, 00867 class U2,class C2, 00868 class Split_t, 00869 class Stop_t, 00870 class Visitor_t, 00871 class Random_t> 00872 void RandomForest<LabelType, PreprocessorTag>:: 00873 learn( MultiArrayView<2, U, C1> const & features, 00874 MultiArrayView<2, U2,C2> const & response, 00875 Visitor_t visitor_, 00876 Split_t split_, 00877 Stop_t stop_, 00878 Random_t const & random) 00879 { 00880 using namespace rf; 00881 //this->reset(); 00882 //typedefs 00883 typedef UniformIntRandomFunctor<Random_t> 00884 RandFunctor_t; 00885 00886 // See rf_preprocessing.hxx for more info on this 00887 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t; 00888 00889 // default values and initialization 00890 // Value Chooser chooses second argument as value if first argument 00891 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00892 // it - just smile and wave. 00893 00894 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00895 Default_Stop_t default_stop(options_); 00896 typename RF_CHOOSER(Stop_t)::type stop 00897 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00898 Default_Split_t default_split; 00899 typename RF_CHOOSER(Split_t)::type split 00900 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00901 rf::visitors::StopVisiting stopvisiting; 00902 typedef rf::visitors::detail::VisitorNode< 00903 rf::visitors::OnlineLearnVisitor, 00904 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 00905 IntermedVis 00906 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00907 #undef RF_CHOOSER 00908 if(options_.prepare_online_learning_) 00909 online_visitor_.activate(); 00910 else 00911 online_visitor_.deactivate(); 00912 00913 00914 // Make stl compatible random functor. 00915 RandFunctor_t randint ( random); 00916 00917 00918 // Preprocess the data to get something the split functor can work 00919 // with. Also fill the ext_param structure by preprocessing 00920 // option parameters that could only be completely evaluated 00921 // when the training data is known. 00922 Preprocessor_t preprocessor( features, response, 00923 options_, ext_param_); 00924 00925 // Give the Split functor information about the data. 00926 split.set_external_parameters(ext_param_); 00927 stop.set_external_parameters(ext_param_); 00928 00929 00930 //initialize trees. 00931 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_)); 00932 00933 Sampler<Random_t > sampler(preprocessor.strata().begin(), 00934 preprocessor.strata().end(), 00935 detail::make_sampler_opt(options_) 00936 .sampleSize(ext_param().actual_msample_), 00937 random); 00938 00939 visitor.visit_at_beginning(*this, preprocessor); 00940 // THE MAIN EFFING RF LOOP - YEAY DUDE! 00941 00942 for(int ii = 0; ii < (int)trees_.size(); ++ii) 00943 { 00944 //initialize First region/node/stack entry 00945 sampler 00946 .sample(); 00947 StackEntry_t 00948 first_stack_entry( sampler.sampledIndices().begin(), 00949 sampler.sampledIndices().end(), 00950 ext_param_.class_count_); 00951 first_stack_entry 00952 .set_oob_range( sampler.oobIndices().begin(), 00953 sampler.oobIndices().end()); 00954 trees_[ii] 00955 .learn( preprocessor.features(), 00956 preprocessor.response(), 00957 first_stack_entry, 00958 split, 00959 stop, 00960 visitor, 00961 randint); 00962 visitor 00963 .visit_after_tree( *this, 00964 preprocessor, 00965 sampler, 00966 first_stack_entry, 00967 ii); 00968 } 00969 00970 visitor.visit_at_end(*this, preprocessor); 00971 // Only for online learning? 00972 online_visitor_.deactivate(); 00973 } 00974 00975 00976 00977 00978 template <class LabelType, class Tag> 00979 template <class U, class C, class Stop> 00980 LabelType RandomForest<LabelType, Tag> 00981 ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const 00982 { 00983 vigra_precondition(columnCount(features) >= ext_param_.column_count_, 00984 "RandomForestn::predictLabel():" 00985 " Too few columns in feature matrix."); 00986 vigra_precondition(rowCount(features) == 1, 00987 "RandomForestn::predictLabel():" 00988 " Feature matrix must have a singlerow."); 00989 typedef MultiArrayShape<2>::type Shp; 00990 garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0); 00991 LabelType d; 00992 predictProbabilities(features, garbage_prediction_, stop); 00993 ext_param_.to_classlabel(argMax(garbage_prediction_), d); 00994 return d; 00995 } 00996 00997 00998 //Same thing as above with priors for each label !!! 00999 template <class LabelType, class PreprocessorTag> 01000 template <class U, class C> 01001 LabelType RandomForest<LabelType, PreprocessorTag> 01002 ::predictLabel( MultiArrayView<2, U, C> const & features, 01003 ArrayVectorView<double> priors) const 01004 { 01005 using namespace functor; 01006 vigra_precondition(columnCount(features) >= ext_param_.column_count_, 01007 "RandomForestn::predictLabel(): Too few columns in feature matrix."); 01008 vigra_precondition(rowCount(features) == 1, 01009 "RandomForestn::predictLabel():" 01010 " Feature matrix must have a single row."); 01011 Matrix<double> prob(1,ext_param_.class_count_); 01012 predictProbabilities(features, prob); 01013 std::transform( prob.begin(), prob.end(), 01014 priors.begin(), prob.begin(), 01015 Arg1()*Arg2()); 01016 LabelType d; 01017 ext_param_.to_classlabel(argMax(prob), d); 01018 return d; 01019 } 01020 01021 template<class LabelType,class PreprocessorTag> 01022 template <class T1,class T2, class C> 01023 void RandomForest<LabelType,PreprocessorTag> 01024 ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet, 01025 MultiArrayView<2, T2, C> & prob) 01026 { 01027 //Features are n xp 01028 //prob is n x NumOfLabel probaility for each feature in each class 01029 01030 vigra_precondition(rowCount(predictionSet.features) == rowCount(prob), 01031 "RandomFroest::predictProbabilities():" 01032 " Feature matrix and probability matrix size misnmatch."); 01033 // num of features must be bigger than num of features in Random forest training 01034 // but why bigger? 01035 vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_, 01036 "RandomForestn::predictProbabilities():" 01037 " Too few columns in feature matrix."); 01038 vigra_precondition( columnCount(prob) 01039 == (MultiArrayIndex)ext_param_.class_count_, 01040 "RandomForestn::predictProbabilities():" 01041 " Probability matrix must have as many columns as there are classes."); 01042 prob.init(0.0); 01043 //store total weights 01044 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0); 01045 //Go through all trees 01046 int set_id=-1; 01047 for(int k=0; k<options_.tree_count_; ++k) 01048 { 01049 set_id=(set_id+1) % predictionSet.indices[0].size(); 01050 typedef std::set<SampleRange<T1> > my_set; 01051 typedef typename my_set::iterator set_it; 01052 //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it; 01053 //Build a stack with all the ranges we have 01054 std::vector<std::pair<int,set_it> > stack; 01055 stack.clear(); 01056 set_it i; 01057 for(i=predictionSet.ranges[set_id].begin();i!=predictionSet.ranges[set_id].end();++i) 01058 stack.push_back(std::pair<int,set_it>(2,i)); 01059 //get weights predicted by single tree 01060 int num_decisions=0; 01061 while(!stack.empty()) 01062 { 01063 set_it range=stack.back().second; 01064 int index=stack.back().first; 01065 stack.pop_back(); 01066 ++num_decisions; 01067 01068 if(trees_[k].isLeafNode(trees_[k].topology_[index])) 01069 { 01070 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_, 01071 trees_[k].parameters_, 01072 index).prob_begin(); 01073 for(int i=range->start;i!=range->end;++i) 01074 { 01075 //update votecount. 01076 for(int l=0; l<ext_param_.class_count_; ++l) 01077 { 01078 prob(predictionSet.indices[set_id][i], l) += (T2)weights[l]; 01079 //every weight in totalWeight. 01080 totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l]; 01081 } 01082 } 01083 } 01084 01085 else 01086 { 01087 if(trees_[k].topology_[index]!=i_ThresholdNode) 01088 { 01089 throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes"); 01090 } 01091 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index); 01092 if(range->min_boundaries[node.column()]>=node.threshold()) 01093 { 01094 //Everything goes to right child 01095 stack.push_back(std::pair<int,set_it>(node.child(1),range)); 01096 continue; 01097 } 01098 if(range->max_boundaries[node.column()]<node.threshold()) 01099 { 01100 //Everything goes to the left child 01101 stack.push_back(std::pair<int,set_it>(node.child(0),range)); 01102 continue; 01103 } 01104 //We have to split at this node 01105 SampleRange<T1> new_range=*range; 01106 new_range.min_boundaries[node.column()]=FLT_MAX; 01107 range->max_boundaries[node.column()]=-FLT_MAX; 01108 new_range.start=new_range.end=range->end; 01109 int i=range->start; 01110 while(i!=range->end) 01111 { 01112 //Decide for range->indices[i] 01113 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold()) 01114 { 01115 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()], 01116 predictionSet.features(predictionSet.indices[set_id][i],node.column())); 01117 --range->end; 01118 --new_range.start; 01119 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]); 01120 01121 } 01122 else 01123 { 01124 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()], 01125 predictionSet.features(predictionSet.indices[set_id][i],node.column())); 01126 ++i; 01127 } 01128 } 01129 //The old one ... 01130 if(range->start==range->end) 01131 { 01132 predictionSet.ranges[set_id].erase(range); 01133 } 01134 else 01135 { 01136 stack.push_back(std::pair<int,set_it>(node.child(0),range)); 01137 } 01138 //And the new one ... 01139 if(new_range.start!=new_range.end) 01140 { 01141 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range); 01142 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first)); 01143 } 01144 } 01145 } 01146 predictionSet.cumulativePredTime[k]=num_decisions; 01147 } 01148 for(unsigned int i=0;i<totalWeights.size();++i) 01149 { 01150 double test=0.0; 01151 //Normalise votes in each row by total VoteCount (totalWeight 01152 for(int l=0; l<ext_param_.class_count_; ++l) 01153 { 01154 test+=prob(i,l); 01155 prob(i, l) /= totalWeights[i]; 01156 } 01157 assert(test==totalWeights[i]); 01158 assert(totalWeights[i]>0.0); 01159 } 01160 } 01161 01162 template <class LabelType, class PreprocessorTag> 01163 template <class U, class C1, class T, class C2, class Stop_t> 01164 void RandomForest<LabelType, PreprocessorTag> 01165 ::predictProbabilities(MultiArrayView<2, U, C1>const & features, 01166 MultiArrayView<2, T, C2> & prob, 01167 Stop_t & stop_) const 01168 { 01169 //Features are n xp 01170 //prob is n x NumOfLabel probability for each feature in each class 01171 01172 vigra_precondition(rowCount(features) == rowCount(prob), 01173 "RandomForestn::predictProbabilities():" 01174 " Feature matrix and probability matrix size mismatch."); 01175 01176 // num of features must be bigger than num of features in Random forest training 01177 // but why bigger? 01178 vigra_precondition( columnCount(features) >= ext_param_.column_count_, 01179 "RandomForestn::predictProbabilities():" 01180 " Too few columns in feature matrix."); 01181 vigra_precondition( columnCount(prob) 01182 == (MultiArrayIndex)ext_param_.class_count_, 01183 "RandomForestn::predictProbabilities():" 01184 " Probability matrix must have as many columns as there are classes."); 01185 01186 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 01187 Default_Stop_t default_stop(options_); 01188 typename RF_CHOOSER(Stop_t)::type & stop 01189 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 01190 #undef RF_CHOOSER 01191 stop.set_external_parameters(ext_param_, tree_count()); 01192 prob.init(NumericTraits<T>::zero()); 01193 /* This code was originally there for testing early stopping 01194 * - we wanted the order of the trees to be randomized 01195 if(tree_indices_.size() != 0) 01196 { 01197 std::random_shuffle(tree_indices_.begin(), 01198 tree_indices_.end()); 01199 } 01200 */ 01201 //Classify for each row. 01202 for(int row=0; row < rowCount(features); ++row) 01203 { 01204 ArrayVector<double>::const_iterator weights; 01205 01206 //totalWeight == totalVoteCount! 01207 double totalWeight = 0.0; 01208 01209 //Let each tree classify... 01210 for(int k=0; k<options_.tree_count_; ++k) 01211 { 01212 //get weights predicted by single tree 01213 weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row)); 01214 01215 //update votecount. 01216 int weighted = options_.predict_weighted_; 01217 for(int l=0; l<ext_param_.class_count_; ++l) 01218 { 01219 double cur_w = weights[l] * (weighted * (*(weights-1)) 01220 + (1-weighted)); 01221 prob(row, l) += (T)cur_w; 01222 //every weight in totalWeight. 01223 totalWeight += cur_w; 01224 } 01225 if(stop.after_prediction(weights, 01226 k, 01227 rowVector(prob, row), 01228 totalWeight)) 01229 { 01230 break; 01231 } 01232 } 01233 01234 //Normalise votes in each row by total VoteCount (totalWeight 01235 for(int l=0; l< ext_param_.class_count_; ++l) 01236 { 01237 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight); 01238 } 01239 } 01240 01241 } 01242 01243 //@} 01244 01245 } // namespace vigra 01246 01247 #include "random_forest/rf_algorithm.hxx" 01248 #endif // VIGRA_RANDOM_FOREST_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|