Skip to content

Commit 91cbc5e

Browse files
authored
Implement Random Forest as a sub case to EnsembleLearner (#410)
* Implement Random Forest as a subset of Bagging with an additional constraint on predictors (or features) to be used. Introduce Dataset bootstrap implementation with indices denoted as '_with_indices' methods. * cargo fmt * clippy is far happier now :) * Add type alias for Random Forest as an EnsembleLearner with model type DecisionTree. * fix clippy code quality check * Add unit tests for bootstrap with indices and random forest * 📜 Add docs and example for RandomForest type alias. * lint
1 parent 3f2c202 commit 91cbc5e

File tree

8 files changed

+345
-48
lines changed

8 files changed

+345
-48
lines changed

algorithms/linfa-ensemble/README.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,39 @@
1515
You can find examples in the `examples/` directory. To run an bootstrap aggregation for ensemble of decision trees (a Random Forest) use:
1616

1717
```bash
18-
$ cargo run --example randomforest_iris --release
18+
$ cargo run --example ensemble_iris --release
1919
```
2020

21+
The expected output should be
22+
```commandline
23+
An example using Bagging with Decision Tree on Iris Dataset
24+
Final Predictions:
25+
[0, 2, 0, 1, 1, 2, 2, 1, 0, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 2, 0], shape=[30], strides=[1], layout=CFcf (0xf), const ndim=1
26+
27+
classes | 0 | 1 | 2
28+
0 | 11 | 0 | 0
29+
1 | 0 | 7 | 1
30+
2 | 0 | 1 | 10
31+
32+
Test accuracy: 93.333336
33+
with default Decision Tree params,
34+
Ensemble Size: 100,
35+
Bootstrap Proportion: 0.7
36+
Feature selection proportion: 1
37+
38+
An example using a Random Forest on Iris Dataset
39+
Final Predictions:
40+
[0, 1, 0, 1, 1, 2, 2, 1, 0, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 2, 0], shape=[30], strides=[1], layout=CFcf (0xf), const ndim=1
41+
42+
classes | 0 | 1 | 2
43+
0 | 11 | 0 | 0
44+
1 | 0 | 8 | 0
45+
2 | 0 | 1 | 10
46+
47+
Test accuracy: 96.666664
48+
with default Decision Tree params,
49+
Ensemble Size: 100,
50+
Bootstrap Proportion: 0.7
51+
Feature selection proportion: 0.2
52+
```
2153

algorithms/linfa-ensemble/examples/bagging_iris.rs

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
2+
use linfa_ensemble::{EnsembleLearnerParams, RandomForestParams};
3+
use linfa_trees::DecisionTree;
4+
use ndarray_rand::rand::SeedableRng;
5+
use rand::rngs::SmallRng;
6+
7+
fn ensemble_learner(ensemble_size: usize, bootstrap_proportion: f64) {
8+
// Load dataset
9+
let mut rng = SmallRng::seed_from_u64(42);
10+
let (train, test) = linfa_datasets::iris()
11+
.shuffle(&mut rng)
12+
.split_with_ratio(0.8);
13+
14+
// Train ensemble learner model
15+
let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
16+
.ensemble_size(ensemble_size)
17+
.bootstrap_proportion(bootstrap_proportion)
18+
.fit(&train)
19+
.unwrap();
20+
21+
// Return highest ranking predictions
22+
let final_predictions_ensemble = model.predict(&test);
23+
println!("Final Predictions: \n{final_predictions_ensemble:?}");
24+
25+
let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();
26+
27+
println!("{cm:?}");
28+
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {ensemble_size},\n Bootstrap Proportion: {bootstrap_proportion}.\n",
29+
100.0 * cm.accuracy());
30+
}
31+
32+
fn random_forest(ensemble_size: usize, bootstrap_proportion: f64, feature_proportion: f64) {
33+
let mut rng = SmallRng::seed_from_u64(42);
34+
let (train, test) = linfa_datasets::iris()
35+
.shuffle(&mut rng)
36+
.split_with_ratio(0.8);
37+
38+
// Train ensemble learner model
39+
let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
40+
.ensemble_size(ensemble_size)
41+
.bootstrap_proportion(bootstrap_proportion)
42+
.feature_proportion(feature_proportion)
43+
.fit(&train)
44+
.unwrap();
45+
46+
// Return highest ranking predictions
47+
let final_predictions_ensemble = model.predict(&test);
48+
println!("Final Predictions: \n{final_predictions_ensemble:?}");
49+
50+
let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();
51+
52+
println!("{cm:?}");
53+
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {ensemble_size},\n Bootstrap Proportion: {bootstrap_proportion}\n Feature selection proportion: {feature_proportion}.\n",
54+
100.0 * cm.accuracy());
55+
}
56+
57+
fn main() {
58+
// This is an example bagging with decision tree
59+
println!("An example using Bagging with Decision Tree on Iris Dataset");
60+
ensemble_learner(100, 0.7);
61+
// This is basically a Random Forest ensemble
62+
println!("An example using a Random Forest on Iris Dataset");
63+
random_forest(100, 0.7, 0.2);
64+
}

