Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5560,6 +5560,113 @@
" plt.show()\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 12) Well-fit scatter grid: WW metrics vs generalization gap (per dataset \u00d7 3 metrics)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"if 'combined_df' not in globals() or combined_df.empty:\n",
" print('No combined results available. Run the training/aggregation cells first.')\n",
"else:\n",
" scatter_df = combined_df.copy()\n",
" scatter_df['case_type'] = scatter_df.get('case_type', 'good').fillna('good').astype(str)\n",
" scatter_df['overfit_mode'] = scatter_df.get('overfit_mode', 'none').fillna('none').astype(str)\n",
"\n",
" good_df = scatter_df[scatter_df['case_type'] != 'overfit'].copy()\n",
"\n",
" required_cols = ['train_accuracy', 'test_accuracy', 'alpha', 'num_traps', 'ERG_gap']\n",
" for col in required_cols:\n",
" if col in good_df.columns:\n",
" good_df[col] = pd.to_numeric(good_df[col], errors='coerce')\n",
"\n",
" if 'train_accuracy' in good_df.columns and 'test_accuracy' in good_df.columns:\n",
" good_df['generalization_gap'] = good_df['train_accuracy'] - good_df['test_accuracy']\n",
"\n",
" if good_df.empty:\n",
" print('No well-fit (good) rows found for plotting.')\n",
" elif 'generalization_gap' not in good_df.columns:\n",
" print('Missing required columns: train_accuracy and/or test_accuracy.')\n",
" else:\n",
" dataset_col = 'dataset_uid' if 'dataset_uid' in good_df.columns else None\n",
" if dataset_col is None:\n",
" good_df['dataset_uid'] = 'well_fit_models'\n",
" dataset_col = 'dataset_uid'\n",
"\n",
" dataset_order = list(dict.fromkeys(good_df[dataset_col].astype(str).tolist()))\n",
"\n",
" palette = plt.cm.tab20(np.linspace(0, 1, max(len(dataset_order), 3)))\n",
" dataset_colors = {dataset: palette[i] for i, dataset in enumerate(dataset_order)}\n",
"\n",
" metric_specs = [\n",
" ('alpha', 'alpha'),\n",
" ('num_traps', 'num_traps'),\n",
" ('ERG_gap', 'ERG_gap'),\n",
" ]\n",
"\n",
" n_rows = len(dataset_order)\n",
" n_cols = len(metric_specs)\n",
" fig, axes = plt.subplots(n_rows, n_cols, figsize=(5.1 * n_cols, 3.7 * n_rows), squeeze=False)\n",
"\n",
" for row_idx, dataset_uid in enumerate(dataset_order):\n",
" row_df = good_df[good_df[dataset_col].astype(str) == dataset_uid].copy()\n",
"\n",
" for col_idx, (metric_col, metric_label) in enumerate(metric_specs):\n",
" ax = axes[row_idx, col_idx]\n",
"\n",
" if metric_col not in row_df.columns or 'generalization_gap' not in row_df.columns:\n",
" ax.text(0.5, 0.5, 'Missing required columns', ha='center', va='center', transform=ax.transAxes)\n",
" ax.set_axis_off()\n",
" continue\n",
"\n",
" cur = row_df.dropna(subset=[metric_col, 'generalization_gap'])\n",
"\n",
" if cur.empty:\n",
" ax.text(0.5, 0.5, f'No data for {dataset_uid}', ha='center', va='center', transform=ax.transAxes)\n",
" ax.set_xlabel(metric_label)\n",
" ax.set_ylabel('generalization_gap')\n",
" ax.grid(alpha=0.2)\n",
" continue\n",
"\n",
" ax.scatter(\n",
" cur[metric_col],\n",
" cur['generalization_gap'],\n",
" color=dataset_colors.get(dataset_uid, '#777777'),\n",
" edgecolors='black',\n",
" alpha=0.8,\n",
" s=42,\n",
" )\n",
" ax.set_xlabel(metric_label)\n",
" ax.set_ylabel('generalization_gap')\n",
" ax.grid(alpha=0.2)\n",
"\n",
" if row_idx == 0:\n",
" ax.set_title(f'{metric_label} vs generalization_gap')\n",
"\n",
" if col_idx == 0:\n",
" ax.text(\n",
" 0.03,\n",
" 0.97,\n",
" str(dataset_uid),\n",
" transform=ax.transAxes,\n",
" va='top',\n",
" ha='left',\n",
" fontsize=9,\n",
" fontweight='bold',\n",
" color=dataset_colors.get(dataset_uid, '#333333'),\n",
" )\n",
"\n",
" fig.suptitle('Well-fit models: alpha/num_traps/ERG_gap vs generalization_gap by dataset', y=1.002)\n",
" fig.tight_layout()\n",
" plt.show()\n"
]
}
],
"metadata": {
Expand All @@ -5580,4 +5687,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}