[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_algorithm_prototyping.hxx VIGRA

00001 #ifndef VIGRA_RF_ALGORITHM_HXX
00002 #define VIGRA_RF_ALGORTIHM_HXX
00003 /* First idea for algorithms class.
00004  * delete file once all methods and classes have been ported
00005  */
00006 #include <vector>
00007 
00008 namespace vigra
00009 {
00010 namespace rf
00011 {
00012 /** This namespace contains all algorithms developed for feature 
00013  * selection
00014  *
00015  */
00016 namespace algorithms
00017 {
00018 
00019 /** Container for output
00020  */
00021 typedef std::vector<int> SelFeats_t;
00022 
00023 
00024 /** Variable Selection Error summary
00025  */
00026 struct VarSelectOutput
00027 {
00028     double before;
00029     double after;
00030     
00031 
00032     /** Error rate before variable
00033      * selection
00034      */
00035     double before()
00036     {
00037         return before;
00038     }
00039 
00040 
00041     /** Error rate after variable selection
00042      */
00043     double after()
00044     {
00045         return after;
00046     }
00047     
00048 
00049     /** before - after
00050      */
00051     double improvement()
00052     {
00053         return before - after;
00054     }
00055 };
00056 
00057 
00058 /** Classifier typedef.
00059  *  has been kept as typedef as somebody may want to apply the
00060  *  same algorithms with different classifiers - made to simplify
00061  *  porting when necessary
00062  */
00063 typedef RandomForest<>   ClassifierT;
00064 
00065 /** Perform forward selection using Random FOrests
00066  * 
00067  * \param features  Matrix containing the features used.
00068  * \param label     Matrix containing the corresponding labels
00069  * \param selected_feats 
00070  *                  - output. Linear container which will contain
00071  *                  the selected features.
00072  * \output Selection summary.
00073  * \param tolerance between best solution and selected solution 
00074  * with less features (fraction of error with all features)
00075  *
00076  *  FeatureT and LabelT should be vigra::MultiArray compatible 
00077  *  SelFeatsT should be a back insertion containerr i.e. std::vector
00078  *  or vigra::ArrayVector
00079  */
00080 template<class FeatureT, class LabelT>
00081 VarSelectOutput forward_select(     FeatureT & features, 
00082                                     LabelT & labels, 
00083                                     SelFeatsT & selected_feats,
00084                                     double tolerance = 0.0)
00085 {
00086     int featureCount = features.shape(1);
00087     std::vector<int> selected;
00088     std::vector<int> not_selected;
00089     for(int ii = 0; ii < featureCount; ++ii)
00090     {
00091         not_selected.push_back(ii);
00092     }
00093     while(not_selected.size() != 0)
00094     {
00095         std::vector<int> current_errors(not_selected.size(), 1);
00096         for(int ii = 0; ii < not_selected.size(); ++ii)
00097         {
00098             selected.push_back(not_selected[ii]);
00099             MultiArray<2, double> cur_feats = choose( features, 
00100                                                       selected.begin(), 
00101                                                       selected.end());
00102             selected.pop_back();
00103             visitors::OOB_Error oob;
00104             visitors::RandomForestProgressVisitor progress;
00105             ClassifierT classifier;
00106             classifier.learn(cur_feats, 
00107                              labels, 
00108                              create_visitor(oob, progress));
00109             current_errors.push_back(oob.oob_breiman);
00110         }
00111         int pos = std::min_element(current_errors.begin(),
00112                                    current_errors.end()) 
00113                   -     current_errors.begin();
00114         selected.push_back(not_selected[pos]);
00115         errors.push_back(current_errors[pos]);
00116         not_selected.erase(pos);
00117     }
00118 }
00119 
00120 /** Perform backward elimination using Random Forests
00121  * 
00122  * \param features  Matrix containing the features used.
00123  * \param label     Matrix containing the corresponding labels
00124  * \param selected_feats 
00125  *                  - output. Linear container which will contain
00126  *                  the selected features.
00127  * \output Selection summary.
00128  * \param tolerance between best solution and selected solution 
00129  * with less features (fraction of error with all features)
00130  *
00131  *  FeatureT and LabelT should be vigra::MultiArray compatible 
00132  *  SelFeatsT should be a back insertion containerr i.e. std::vector
00133  *  or vigra::ArrayVector
00134  */
00135 template<class FeatureT, class LabelT>
00136 VarSelectOutput backward_eliminate( FeatureT & features, 
00137                                     LabelT & labels, 
00138                                     SelFeatsT & selected_feats,
00139                                     double tolerance = 0.0);
00140 {
00141     int featureCount = features.shape(1);
00142     std::vector<int> selected;
00143     std::vector<int> not_selected;
00144     for(int ii = 0; ii < featureCount; ++ii)
00145     {
00146         selected.push_back(ii);
00147     }
00148     while(selected.size() != 0)
00149     {
00150         std::vector<int> current_errors(not_selected.size(), 1);
00151         for(int ii = 0; ii < not_selected.size(); ++ii)
00152         {
00153             selected.push_back(not_selected[ii]);
00154             MultiArray<2, double> cur_feats = choose( features, 
00155                                                       selected.begin(), 
00156                                                       selected.end());
00157             selected.pop_back();
00158             visitors::OOB_Error oob;
00159             visitors::RandomForestProgressVisitor progress;
00160             ClassifierT classifier;
00161             classifier.learn(cur_feats, 
00162                              labels, 
00163                              create_visitor(oob, progress));
00164             current_errors.push_back(oob.oob_breiman);
00165         }
00166         int pos = std::min_element(current_errors.begin(),
00167                                    current_errors.end()) 
00168                   -     current_errors.begin();
00169         selected.push_back(not_selected[pos]);
00170         errors.push_back(current_errors[pos]);
00171         not_selected.erase(pos);
00172     }
00173 
00174 }
00175 
00176 
00177 /** Perform rank selection using Random Forests and a fixed predefined
00178  *  ranking.
00179  * 
00180  * \param features  Matrix containing the features used.
00181  * \param label     Matrix containing the corresponding labels
00182  * \param ranking   ranking of features by relevance (i.e. rf variable
00183  *                  importance measures)
00184  * \param selected_feats 
00185  *                  - output. Linear container which will contain
00186  *                  the selected features.
00187  * \output Selection summary.
00188  * \param tolerance between best solution and selected solution 
00189  * with less features (fraction of error with all features)
00190  *
00191  *  FeatureT and LabelT should be vigra::MultiArray compatible 
00192  *  SelFeatsT should be a back insertion containerr i.e. std::vector
00193  *  or vigra::ArrayVector
00194  */
00195 template<class FeatureT, class LabelT, class RankingT>
00196 VarSelectOutput rank_select(            FeatureT & features, 
00197                                         LabelT & labels, 
00198                                         RankingT & ranking,
00199                                         SelFeatsT & selected_feats,
00200                                         double tolerance = 0.0);
00201 {
00202     typename RankingT::iterator iter = ranking.begin();
00203     for(; iter != ranking.end(); ++iter)
00204     {
00205         MultiArray<2, double> cur_feats = choose( features, 
00206                                                   ranking.begin(), 
00207                                                   iter);
00208         ClassifierT classifier;
00209         classifier.learn(cur_feats, 
00210                          labels, 
00211                          create_visitor(oob, progress));
00212         errors.push_back(oob.oob_breiman);
00213 
00214     }
00215 
00216 }
00217 
00218 }//namespace algorithms
00219 }//namespace rf
00220 }//namespace vigra
00221 
00222 #undef //VIGRA_RF_ALGORITHM_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (25 Nov 2010)