algorithms/linfa-ensemble/src/algorithm.rs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,31 @@ use linfa::{
55
traits::*,
66
DatasetBase,
77
};
8+
use linfa_trees::DecisionTree;
89
use ndarray::{Array2, Axis, Zip};
910
use rand::Rng;
1011
use std::{cmp::Eq, collections::HashMap, hash::Hash};
1112

13+
pub type RandomForest<F, L> = EnsembleLearner<DecisionTree<F, L>>;
14+
1215
pub struct EnsembleLearner<M> {
1316
pub models: Vec<M>,
17+
pub model_features: Vec<Vec<usize>>,
1418
}
1519

1620
impl<M> EnsembleLearner<M> {
1721
// Generates prediction iterator returning predictions from each model
1822
pub fn generate_predictions<'b, R: Records, T>(
1923
&'b self,
20-
x: &'b R,
24+
x: &'b [R],
2125
) -> impl Iterator<Item = T> + 'b
2226
where
2327
M: Predict<&'b R, T>,
2428
{
25-
self.models.iter().map(move |m| m.predict(x))
29+
self.models
30+
.iter()
31+
.zip(x.iter())
32+
.map(move |(m, sub_data)| m.predict(sub_data))
2633
}
2734
}
2835

@@ -40,7 +47,12 @@ where
4047
"The number of data points must match the number of outputs."
4148
);
4249

43-
let predictions = self.generate_predictions(x);
50+
let sub_datas = self
51+
.model_features
52+
.iter()
53+
.map(|feat| x.select(Axis(1), feat))
54+
.collect::<Vec<_>>();
55+
let predictions = self.generate_predictions(&sub_datas);
4456

4557
// prediction map has same shape as y_array, but the elements are maps
4658
let mut prediction_maps = y_array.map(|_| HashMap::new());
@@ -81,23 +93,30 @@ where
8193
&self,
8294
dataset: &DatasetBase<Array2<D>, T>,
8395
) -> core::result::Result<Self::Object, Error> {
84-
let mut models = Vec::new();
96+
let mut models = Vec::with_capacity(self.ensemble_size);
97+
let mut model_features = Vec::with_capacity(self.ensemble_size);
8598
let mut rng = self.rng.clone();
8699

100+
// Compute dataset and the subset of features ratio to be selected
87101
let dataset_size =
88102
((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;
103+
let n_feat = dataset.records.ncols();
104+
let n_sub = ((n_feat as f64) * self.feature_proportion).ceil() as usize;
89105

90-
let iter = dataset.bootstrap_samples(dataset_size, &mut rng);
91-
92-
for train in iter {
106+
let iter = dataset.bootstrap_with_indices((dataset_size, n_sub), &mut rng);
107+
for (train, _, feature_selected) in iter {
93108
let model = self.model_params.fit(&train).unwrap();
94109
models.push(model);
110+
model_features.push(feature_selected);
95111

96112
if models.len() == self.ensemble_size {
97113
break;
98114
}
99115
}
100116

101-
Ok(EnsembleLearner { models })
117+
Ok(EnsembleLearner {
118+
models,
119+
model_features,
120+
})
102121
}
103122
}

algorithms/linfa-ensemble/src/hyperparams.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use linfa::{
22
error::{Error, Result},
33
ParamGuard,
44
};
5+
use linfa_trees::DecisionTreeParams;
56
use rand::rngs::ThreadRng;
67
use rand::Rng;
78

@@ -11,6 +12,8 @@ pub struct EnsembleLearnerValidParams<P, R> {
1112
pub ensemble_size: usize,
1213
/// The proportion of the total number of training samples that should be given to each model for training
1314
pub bootstrap_proportion: f64,
15+
/// The proportion of the total number of training feature that should be given to each model for training
16+
pub feature_proportion: f64,
1417
/// The model parameters for the base model
1518
pub model_params: P,
1619
pub rng: R,
@@ -19,6 +22,8 @@ pub struct EnsembleLearnerValidParams<P, R> {
1922
#[derive(Clone, Copy, Debug, PartialEq)]
2023
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
2124

25+
pub type RandomForestParams<F, L, R> = EnsembleLearnerParams<DecisionTreeParams<F, L>, R>;
26+
2227
impl<P> EnsembleLearnerParams<P, ThreadRng> {
2328
pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
2429
Self::new_fixed_rng(model_params, rand::thread_rng())
@@ -30,6 +35,7 @@ impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
3035
Self(EnsembleLearnerValidParams {
3136
ensemble_size: 1,
3237
bootstrap_proportion: 1.0,
38+
feature_proportion: 1.0,
3339
model_params,
3440
rng,
3541
})
@@ -44,6 +50,11 @@ impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
4450
self.0.bootstrap_proportion = proportion;
4551
self
4652
}
53+
54+
pub fn feature_proportion(mut self, proportion: f64) -> Self {
55+
self.0.feature_proportion = proportion;
56+
self
57+
}
4758
}
4859

4960
impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
@@ -61,6 +72,11 @@ impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
6172
"Ensemble size should be less than one, but was {}",
6273
self.0.ensemble_size
6374
)))
75+
} else if self.0.feature_proportion > 1.0 || self.0.feature_proportion <= 0.0 {
76+
Err(Error::Parameters(format!(
77+
"Feature proportion should be greater than zero and less than or equal to one, but was {}",
78+
self.0.feature_proportion
79+
)))
6480
} else {
6581
Ok(&self.0)
6682
}

