diff --git a/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb b/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb index 07bfd47..f217081 100644 --- a/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb +++ b/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb @@ -5560,6 +5560,113 @@ " plt.show()\n", "" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 12) Well-fit scatter grid: WW metrics vs generalization gap (per dataset \u00d7 3 metrics)\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "if 'combined_df' not in globals() or combined_df.empty:\n", + " print('No combined results available. Run the training/aggregation cells first.')\n", + "else:\n", + " scatter_df = combined_df.copy()\n", + " scatter_df['case_type'] = scatter_df.get('case_type', 'good').fillna('good').astype(str)\n", + " scatter_df['overfit_mode'] = scatter_df.get('overfit_mode', 'none').fillna('none').astype(str)\n", + "\n", + " good_df = scatter_df[scatter_df['case_type'] != 'overfit'].copy()\n", + "\n", + " required_cols = ['train_accuracy', 'test_accuracy', 'alpha', 'num_traps', 'ERG_gap']\n", + " for col in required_cols:\n", + " if col in good_df.columns:\n", + " good_df[col] = pd.to_numeric(good_df[col], errors='coerce')\n", + "\n", + " if 'train_accuracy' in good_df.columns and 'test_accuracy' in good_df.columns:\n", + " good_df['generalization_gap'] = good_df['train_accuracy'] - good_df['test_accuracy']\n", + "\n", + " if good_df.empty:\n", + " print('No well-fit (good) rows found for plotting.')\n", + " elif 'generalization_gap' not in good_df.columns:\n", + " print('Missing required columns: train_accuracy and/or test_accuracy.')\n", + " else:\n", + " dataset_col = 'dataset_uid' if 'dataset_uid' in good_df.columns else None\n", + " if dataset_col is None:\n", + " good_df['dataset_uid'] = 'well_fit_models'\n", + " dataset_col = 'dataset_uid'\n", + "\n", + " dataset_order = list(dict.fromkeys(good_df[dataset_col].astype(str).tolist()))\n", + "\n", + " palette = plt.cm.tab20(np.linspace(0, 1, max(len(dataset_order), 3)))\n", + " dataset_colors = {dataset: palette[i] for i, dataset in enumerate(dataset_order)}\n", + "\n", + " metric_specs = [\n", + " ('alpha', 'alpha'),\n", + " ('num_traps', 'num_traps'),\n", + " ('ERG_gap', 'ERG_gap'),\n", + " ]\n", + "\n", + " n_rows = len(dataset_order)\n", + " n_cols = len(metric_specs)\n", + " fig, axes = plt.subplots(n_rows, n_cols, figsize=(5.1 * n_cols, 3.7 * n_rows), squeeze=False)\n", + "\n", + " for row_idx, dataset_uid in enumerate(dataset_order):\n", + " row_df = good_df[good_df[dataset_col].astype(str) == dataset_uid].copy()\n", + "\n", + " for col_idx, (metric_col, metric_label) in enumerate(metric_specs):\n", + " ax = axes[row_idx, col_idx]\n", + "\n", + " if metric_col not in row_df.columns or 'generalization_gap' not in row_df.columns:\n", + " ax.text(0.5, 0.5, 'Missing required columns', ha='center', va='center', transform=ax.transAxes)\n", + " ax.set_axis_off()\n", + " continue\n", + "\n", + " cur = row_df.dropna(subset=[metric_col, 'generalization_gap'])\n", + "\n", + " if cur.empty:\n", + " ax.text(0.5, 0.5, f'No data for {dataset_uid}', ha='center', va='center', transform=ax.transAxes)\n", + " ax.set_xlabel(metric_label)\n", + " ax.set_ylabel('generalization_gap')\n", + " ax.grid(alpha=0.2)\n", + " continue\n", + "\n", + " ax.scatter(\n", + " cur[metric_col],\n", + " cur['generalization_gap'],\n", + " color=dataset_colors.get(dataset_uid, '#777777'),\n", + " edgecolors='black',\n", + " alpha=0.8,\n", + " s=42,\n", + " )\n", + " ax.set_xlabel(metric_label)\n", + " ax.set_ylabel('generalization_gap')\n", + " ax.grid(alpha=0.2)\n", + "\n", + " if row_idx == 0:\n", + " ax.set_title(f'{metric_label} vs generalization_gap')\n", + "\n", + " if col_idx == 0:\n", + " ax.text(\n", + " 0.03,\n", + " 0.97,\n", + " str(dataset_uid),\n", + " transform=ax.transAxes,\n", + " va='top',\n", + " ha='left',\n", + " fontsize=9,\n", + " fontweight='bold',\n", + " color=dataset_colors.get(dataset_uid, '#333333'),\n", + " )\n", + "\n", + " fig.suptitle('Well-fit models: alpha/num_traps/ERG_gap vs generalization_gap by dataset', y=1.002)\n", + " fig.tight_layout()\n", + " plt.show()\n" + ] } ], "metadata": { @@ -5580,4 +5687,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file