Skip to content

Commit 4341a4d

Browse files
New: print out trees in compact form
1 parent 57e23a4 commit 4341a4d

File tree

1 file changed

+52
-3
lines changed

1 file changed

+52
-3
lines changed

notebooks/Alhazen.ipynb

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,8 @@
715715
" e.g., 'num(<integer>)'\n",
716716
" rule : The production rule.\n",
717717
" '''\n",
718-
" def __init__(self, name: str, rule: str, friendly_name: str = None) -> None:\n",
718+
" def __init__(self, name: str, rule: str, /, \n",
719+
" friendly_name: str = None) -> None:\n",
719720
" super().__init__(name, rule, rule, friendly_name=friendly_name)\n",
720721
"\n",
721722
" def name_rep(self) -> str:\n",
@@ -2213,7 +2214,7 @@
22132214
"source": [
22142215
"class InputSpecification:\n",
22152216
" '''\n",
2216-
" This class represents a complet input specification of a new input. A input specification\n",
2217+
" This class represents a complete input specification of a new input. A input specification\n",
22172218
" consists of one or more requirements.\n",
22182219
" requirements : Is a list of all requirements that must be used.\n",
22192220
" '''\n",
@@ -2669,7 +2670,8 @@
26692670
"\n",
26702671
" self._all_features = extract_existence(self._grammar) + extract_numeric(self._grammar)\n",
26712672
" self._feature_names = [f.name for f in self._all_features]\n",
2672-
" print(f\"Features: {self._feature_names}\")\n",
2673+
" print(\"Features:\", \", \".join(f.friendly_name() \n",
2674+
" for f in self._all_features))\n",
26732675
"\n",
26742676
" def _add_new_data(self, exec_data, feature_data):\n",
26752677
" joined_data = exec_data.join(feature_data.drop(['sample'], axis=1))\n",
@@ -2821,6 +2823,53 @@
28212823
"show_decision_tree(remove_unequal_decisions(final_tree), all_feature_names)"
28222824
]
28232825
},
2826+
{
2827+
"cell_type": "code",
2828+
"execution_count": null,
2829+
"metadata": {},
2830+
"outputs": [],
2831+
"source": [
2832+
"import math\n",
2833+
"\n",
2834+
"def friendly_decision_tree(clf, feature_names, class_names = ['NO_BUG', 'BUG']):\n",
2835+
" def _tree(index, indent=0):\n",
2836+
" s = \"\"\n",
2837+
" feature = clf.tree_.feature[index]\n",
2838+
" feature_name = feature_names[feature]\n",
2839+
" threshold = clf.tree_.threshold[index]\n",
2840+
" value = clf.tree_.value[index]\n",
2841+
" class_ = int(value[0][0])\n",
2842+
" class_name = class_names[class_]\n",
2843+
" left = clf.tree_.children_left[index]\n",
2844+
" right = clf.tree_.children_right[index]\n",
2845+
" if left == right:\n",
2846+
" # Leaf node\n",
2847+
" s += \" \" * indent + class_name + \"\\n\"\n",
2848+
" else:\n",
2849+
" if math.isclose(threshold, 0.5):\n",
2850+
" s += \" \" * indent + f\"if {feature_name}:\\n\"\n",
2851+
" s += _tree(right, indent + 2)\n",
2852+
" s += \" \" * indent + f\"else:\\n\"\n",
2853+
" s += _tree(left, indent + 2)\n",
2854+
" else:\n",
2855+
" s += \" \" * indent + f\"if {feature_name} <= {threshold:.4f}:\\n\"\n",
2856+
" s += _tree(left, indent + 2)\n",
2857+
" s += \" \" * indent + f\"else:\\n\"\n",
2858+
" s += _tree(right, indent + 2)\n",
2859+
" return s\n",
2860+
"\n",
2861+
" return _tree(0)"
2862+
]
2863+
},
2864+
{
2865+
"cell_type": "code",
2866+
"execution_count": null,
2867+
"metadata": {},
2868+
"outputs": [],
2869+
"source": [
2870+
"print(friendly_decision_tree(final_tree, all_feature_names))"
2871+
]
2872+
},
28242873
{
28252874
"cell_type": "markdown",
28262875
"metadata": {},

0 commit comments

Comments
 (0)