Skip to content

Commit d5107cd

Browse files
Update main.py
1 parent c0567e1 commit d5107cd

File tree

1 file changed

+82
-6
lines changed

1 file changed

+82
-6
lines changed

src/main.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,85 @@
1-
from data_preprocessing import preprocess_data
2-
from model_training import train_models
1+
# main.py
2+
3+
4+
import os
35
import pandas as pd
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
import seaborn as sns
9+
import joblib
10+
import shap
11+
import xgboost as xgb
12+
13+
from sklearn.model_selection import train_test_split, GridSearchCV, KFold
14+
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
15+
from sklearn.linear_model import LogisticRegression
16+
from sklearn.ensemble import RandomForestClassifier
17+
from sklearn.neural_network import MLPClassifier
18+
from sklearn.inspection import permutation_importance
19+
20+
# Définition des chemins
21+
MODELS_PATH = "/content/drive/MyDrive/Titanic-Survival-Predict-main/models"
22+
23+
# Chargement des données
24+
train_df = pd.read_csv('/content/drive/My Drive/Titanic-Survival-Predict-main/train_cleaned.csv')
25+
test_df = pd.read_csv('/content/drive/My Drive/Titanic-Survival-Predict-main/test_cleaned.csv')
26+
27+
# Vérification des données
28+
assert train_df.shape[0] > 0, "Le dataset d'entraînement est vide"
29+
assert test_df.shape[0] > 0, "Le dataset de test est vide"
30+
31+
# Préparation des données
32+
X = train_df.drop(columns=['Survived'])
33+
y = train_df['Survived']
34+
35+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
36+
37+
# Entraînement des modèles
38+
print("\n--- Entraînement des modèles de base ---")
39+
logreg = LogisticRegression(max_iter=1000, random_state=42).fit(X_train, y_train)
40+
rf = RandomForestClassifier(n_estimators=100, random_state=42).fit(X_train, y_train)
41+
mlp = MLPClassifier(alpha=0.06, hidden_layer_sizes=(50, 50), learning_rate_init=0.03, max_iter=158).fit(X_train, y_train)
42+
xgb_model = xgb.XGBClassifier(use_label_encoder=False, enable_categorical=True, eval_metric='logloss', random_state=42).fit(X_train, y_train)
43+
44+
# Vérification des modèles
45+
assert logreg, "Erreur lors de l'entraînement de la Régression Logistique"
46+
assert rf, "Erreur lors de l'entraînement du RandomForest"
47+
assert mlp, "Erreur lors de l'entraînement du MLP"
48+
assert xgb_model, "Erreur lors de l'entraînement de XGBoost"
49+
50+
# Évaluation des modèles
51+
print("\n--- Évaluation des modèles de base ---")
52+
for model, name in zip([logreg, rf, mlp, xgb_model], ["Logistic Regression", "RandomForest", "MLP", "XGBoost"]):
53+
y_pred = model.predict(X_val)
54+
acc = accuracy_score(y_val, y_pred)
55+
f1 = classification_report(y_val, y_pred, output_dict=True)["macro avg"]["f1-score"]
56+
print(f"{name} - Accuracy: {acc:.4f}, F1-Score: {f1:.4f}")
57+
58+
# Sauvegarde des modèles
59+
print("\n--- Sauvegarde des modèles ---")
60+
os.makedirs(MODELS_PATH, exist_ok=True)
61+
joblib.dump(logreg, f"{MODELS_PATH}/logreg.pkl")
62+
joblib.dump(rf, f"{MODELS_PATH}/random_forest.pkl")
63+
joblib.dump(mlp, f"{MODELS_PATH}/mlp.pkl")
64+
joblib.dump(xgb_model, f"{MODELS_PATH}/xgboost.pkl")
65+
66+
print(f"Les modèles ont été sauvegardés dans {MODELS_PATH}")
67+
68+
# Chargement des modèles sauvegardés pour validation
69+
print("\n--- Chargement des modèles sauvegardés ---")
70+
logreg_loaded = joblib.load(f"{MODELS_PATH}/logreg.pkl")
71+
rf_loaded = joblib.load(f"{MODELS_PATH}/random_forest.pkl")
72+
mlp_loaded = joblib.load(f"{MODELS_PATH}/mlp.pkl")
73+
xgb_loaded = joblib.load(f"{MODELS_PATH}/xgboost.pkl")
74+
75+
assert logreg_loaded, "Erreur lors du chargement de la Régression Logistique"
76+
assert rf_loaded, "Erreur lors du chargement du RandomForest"
77+
assert mlp_loaded, "Erreur lors du chargement du MLP"
78+
assert xgb_loaded, "Erreur lors du chargement du XGBoost"
479

5-
# Charger les données
6-
train_df, test_df = preprocess_data("data/train.csv", "data/test.csv")
80+
print("Les modèles ont été correctement chargés.")
781

8-
# Entraîner les modèles
9-
train_models(train_df)
82+
# Passage à la Partie 9 (IA Explicable)
83+
print("\n--- Exécution de la Partie 9: IA Explicable (XAI) ---")
84+
os.system("python partie9_xai.py")
85+

0 commit comments

Comments
 (0)