algorithms/linfa-ensemble/src/lib.rs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
//!
66
//! ## Bootstrap Aggregation (aka Bagging)
77
//!
8-
//! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of
8+
//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
99
//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
1010
//!
11+
//! ## Random Forest
12+
//!
13+
//! A special case of Bootstrap Aggregation using decision trees (see `linfa-trees`) with random feature
14+
//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
15+
//! the number of available features.
16+
//!
1117
//! ## Reference
1218
//!
1319
//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
20+
//! * [An Introduction to Statistical Learning](https://www.statlearning.com/)
1421
//!
1522
//! ## Example
1623
//!
@@ -32,15 +39,44 @@
3239
//!
3340
//! // Train the model on the iris dataset
3441
//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
35-
//! .ensemble_size(100)
36-
//! .bootstrap_proportion(0.7)
42+
//! .ensemble_size(100) // Number of Decision Tree to fit
43+
//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
3744
//! .fit(&train)
3845
//! .unwrap();
3946
//!
4047
//! // Make predictions on the test set
4148
//! let predictions = bagging_model.predict(&test);
4249
//! ```
4350
//!
51+
//! This example shows how to train a Random Forest model using 100 decision trees,
52+
//! each trained on 70% of the training data (bootstrap sampling) and using only
53+
//! 30% of the available features.
54+
//!
55+
//! ```no_run
56+
//! use linfa::prelude::{Fit, Predict};
57+
//! use linfa_ensemble::RandomForestParams;
58+
//! use linfa_trees::DecisionTree;
59+
//! use ndarray_rand::rand::SeedableRng;
60+
//! use rand::rngs::SmallRng;
61+
//!
62+
//! // Load Iris dataset
63+
//! let mut rng = SmallRng::seed_from_u64(42);
64+
//! let (train, test) = linfa_datasets::iris()
65+
//! .shuffle(&mut rng)
66+
//! .split_with_ratio(0.8);
67+
//!
68+
//! // Train the model on the iris dataset
69+
//! let bagging_model = RandomForestParams::new(DecisionTree::params())
70+
//! .ensemble_size(100) // Number of Decision Tree to fit
71+
//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
72+
//! .feature_proportion(0.3) // Select only 30% of the feature
73+
//! .fit(&train)
74+
//! .unwrap();
75+
//!
76+
//! // Make predictions on the test set
77+
//! let predictions = bagging_model.predict(&test);
78+
//! ```
79+
4480
mod algorithm;
4581
mod hyperparams;
4682

@@ -55,14 +91,35 @@ mod tests {
5591
use ndarray_rand::rand::SeedableRng;
5692
use rand::rngs::SmallRng;
5793

94+
#[test]
95+
fn test_random_forest_accuracy_on_iris_dataset() {
96+
let mut rng = SmallRng::seed_from_u64(42);
97+
let (train, test) = linfa_datasets::iris()
98+
.shuffle(&mut rng)
99+
.split_with_ratio(0.8);
100+
101+
let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
102+
.ensemble_size(100)
103+
.bootstrap_proportion(0.7)
104+
.feature_proportion(0.3)
105+
.fit(&train)
106+
.unwrap();
107+
108+
let predictions = model.predict(&test);
109+
110+
let cm = predictions.confusion_matrix(&test).unwrap();
111+
let acc = cm.accuracy();
112+
assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
113+
}
114+
58115
#[test]
59116
fn test_ensemble_learner_accuracy_on_iris_dataset() {
60117
let mut rng = SmallRng::seed_from_u64(42);
61118
let (train, test) = linfa_datasets::iris()
62119
.shuffle(&mut rng)
63120
.split_with_ratio(0.8);
64121

65-
let model = EnsembleLearnerParams::new(DecisionTree::params())
122+
let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
66123
.ensemble_size(100)
67124
.bootstrap_proportion(0.7)
68125
.fit(&train)

0 commit comments

Comments
 (0)