|
| 1 | +# partie 9 |
| 2 | + |
| 3 | +# xai.py |
| 4 | + |
| 5 | + |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +import shap |
| 10 | +import matplotlib.pyplot as plt |
| 11 | +from sklearn.inspection import permutation_importance |
| 12 | + |
| 13 | +# Assurer que les modèles et les données existent |
| 14 | +assert 'logreg' in globals(), "Le modèle de régression logistique (logreg) n'est pas défini." |
| 15 | +assert 'rf' in globals(), "Le modèle RandomForest (rf) n'est pas défini." |
| 16 | +assert 'mlp' in globals(), "Le modèle MLP (mlp) n'est pas défini." |
| 17 | +assert 'X_train' in globals(), "X_train n'est pas défini." |
| 18 | +assert 'X_val' in globals(), "X_val n'est pas défini." |
| 19 | +assert 'y_val' in globals(), "y_val n'est pas défini." |
| 20 | + |
| 21 | +print("Toutes les assertions sont validées. Exécution du code XAI...") |
| 22 | + |
| 23 | +# ---- Explication avec SHAP pour RandomForest ---- |
| 24 | +explainer = shap.Explainer(rf) |
| 25 | +shap_values = explainer(X_train) |
| 26 | + |
| 27 | +# Visualisation SHAP pour la première observation |
| 28 | +plt.title("Graphique en cascade SHAP pour la première observation") |
| 29 | +shap.plots.waterfall(shap_values[0, :, 1]) |
| 30 | +plt.show() |
| 31 | + |
| 32 | +# ---- Analyse des coefficients pour la régression logistique ---- |
| 33 | +coefficients = logreg.coef_[0] |
| 34 | +feature_names = X_train.columns |
| 35 | +coef_feature_pairs = sorted(zip(coefficients, feature_names), key=lambda x: abs(x[0]), reverse=True) |
| 36 | + |
| 37 | +# Sélection des 10 principales caractéristiques |
| 38 | +sorted_coefficients, sorted_feature_names = zip(*coef_feature_pairs[:10]) |
| 39 | +plt.figure(figsize=(10, 6)) |
| 40 | +plt.barh(sorted_feature_names, sorted_coefficients, color='skyblue') |
| 41 | +plt.xlabel('Valeur des coefficients') |
| 42 | +plt.ylabel('Nom des caractéristiques') |
| 43 | +plt.title('Top 10 des coefficients de la régression logistique') |
| 44 | +plt.gca().invert_yaxis() |
| 45 | +plt.show() |
| 46 | + |
| 47 | +# ---- Analyse des importances des caractéristiques pour RandomForest ---- |
| 48 | +importances = rf.feature_importances_ |
| 49 | +feature_importance_df = pd.DataFrame({'Feature': X_train.columns, 'Importance': importances}).sort_values(by='Importance', ascending=False) |
| 50 | + |
| 51 | +# Sélection des 10 principales caractéristiques |
| 52 | +top_features = feature_importance_df.head(10) |
| 53 | +plt.figure(figsize=(12, 6)) |
| 54 | +plt.barh(top_features['Feature'], top_features['Importance'], color='lightgreen') |
| 55 | +plt.xlabel('Importance des caractéristiques') |
| 56 | +plt.ylabel('Nom des caractéristiques') |
| 57 | +plt.title('Top 10 des caractéristiques les plus importantes (RandomForest)') |
| 58 | +plt.gca().invert_yaxis() |
| 59 | +plt.show() |
| 60 | + |
| 61 | +# ---- Importances par permutation pour MLP ---- |
| 62 | +perm_importance = permutation_importance(mlp, X_val, y_val, n_repeats=10, random_state=42) |
| 63 | +feature_importances = perm_importance.importances_mean |
| 64 | +sorted_idx = feature_importances.argsort()[::-1] |
| 65 | + |
| 66 | +# Sélection des 10 principales caractéristiques |
| 67 | +top_features = [X_train.columns[i] for i in sorted_idx[:10]] |
| 68 | +top_importances = feature_importances[sorted_idx[:10]] |
| 69 | +plt.figure(figsize=(10, 6)) |
| 70 | +plt.barh(top_features, top_importances, color='cornflowerblue') |
| 71 | +plt.xlabel("Importance des caractéristiques") |
| 72 | +plt.ylabel("Nom des caractéristiques") |
| 73 | +plt.title("Top 10 importances des caractéristiques par permutation") |
| 74 | +plt.gca().invert_yaxis() |
| 75 | +plt.show() |
| 76 | + |
| 77 | +print("Analyse XAI terminée avec succès !") |
0 commit comments