|
1 | | -from data_preprocessing import preprocess_data |
2 | | -from model_training import train_models |
| 1 | +# main.py |
| 2 | + |
| 3 | + |
| 4 | +import os |
3 | 5 | import pandas as pd |
| 6 | +import numpy as np |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import seaborn as sns |
| 9 | +import joblib |
| 10 | +import shap |
| 11 | +import xgboost as xgb |
| 12 | + |
| 13 | +from sklearn.model_selection import train_test_split, GridSearchCV, KFold |
| 14 | +from sklearn.metrics import accuracy_score, confusion_matrix, classification_report |
| 15 | +from sklearn.linear_model import LogisticRegression |
| 16 | +from sklearn.ensemble import RandomForestClassifier |
| 17 | +from sklearn.neural_network import MLPClassifier |
| 18 | +from sklearn.inspection import permutation_importance |
| 19 | + |
| 20 | +# Définition des chemins |
| 21 | +MODELS_PATH = "/content/drive/MyDrive/Titanic-Survival-Predict-main/models" |
| 22 | + |
| 23 | +# Chargement des données |
| 24 | +train_df = pd.read_csv('/content/drive/My Drive/Titanic-Survival-Predict-main/train_cleaned.csv') |
| 25 | +test_df = pd.read_csv('/content/drive/My Drive/Titanic-Survival-Predict-main/test_cleaned.csv') |
| 26 | + |
| 27 | +# Vérification des données |
| 28 | +assert train_df.shape[0] > 0, "Le dataset d'entraînement est vide" |
| 29 | +assert test_df.shape[0] > 0, "Le dataset de test est vide" |
| 30 | + |
| 31 | +# Préparation des données |
| 32 | +X = train_df.drop(columns=['Survived']) |
| 33 | +y = train_df['Survived'] |
| 34 | + |
| 35 | +X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) |
| 36 | + |
| 37 | +# Entraînement des modèles |
| 38 | +print("\n--- Entraînement des modèles de base ---") |
| 39 | +logreg = LogisticRegression(max_iter=1000, random_state=42).fit(X_train, y_train) |
| 40 | +rf = RandomForestClassifier(n_estimators=100, random_state=42).fit(X_train, y_train) |
| 41 | +mlp = MLPClassifier(alpha=0.06, hidden_layer_sizes=(50, 50), learning_rate_init=0.03, max_iter=158).fit(X_train, y_train) |
| 42 | +xgb_model = xgb.XGBClassifier(use_label_encoder=False, enable_categorical=True, eval_metric='logloss', random_state=42).fit(X_train, y_train) |
| 43 | + |
| 44 | +# Vérification des modèles |
| 45 | +assert logreg, "Erreur lors de l'entraînement de la Régression Logistique" |
| 46 | +assert rf, "Erreur lors de l'entraînement du RandomForest" |
| 47 | +assert mlp, "Erreur lors de l'entraînement du MLP" |
| 48 | +assert xgb_model, "Erreur lors de l'entraînement de XGBoost" |
| 49 | + |
| 50 | +# Évaluation des modèles |
| 51 | +print("\n--- Évaluation des modèles de base ---") |
| 52 | +for model, name in zip([logreg, rf, mlp, xgb_model], ["Logistic Regression", "RandomForest", "MLP", "XGBoost"]): |
| 53 | + y_pred = model.predict(X_val) |
| 54 | + acc = accuracy_score(y_val, y_pred) |
| 55 | + f1 = classification_report(y_val, y_pred, output_dict=True)["macro avg"]["f1-score"] |
| 56 | + print(f"{name} - Accuracy: {acc:.4f}, F1-Score: {f1:.4f}") |
| 57 | + |
| 58 | +# Sauvegarde des modèles |
| 59 | +print("\n--- Sauvegarde des modèles ---") |
| 60 | +os.makedirs(MODELS_PATH, exist_ok=True) |
| 61 | +joblib.dump(logreg, f"{MODELS_PATH}/logreg.pkl") |
| 62 | +joblib.dump(rf, f"{MODELS_PATH}/random_forest.pkl") |
| 63 | +joblib.dump(mlp, f"{MODELS_PATH}/mlp.pkl") |
| 64 | +joblib.dump(xgb_model, f"{MODELS_PATH}/xgboost.pkl") |
| 65 | + |
| 66 | +print(f"Les modèles ont été sauvegardés dans {MODELS_PATH}") |
| 67 | + |
| 68 | +# Chargement des modèles sauvegardés pour validation |
| 69 | +print("\n--- Chargement des modèles sauvegardés ---") |
| 70 | +logreg_loaded = joblib.load(f"{MODELS_PATH}/logreg.pkl") |
| 71 | +rf_loaded = joblib.load(f"{MODELS_PATH}/random_forest.pkl") |
| 72 | +mlp_loaded = joblib.load(f"{MODELS_PATH}/mlp.pkl") |
| 73 | +xgb_loaded = joblib.load(f"{MODELS_PATH}/xgboost.pkl") |
| 74 | + |
| 75 | +assert logreg_loaded, "Erreur lors du chargement de la Régression Logistique" |
| 76 | +assert rf_loaded, "Erreur lors du chargement du RandomForest" |
| 77 | +assert mlp_loaded, "Erreur lors du chargement du MLP" |
| 78 | +assert xgb_loaded, "Erreur lors du chargement du XGBoost" |
4 | 79 |
|
5 | | -# Charger les données |
6 | | -train_df, test_df = preprocess_data("data/train.csv", "data/test.csv") |
| 80 | +print("Les modèles ont été correctement chargés.") |
7 | 81 |
|
8 | | -# Entraîner les modèles |
9 | | -train_models(train_df) |
| 82 | +# Passage à la Partie 9 (IA Explicable) |
| 83 | +print("\n--- Exécution de la Partie 9: IA Explicable (XAI) ---") |
| 84 | +os.system("python partie9_xai.py") |
| 85 | + |
0 commit comments