[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx
1/************************************************************************/
2/* */
3/* Copyright 2014-2015 by Ullrich Koethe and Philip Schill */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef VIGRA_RF3_RANDOM_FOREST_HXX
36#define VIGRA_RF3_RANDOM_FOREST_HXX
37
38#include <type_traits>
39#include <thread>
40
41#include "../multi_shape.hxx"
42#include "../binary_forest.hxx"
43#include "../threadpool.hxx"
44#include "random_forest_common.hxx"
45
46
47
48namespace vigra
49{
50
51namespace rf3
52{
53
54/********************************************************/
55/* */
56/* rf3::RandomForest */
57/* */
58/********************************************************/
59
60/** \brief Random forest version 3.
61
62 vigra::rf3::RandomForest is typicall constructed via the factory function \ref vigra::rf3::random_forest().
63*/
64template <typename FEATURES,
65 typename LABELS,
66 typename SPLITTESTS = LessEqualSplitTest<typename FEATURES::value_type>,
67 typename ACCTYPE = ArgMaxVectorAcc<double>>
68class RandomForest
69{
70public:
71
72 typedef FEATURES Features;
73 typedef typename Features::value_type FeatureType;
74 typedef LABELS Labels;
75 typedef typename Labels::value_type LabelType;
76 typedef SPLITTESTS SplitTests;
77 typedef ACCTYPE ACC;
78 typedef typename ACC::input_type AccInputType;
79 typedef BinaryForest Graph;
80 typedef Graph::Node Node;
81
82 static ContainerTag const container_tag = VectorTag;
83
84 // FIXME:
85 // Once the support for Visual Studio 2012 is dropped, replace this struct with
86 // template <typename T>
87 // using NodeMap = PropertyMap<Node, T, container_tag>;
88 // Then the verbose typename NodeMap<T>::type, which typically shows up on NodeMap usages,
89 // can be replace with NodeMap<T>.
90 template <typename T>
91 struct NodeMap
92 {
94 };
95
96 // Default (empty) constructor.
97 RandomForest();
98
99 // Default constructor (copy all of the given stuff).
100 RandomForest(
101 Graph const & graph,
102 typename NodeMap<SplitTests>::type const & split_tests,
103 typename NodeMap<AccInputType>::type const & node_responses,
104 ProblemSpec<LabelType> const & problem_spec
105 );
106
107 /// \brief Grow this forest by incorporating the other.
108 void merge(
109 RandomForest const & other
110 );
111
112 /// \brief Predict the given data and return the average number of split comparisons.
113 /// \note labels must be a 1-D array with size <tt>features.shape(0)</tt>.
115 FEATURES const & features,
116 LABELS & labels,
117 int n_threads = -1,
118 const std::vector<size_t> & tree_indices = std::vector<size_t>()
119 ) const;
120
121 /// \brief Predict the probabilities of the given data and return the average number of split comparisons.
122 /// \note probs should have the shape (features.shape()[0], num_classes).
123 template <typename PROBS>
125 FEATURES const & features,
126 PROBS & probs,
127 int n_threads = -1,
128 const std::vector<size_t> & tree_indices = std::vector<size_t>()
129 ) const;
130
131 /// \brief For each data point in features, compute the corresponding leaf ids and return the average number of split comparisons.
132 /// \note ids should have the shape (features.shape()[0], num_trees).
133 template <typename IDS>
134 double leaf_ids(
135 FEATURES const & features,
136 IDS & ids,
137 int n_threads = -1,
138 const std::vector<size_t> tree_indices = std::vector<size_t>()
139 ) const;
140
141 /// \brief Return the number of nodes.
142 size_t num_nodes() const
143 {
144 return graph_.numNodes();
145 }
146
147 /// \brief Return the number of trees.
148 size_t num_trees() const
149 {
150 return graph_.numRoots();
151 }
152
153 /// \brief Return the number of classes.
154 size_t num_classes() const
155 {
156 return problem_spec_.num_classes_;
157 }
158
159 /// \brief Return the number of classes.
160 size_t num_features() const
161 {
162 return problem_spec_.num_features_;
163 }
164
165 /// \brief The graph structure.
166 Graph graph_;
167
168 /// \brief Contains a test for each internal node, that is used to determine whether given data goes to the left or the right child.
169 typename NodeMap<SplitTests>::type split_tests_;
170
171 /// \brief Contains the responses of each node (for example the most frequent label).
172 typename NodeMap<AccInputType>::type node_responses_;
173
174 /// \brief The specifications.
175 ProblemSpec<LabelType> problem_spec_;
176
177 /// \brief The options that were used for training.
179
180private:
181
182 /// \brief Compute the leaf ids of the instances in [from, to).
183 template <typename IDS, typename INDICES>
184 double leaf_ids_impl(
185 FEATURES const & features,
186 IDS & ids,
187 size_t from,
188 size_t to,
189 INDICES const & tree_indices
190 ) const;
191
192 template<typename PROBS>
193 void predict_probabilities_impl(
194 FEATURES const & features,
195 PROBS & probs,
196 const size_t i,
197 const std::vector<size_t> & tree_indices) const;
198
199};
200
201template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
202RandomForest<FEATURES, LABELS, SPLITTESTS, ACC>::RandomForest()
203 :
204 graph_(),
205 split_tests_(),
206 node_responses_(),
207 problem_spec_()
208{}
209
210template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
212 Graph const & graph,
213 typename NodeMap<SplitTests>::type const & split_tests,
214 typename NodeMap<AccInputType>::type const & node_responses,
215 ProblemSpec<LabelType> const & problem_spec
216) :
217 graph_(graph),
218 split_tests_(split_tests),
219 node_responses_(node_responses),
220 problem_spec_(problem_spec)
221{}
222
223template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
225 RandomForest const & other
226){
227 vigra_precondition(problem_spec_ == other.problem_spec_,
228 "RandomForest::merge(): You cannot merge with different problem specs.");
229
230 // FIXME: Eventually compare the options and only fix if the forests are compatible.
231
232 size_t const offset = num_nodes();
233 graph_.merge(other.graph_);
234 for (auto const & p : other.split_tests_)
235 {
236 split_tests_.insert(Node(p.first.id()+offset), p.second);
237 }
238 for (auto const & p : other.node_responses_)
239 {
240 node_responses_.insert(Node(p.first.id()+offset), p.second);
241 }
242}
243
244// FIXME TODO we don't support the selection of tree indices any more in predict_probabilities, might be a good idea
245// to re-enable this.
246template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
248 FEATURES const & features,
249 LABELS & labels,
250 int n_threads,
251 const std::vector<size_t> & tree_indices
252) const {
253 vigra_precondition(features.shape()[0] == labels.shape()[0],
254 "RandomForest::predict(): Shape mismatch between features and labels.");
255 vigra_precondition((size_t)features.shape()[1] == problem_spec_.num_features_,
256 "RandomForest::predict(): Number of features in prediction differs from training.");
257
258 MultiArray<2, double> probs(Shape2(features.shape()[0], problem_spec_.num_classes_));
259 predict_probabilities(features, probs, n_threads, tree_indices);
260 for (size_t i = 0; i < (size_t)features.shape()[0]; ++i)
261 {
262 auto const sub_probs = probs.template bind<0>(i);
263 auto it = std::max_element(sub_probs.begin(), sub_probs.end());
264 size_t const label = std::distance(sub_probs.begin(), it);
265 labels(i) = problem_spec_.distinct_classes_[label];
266 }
267}
268
269
270// FIXME TODO we don't support the selection of tree indices any more in predict_probabilities, might be a good idea
271// to re-enable this.
272template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
273template <typename PROBS>
275 FEATURES const & features,
276 PROBS & probs,
277 int n_threads,
278 const std::vector<size_t> & tree_indices
279) const {
280 vigra_precondition(features.shape()[0] == probs.shape()[0],
281 "RandomForest::predict_probabilities(): Shape mismatch between features and probabilities.");
282 vigra_precondition((size_t)features.shape()[1] == problem_spec_.num_features_,
283 "RandomForest::predict_probabilities(): Number of features in prediction differs from training.");
284 vigra_precondition((size_t)probs.shape()[1] == problem_spec_.num_classes_,
285 "RandomForest::predict_probabilities(): Number of labels in probabilities differs from training.");
286
287 // By default, actual_tree_indices is empty. In that case we want to use all trees.
288 // We need to make a copy. I really don't know how the old code did compile...
289 std::vector<size_t> tree_indices_cpy(tree_indices);
290 if (tree_indices_cpy.size() == 0)
291 {
292 tree_indices_cpy.resize(graph_.numRoots());
293 std::iota(tree_indices_cpy.begin(), tree_indices_cpy.end(), 0);
294 }
295 else {
296 // Check the tree indices.
297 std::sort(tree_indices_cpy.begin(), tree_indices_cpy.end());
298 tree_indices_cpy.erase(std::unique(tree_indices_cpy.begin(), tree_indices_cpy.end()), tree_indices_cpy.end());
299 for (auto i : tree_indices_cpy)
300 vigra_precondition(i < graph_.numRoots(), "RandomForest::leaf_ids(): Tree index out of range.");
301 }
302
303 size_t const num_instances = features.shape()[0];
304
305 if (n_threads == -1)
306 n_threads = std::thread::hardware_concurrency();
307 if (n_threads < 1)
308 n_threads = 1;
309
311 n_threads,
312 num_instances,
313 [&features,&probs,&tree_indices_cpy,this](size_t, size_t i) {
314 this->predict_probabilities_impl(features, probs, i, tree_indices_cpy);
315 }
316 );
317}
318
319template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
320template <typename PROBS>
321void RandomForest<FEATURES, LABELS, SPLITTESTS, ACC>::predict_probabilities_impl(
322 FEATURES const & features,
323 PROBS & probs,
324 const size_t i,
325 const std::vector<size_t> & tree_indices
326) const {
327
328 // instantiate the accumulation function and the vector to store the tree node results
329 ACC acc;
330 std::vector<AccInputType> tree_results;
331 tree_results.reserve(tree_indices.size());
332 auto const sub_features = features.template bind<0>(i);
333
334 // loop over the trees
335 for (auto k : tree_indices)
336 {
337 Node node = graph_.getRoot(k);
338 while (graph_.outDegree(node) > 0)
339 {
340 size_t const child_index = split_tests_.at(node)(sub_features);
341 node = graph_.getChild(node, child_index);
342 }
343 tree_results.emplace_back(node_responses_.at(node));
344 }
345
346 // write the tree results into the probabilities
347 auto sub_probs = probs.template bind<0>(i);
348 acc(tree_results.begin(), tree_results.end(), sub_probs.begin());
349}
350
351template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
352template <typename IDS>
354 FEATURES const & features,
355 IDS & ids,
356 int n_threads,
357 std::vector<size_t> tree_indices
358) const {
359 vigra_precondition(features.shape()[0] == ids.shape()[0],
360 "RandomForest::leaf_ids(): Shape mismatch between features and probabilities.");
361 vigra_precondition((size_t)features.shape()[1] == problem_spec_.num_features_,
362 "RandomForest::leaf_ids(): Number of features in prediction differs from training.");
363 vigra_precondition(ids.shape()[1] == graph_.numRoots(),
364 "RandomForest::leaf_ids(): Leaf array has wrong shape.");
365
366 // Check the tree indices.
367 std::sort(tree_indices.begin(), tree_indices.end());
368 tree_indices.erase(std::unique(tree_indices.begin(), tree_indices.end()), tree_indices.end());
369 for (auto i : tree_indices)
370 vigra_precondition(i < graph_.numRoots(), "RandomForest::leaf_ids(): Tree index out of range.");
371
372 // By default, actual_tree_indices is empty. In that case we want to use all trees.
373 if (tree_indices.size() == 0)
374 {
375 tree_indices.resize(graph_.numRoots());
376 std::iota(tree_indices.begin(), tree_indices.end(), 0);
377 }
378
379 size_t const num_instances = features.shape()[0];
380 if (n_threads == -1)
381 n_threads = std::thread::hardware_concurrency();
382 if (n_threads < 1)
383 n_threads = 1;
384 std::vector<double> split_comparisons(n_threads, 0.0);
385 std::vector<size_t> indices(num_instances);
386 std::iota(indices.begin(), indices.end(), 0);
387 std::fill(ids.begin(), ids.end(), -1);
389 n_threads,
390 indices.begin(),
391 indices.end(),
392 [this, &features, &ids, &split_comparisons, &tree_indices](size_t thread_id, size_t i) {
393 split_comparisons[thread_id] += this->leaf_ids_impl(features, ids, i, i+1, tree_indices);
394 }
395 );
396
397 double const sum_split_comparisons = std::accumulate(split_comparisons.begin(), split_comparisons.end(), 0.0);
398 return sum_split_comparisons / features.shape()[0];
399}
400
401template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
402template <typename IDS, typename INDICES>
403double RandomForest<FEATURES, LABELS, SPLITTESTS, ACC>::leaf_ids_impl(
404 FEATURES const & features,
405 IDS & ids,
406 size_t from,
407 size_t to,
408 INDICES const & tree_indices
409) const {
410 vigra_precondition(features.shape()[0] == ids.shape()[0],
411 "RandomForest::leaf_ids_impl(): Shape mismatch between features and labels.");
412 vigra_precondition(features.shape()[1] == problem_spec_.num_features_,
413 "RandomForest::leaf_ids_impl(): Number of Features in prediction differs from training.");
414 vigra_precondition(from >= 0 && from <= to && to <= (size_t)features.shape()[0],
415 "RandomForest::leaf_ids_impl(): Indices out of range.");
416 vigra_precondition(ids.shape()[1] == graph_.numRoots(),
417 "RandomForest::leaf_ids_impl(): Leaf array has wrong shape.");
418
419 double split_comparisons = 0.0;
420 for (size_t i = from; i < to; ++i)
421 {
422 auto const sub_features = features.template bind<0>(i);
423 for (auto k : tree_indices)
424 {
425 Node node = graph_.getRoot(k);
426 while (graph_.outDegree(node) > 0)
427 {
428 size_t const child_index = split_tests_.at(node)(sub_features);
429 node = graph_.getChild(node, child_index);
430 split_comparisons += 1.0;
431 }
432 ids(i, k) = node.id();
433 }
434 }
435 return split_comparisons;
436}
437
438
439
440} // namespace rf3
441} // namespace vigra
442
443#endif
BinaryForest stores a collection of rooted binary trees.
Definition binary_forest.hxx:65
detail::NodeDescriptor< index_type > Node
Node descriptor type of the present graph.
Definition binary_forest.hxx:70
Main MultiArray class containing the memory management.
Definition multi_array.hxx:2479
problem specification class for the random forest.
Definition rf_common.hxx:539
The PropertyMap is used to store Node or Arc information of graphs.
Definition graphs.hxx:411
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition random_forest.hxx:197
Options class for RandomForest version 3.
Definition random_forest_common.hxx:583
void predict_probabilities(FEATURES const &features, PROBS &probs, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the probabilities of the given data and return the average number of split comparisons.
Definition random_forest.hxx:274
size_t num_features() const
Return the number of classes.
Definition random_forest.hxx:160
void merge(RandomForest const &other)
Grow this forest by incorporating the other.
Definition random_forest.hxx:224
void predict(FEATURES const &features, LABELS &labels, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the given data and return the average number of split comparisons.
Definition random_forest.hxx:247
double leaf_ids(FEATURES const &features, IDS &ids, int n_threads=-1, const std::vector< size_t > tree_indices=std::vector< size_t >()) const
For each data point in features, compute the corresponding leaf ids and return the average number of ...
Definition random_forest.hxx:353
size_t num_trees() const
Return the number of trees.
Definition random_forest.hxx:148
size_t num_nodes() const
Return the number of nodes.
Definition random_forest.hxx:142
size_t num_classes() const
Return the number of classes.
Definition random_forest.hxx:154
Efficient computation of object statistics.
Definition accumulator-grammar.hxx:48
Random forest version 3.
Definition random_forest_3.hxx:66
void parallel_foreach(...)
Apply a functor to all items in a range in parallel.

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.2 (Mon Apr 14 2025)