|
| 1 | +--- |
| 2 | +title: Decision Trees |
| 3 | +sidebar_label: Decision Trees |
| 4 | +description: "Understanding recursive partitioning, Entropy, Gini Impurity, and how to prevent overfitting in tree-based models." |
| 5 | +tags: [machine-learning, supervised-learning, classification, decision-trees, cart] |
| 6 | +--- |
| 7 | + |
| 8 | +A **Decision Tree** is a non-parametric supervised learning method used for both classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. |
| 9 | + |
| 10 | +Think of a Decision Tree as a flow chart where each internal node represents a "test" on an attribute, each branch represents the outcome of the test, and each leaf node represents a class label. |
| 11 | + |
| 12 | +## 1. Anatomy of a Tree |
| 13 | + |
| 14 | +* **Root Node:** The very top node that represents the entire dataset. It is the first split. |
| 15 | +* **Internal Node:** A point where the data is split based on a specific feature. |
| 16 | +* **Leaf Node:** The final output nodes that contain the prediction. No further splits occur here. |
| 17 | +* **Branches:** The paths connecting nodes based on the outcome of a decision. |
| 18 | + |
| 19 | +## 2. How the Tree Decides to Split |
| 20 | + |
| 21 | +The algorithm aims to split the data into subsets that are as "pure" as possible. A subset is pure if all data points in it belong to the same class. |
| 22 | + |
| 23 | +### Gini Impurity |
| 24 | + |
| 25 | +This is the default metric used by Scikit-Learn. It measures the probability of a random sample being misclassified. |
| 26 | + |
| 27 | +$$ |
| 28 | +Gini = 1 - \sum_{i=1}^{n} (p_i)^2 |
| 29 | +$$ |
| 30 | + |
| 31 | +**Where:** |
| 32 | + |
| 33 | +* $p_i$ is the probability of an object being classified to a particular class. |
| 34 | + |
| 35 | +### Information Gain (Entropy) |
| 36 | + |
| 37 | +Based on Information Theory, it measures the "disorder" or uncertainty in the data. |
| 38 | + |
| 39 | +$$ |
| 40 | +H(S) = -\sum_{i=1}^{n} p_i \log_2(p_i) |
| 41 | +$$ |
| 42 | + |
| 43 | +**Where:** |
| 44 | + |
| 45 | +* $p_i$ is the proportion of instances in class $i$. |
| 46 | + |
| 47 | +## 3. The Problem of Overfitting |
| 48 | + |
| 49 | +Decision Trees are notorious for **overfitting**. Left unchecked, a tree will continue to split until every single data point has its own leaf, essentially "memorizing" the training data rather than finding patterns. |
| 50 | + |
| 51 | +**How to stop the tree from growing too much:** |
| 52 | +* **max_depth:** Limit how "tall" the tree can get. |
| 53 | +* **min_samples_split:** The minimum number of samples required to split an internal node. |
| 54 | +* **min_samples_leaf:** The minimum number of samples required to be at a leaf node. |
| 55 | +* **Pruning:** Removing branches that provide little power to classify instances. |
| 56 | + |
| 57 | +```mermaid |
| 58 | +graph LR |
| 59 | + X["$$X$$ (Training Data)"] --> ODT["Overfitted Decision Tree"] |
| 60 | +
|
| 61 | + ODT --> O1["$$\text{Very Deep Tree}$$"] |
| 62 | + O1 --> O2["$$\text{Many Splits}$$"] |
| 63 | + O2 --> O3["$$\text{Memorizes Noise}$$"] |
| 64 | + O3 --> O4["$$\text{Low Bias,\ High Variance}$$"] |
| 65 | + O4 --> O5["$$\text{Training Accuracy} \approx 100\%$$"] |
| 66 | + O5 --> O6["$$\text{Poor Generalization}$$"] |
| 67 | +
|
| 68 | + X --> PDT["Pruned Decision Tree"] |
| 69 | +
|
| 70 | + PDT --> P1["$$\text{Limited Depth}$$"] |
| 71 | + P1 --> P2["$$\text{Fewer Splits}$$"] |
| 72 | + P2 --> P3["$$\text{Removes Irrelevant Branches}$$"] |
| 73 | + P3 --> P4["$$\text{Balanced Bias–Variance}$$"] |
| 74 | + P4 --> P5["$$\text{Better Test Accuracy}$$"] |
| 75 | + P5 --> P6["$$\text{Good Generalization}$$"] |
| 76 | +
|
| 77 | + O6 -.->|"$$\text{Comparison}$$"| P6 |
| 78 | +``` |
| 79 | + |
| 80 | +In this diagram, we see two paths from the same training data: one leading to an overfitted decision tree and the other to a pruned decision tree. The overfitted tree has very low bias but high variance, resulting in nearly perfect training accuracy but poor generalization to new data. In contrast, the pruned tree balances bias and variance, leading to better test accuracy and generalization. |
| 81 | + |
| 82 | +## 4. Implementation with Scikit-Learn |
| 83 | + |
| 84 | +```python |
| 85 | +from sklearn.tree import DecisionTreeClassifier, plot_tree |
| 86 | +import matplotlib.pyplot as plt |
| 87 | + |
| 88 | +# 1. Initialize with constraints to prevent overfitting |
| 89 | +model = DecisionTreeClassifier(max_depth=3, criterion='gini') |
| 90 | + |
| 91 | +# 2. Train |
| 92 | +model.fit(X_train, y_train) |
| 93 | + |
| 94 | +# 3. Visualize the Tree |
| 95 | +plt.figure(figsize=(12,8)) |
| 96 | +plot_tree(model, filled=True, feature_names=feature_cols) |
| 97 | +plt.show() |
| 98 | + |
| 99 | +``` |
| 100 | + |
| 101 | +## 5. Pros and Cons |
| 102 | + |
| 103 | +| Advantages | Disadvantages | |
| 104 | +| --- | --- | |
| 105 | +| **Interpretable:** Easy to explain to non-technical stakeholders. | **High Variance:** Small changes in data can result in a completely different tree. | |
| 106 | +| **No Scaling Required:** Does not require feature normalization or standardization. | **Overfitting:** Extremely prone to capturing noise in the data. | |
| 107 | +| Handles both numerical and categorical data. | **Bias:** Can create biased trees if some classes dominate. | |
| 108 | + |
| 109 | +## 6. Mathematical Visualisation |
| 110 | + |
| 111 | +```mermaid |
| 112 | +graph TD |
| 113 | + A[Is Income > $50k?] -->|Yes| B[Is Credit Score > 700?] |
| 114 | + A -->|No| C[Reject Loan] |
| 115 | + B -->|Yes| D[Approve Loan] |
| 116 | + B -->|No| E[Reject Loan] |
| 117 | + |
| 118 | + style A fill:#f3e5f5,stroke:#7b1fa2,color:#333 |
| 119 | + style D fill:#e8f5e9,stroke:#2e7d32,color:#333 |
| 120 | + style C fill:#ffebee,stroke:#c62828,color:#333 |
| 121 | + style E fill:#ffebee,stroke:#c62828,color:#333 |
| 122 | +
|
| 123 | +``` |
| 124 | + |
| 125 | +## References for More Details |
| 126 | + |
| 127 | +* **[Scikit-Learn Tree Module](https://scikit-learn.org/stable/modules/tree.html):** Understanding the algorithmic implementation (CART). |
| 128 | + |
| 129 | +--- |
| 130 | + |
| 131 | +**Single Decision Trees are weak learners. To build a truly robust model, we combine hundreds of trees into a "forest."** |
0 commit comments