Skip to content

Commit 640f5ab

Browse files
Create xai.py
1 parent 4110337 commit 640f5ab

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

tests/xai.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# partie 9
2+
3+
# xai.py
4+
5+
6+
7+
import numpy as np
8+
import pandas as pd
9+
import shap
10+
import matplotlib.pyplot as plt
11+
from sklearn.inspection import permutation_importance
12+
13+
# Assurer que les modèles et les données existent
14+
assert 'logreg' in globals(), "Le modèle de régression logistique (logreg) n'est pas défini."
15+
assert 'rf' in globals(), "Le modèle RandomForest (rf) n'est pas défini."
16+
assert 'mlp' in globals(), "Le modèle MLP (mlp) n'est pas défini."
17+
assert 'X_train' in globals(), "X_train n'est pas défini."
18+
assert 'X_val' in globals(), "X_val n'est pas défini."
19+
assert 'y_val' in globals(), "y_val n'est pas défini."
20+
21+
print("Toutes les assertions sont validées. Exécution du code XAI...")
22+
23+
# ---- Explication avec SHAP pour RandomForest ----
24+
explainer = shap.Explainer(rf)
25+
shap_values = explainer(X_train)
26+
27+
# Visualisation SHAP pour la première observation
28+
plt.title("Graphique en cascade SHAP pour la première observation")
29+
shap.plots.waterfall(shap_values[0, :, 1])
30+
plt.show()
31+
32+
# ---- Analyse des coefficients pour la régression logistique ----
33+
coefficients = logreg.coef_[0]
34+
feature_names = X_train.columns
35+
coef_feature_pairs = sorted(zip(coefficients, feature_names), key=lambda x: abs(x[0]), reverse=True)
36+
37+
# Sélection des 10 principales caractéristiques
38+
sorted_coefficients, sorted_feature_names = zip(*coef_feature_pairs[:10])
39+
plt.figure(figsize=(10, 6))
40+
plt.barh(sorted_feature_names, sorted_coefficients, color='skyblue')
41+
plt.xlabel('Valeur des coefficients')
42+
plt.ylabel('Nom des caractéristiques')
43+
plt.title('Top 10 des coefficients de la régression logistique')
44+
plt.gca().invert_yaxis()
45+
plt.show()
46+
47+
# ---- Analyse des importances des caractéristiques pour RandomForest ----
48+
importances = rf.feature_importances_
49+
feature_importance_df = pd.DataFrame({'Feature': X_train.columns, 'Importance': importances}).sort_values(by='Importance', ascending=False)
50+
51+
# Sélection des 10 principales caractéristiques
52+
top_features = feature_importance_df.head(10)
53+
plt.figure(figsize=(12, 6))
54+
plt.barh(top_features['Feature'], top_features['Importance'], color='lightgreen')
55+
plt.xlabel('Importance des caractéristiques')
56+
plt.ylabel('Nom des caractéristiques')
57+
plt.title('Top 10 des caractéristiques les plus importantes (RandomForest)')
58+
plt.gca().invert_yaxis()
59+
plt.show()
60+
61+
# ---- Importances par permutation pour MLP ----
62+
perm_importance = permutation_importance(mlp, X_val, y_val, n_repeats=10, random_state=42)
63+
feature_importances = perm_importance.importances_mean
64+
sorted_idx = feature_importances.argsort()[::-1]
65+
66+
# Sélection des 10 principales caractéristiques
67+
top_features = [X_train.columns[i] for i in sorted_idx[:10]]
68+
top_importances = feature_importances[sorted_idx[:10]]
69+
plt.figure(figsize=(10, 6))
70+
plt.barh(top_features, top_importances, color='cornflowerblue')
71+
plt.xlabel("Importance des caractéristiques")
72+
plt.ylabel("Nom des caractéristiques")
73+
plt.title("Top 10 importances des caractéristiques par permutation")
74+
plt.gca().invert_yaxis()
75+
plt.show()
76+
77+
print("Analyse XAI terminée avec succès !")

0 commit comments

Comments
 (0)