{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5c1bd341",
   "metadata": {},
   "source": [
    "# Классификация объектов KOI глубокой нейросетью с PSO и дообучением\n",
    "\n",
    "В работе автоматически выбирается ключевая зависимая переменная для задачи классификации, далее строится глубокая нейросеть с количеством скрытых слоёв больше 3.  \n",
    "Сначала веса подбираются методом **PSO (Particle Swarm Optimization)**, затем модель **дообучается градиентным методом**.  \n",
    "В конце выводятся все ключевые метрики на тестовой выборке для написания итогового вывода.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5528408b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import re\n",
    "import copy\n",
    "import math\n",
    "import warnings\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.metrics import (\n",
    "    accuracy_score,\n",
    "    balanced_accuracy_score,\n",
    "    precision_score,\n",
    "    recall_score,\n",
    "    f1_score,\n",
    "    roc_auc_score,\n",
    "    confusion_matrix,\n",
    "    classification_report,\n",
    "    log_loss,\n",
    ")\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b4170ddb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Используемый файл: NASA Exoplanet.csv\n",
      "Форма датасета: (9564, 49)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>kepid</th>\n",
       "      <th>kepoi_name</th>\n",
       "      <th>kepler_name</th>\n",
       "      <th>koi_disposition</th>\n",
       "      <th>koi_pdisposition</th>\n",
       "      <th>koi_score</th>\n",
       "      <th>koi_fpflag_nt</th>\n",
       "      <th>koi_fpflag_ss</th>\n",
       "      <th>koi_fpflag_co</th>\n",
       "      <th>koi_fpflag_ec</th>\n",
       "      <th>...</th>\n",
       "      <th>koi_steff_err2</th>\n",
       "      <th>koi_slogg</th>\n",
       "      <th>koi_slogg_err1</th>\n",
       "      <th>koi_slogg_err2</th>\n",
       "      <th>koi_srad</th>\n",
       "      <th>koi_srad_err1</th>\n",
       "      <th>koi_srad_err2</th>\n",
       "      <th>ra</th>\n",
       "      <th>dec</th>\n",
       "      <th>koi_kepmag</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>10797460</td>\n",
       "      <td>K00752.01</td>\n",
       "      <td>Kepler-227 b</td>\n",
       "      <td>CONFIRMED</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>1.000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>-81.0</td>\n",
       "      <td>4.467</td>\n",
       "      <td>0.064</td>\n",
       "      <td>-0.096</td>\n",
       "      <td>0.927</td>\n",
       "      <td>0.105</td>\n",
       "      <td>-0.061</td>\n",
       "      <td>291.93423</td>\n",
       "      <td>48.141651</td>\n",
       "      <td>15.347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>10797460</td>\n",
       "      <td>K00752.02</td>\n",
       "      <td>Kepler-227 c</td>\n",
       "      <td>CONFIRMED</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>0.969</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>-81.0</td>\n",
       "      <td>4.467</td>\n",
       "      <td>0.064</td>\n",
       "      <td>-0.096</td>\n",
       "      <td>0.927</td>\n",
       "      <td>0.105</td>\n",
       "      <td>-0.061</td>\n",
       "      <td>291.93423</td>\n",
       "      <td>48.141651</td>\n",
       "      <td>15.347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>10811496</td>\n",
       "      <td>K00753.01</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>-176.0</td>\n",
       "      <td>4.544</td>\n",
       "      <td>0.044</td>\n",
       "      <td>-0.176</td>\n",
       "      <td>0.868</td>\n",
       "      <td>0.233</td>\n",
       "      <td>-0.078</td>\n",
       "      <td>297.00482</td>\n",
       "      <td>48.134129</td>\n",
       "      <td>15.436</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>10848459</td>\n",
       "      <td>K00754.01</td>\n",
       "      <td>NaN</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>-174.0</td>\n",
       "      <td>4.564</td>\n",
       "      <td>0.053</td>\n",
       "      <td>-0.168</td>\n",
       "      <td>0.791</td>\n",
       "      <td>0.201</td>\n",
       "      <td>-0.067</td>\n",
       "      <td>285.53461</td>\n",
       "      <td>48.285210</td>\n",
       "      <td>15.597</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>10854555</td>\n",
       "      <td>K00755.01</td>\n",
       "      <td>Kepler-664 b</td>\n",
       "      <td>CONFIRMED</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>1.000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>-211.0</td>\n",
       "      <td>4.438</td>\n",
       "      <td>0.070</td>\n",
       "      <td>-0.210</td>\n",
       "      <td>1.046</td>\n",
       "      <td>0.334</td>\n",
       "      <td>-0.133</td>\n",
       "      <td>288.75488</td>\n",
       "      <td>48.226200</td>\n",
       "      <td>15.509</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 49 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      kepid kepoi_name   kepler_name koi_disposition koi_pdisposition  \\\n",
       "0  10797460  K00752.01  Kepler-227 b       CONFIRMED        CANDIDATE   \n",
       "1  10797460  K00752.02  Kepler-227 c       CONFIRMED        CANDIDATE   \n",
       "2  10811496  K00753.01           NaN       CANDIDATE        CANDIDATE   \n",
       "3  10848459  K00754.01           NaN  FALSE POSITIVE   FALSE POSITIVE   \n",
       "4  10854555  K00755.01  Kepler-664 b       CONFIRMED        CANDIDATE   \n",
       "\n",
       "   koi_score  koi_fpflag_nt  koi_fpflag_ss  koi_fpflag_co  koi_fpflag_ec  ...  \\\n",
       "0      1.000              0              0              0              0  ...   \n",
       "1      0.969              0              0              0              0  ...   \n",
       "2      0.000              0              0              0              0  ...   \n",
       "3      0.000              0              1              0              0  ...   \n",
       "4      1.000              0              0              0              0  ...   \n",
       "\n",
       "   koi_steff_err2  koi_slogg  koi_slogg_err1  koi_slogg_err2  koi_srad  \\\n",
       "0           -81.0      4.467           0.064          -0.096     0.927   \n",
       "1           -81.0      4.467           0.064          -0.096     0.927   \n",
       "2          -176.0      4.544           0.044          -0.176     0.868   \n",
       "3          -174.0      4.564           0.053          -0.168     0.791   \n",
       "4          -211.0      4.438           0.070          -0.210     1.046   \n",
       "\n",
       "   koi_srad_err1  koi_srad_err2         ra        dec  koi_kepmag  \n",
       "0          0.105         -0.061  291.93423  48.141651      15.347  \n",
       "1          0.105         -0.061  291.93423  48.141651      15.347  \n",
       "2          0.233         -0.078  297.00482  48.134129      15.436  \n",
       "3          0.201         -0.067  285.53461  48.285210      15.597  \n",
       "4          0.334         -0.133  288.75488  48.226200      15.509  \n",
       "\n",
       "[5 rows x 49 columns]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Типы данных:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dtype</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>kepid</th>\n",
       "      <td>int64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>kepoi_name</th>\n",
       "      <td>object</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>kepler_name</th>\n",
       "      <td>object</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_disposition</th>\n",
       "      <td>object</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_pdisposition</th>\n",
       "      <td>object</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_score</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_fpflag_nt</th>\n",
       "      <td>int64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_fpflag_ss</th>\n",
       "      <td>int64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_fpflag_co</th>\n",
       "      <td>int64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_fpflag_ec</th>\n",
       "      <td>int64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_period</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_period_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_period_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_time0bk</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_time0bk_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_time0bk_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_impact</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_impact_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_impact_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_duration</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_duration_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_duration_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_depth</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_depth_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_depth_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_prad</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_prad_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_prad_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_teq</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_teq_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_teq_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_insol</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_insol_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_insol_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_model_snr</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_tce_plnt_num</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_tce_delivname</th>\n",
       "      <td>object</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_steff</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_steff_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_steff_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_slogg</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_slogg_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_slogg_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_srad</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_srad_err1</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_srad_err2</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ra</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dec</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>koi_kepmag</th>\n",
       "      <td>float64</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     dtype\n",
       "kepid                int64\n",
       "kepoi_name          object\n",
       "kepler_name         object\n",
       "koi_disposition     object\n",
       "koi_pdisposition    object\n",
       "koi_score          float64\n",
       "koi_fpflag_nt        int64\n",
       "koi_fpflag_ss        int64\n",
       "koi_fpflag_co        int64\n",
       "koi_fpflag_ec        int64\n",
       "koi_period         float64\n",
       "koi_period_err1    float64\n",
       "koi_period_err2    float64\n",
       "koi_time0bk        float64\n",
       "koi_time0bk_err1   float64\n",
       "koi_time0bk_err2   float64\n",
       "koi_impact         float64\n",
       "koi_impact_err1    float64\n",
       "koi_impact_err2    float64\n",
       "koi_duration       float64\n",
       "koi_duration_err1  float64\n",
       "koi_duration_err2  float64\n",
       "koi_depth          float64\n",
       "koi_depth_err1     float64\n",
       "koi_depth_err2     float64\n",
       "koi_prad           float64\n",
       "koi_prad_err1      float64\n",
       "koi_prad_err2      float64\n",
       "koi_teq            float64\n",
       "koi_teq_err1       float64\n",
       "koi_teq_err2       float64\n",
       "koi_insol          float64\n",
       "koi_insol_err1     float64\n",
       "koi_insol_err2     float64\n",
       "koi_model_snr      float64\n",
       "koi_tce_plnt_num   float64\n",
       "koi_tce_delivname   object\n",
       "koi_steff          float64\n",
       "koi_steff_err1     float64\n",
       "koi_steff_err2     float64\n",
       "koi_slogg          float64\n",
       "koi_slogg_err1     float64\n",
       "koi_slogg_err2     float64\n",
       "koi_srad           float64\n",
       "koi_srad_err1      float64\n",
       "koi_srad_err2      float64\n",
       "ra                 float64\n",
       "dec                float64\n",
       "koi_kepmag         float64"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "def find_csv_file():\n",
    "    candidates = [\n",
    "        \"NASA Exoplanet.csv\",\n",
    "        \"nasa_exoplanet.csv\",\n",
    "        \"NASA_Exoplanet.csv\",\n",
    "        \"koi.csv\",\n",
    "        \"kepler.csv\",\n",
    "    ]\n",
    "    for name in candidates:\n",
    "        if Path(name).exists():\n",
    "            return name\n",
    "\n",
    "    csv_files = sorted(Path(\".\").glob(\"*.csv\"))\n",
    "    preferred = []\n",
    "    for p in csv_files:\n",
    "        low = p.name.lower()\n",
    "        if any(key in low for key in [\"nasa\", \"exo\", \"kepler\", \"koi\"]):\n",
    "            preferred.append(p)\n",
    "    if preferred:\n",
    "        return preferred[0].as_posix()\n",
    "    if csv_files:\n",
    "        return csv_files[0].as_posix()\n",
    "    raise FileNotFoundError(\"CSV-файл не найден в текущей папке.\")\n",
    "\n",
    "csv_path = find_csv_file()\n",
    "print(\"Используемый файл:\", csv_path)\n",
    "\n",
    "df = pd.read_csv(csv_path)\n",
    "print(\"Форма датасета:\", df.shape)\n",
    "display(df.head())\n",
    "print(\"\\nТипы данных:\")\n",
    "display(df.dtypes.to_frame(\"dtype\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5cc71ee1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original_name</th>\n",
       "      <th>normalized_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>kepid</td>\n",
       "      <td>kepid</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>kepoi_name</td>\n",
       "      <td>kepoi_name</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>kepler_name</td>\n",
       "      <td>kepler_name</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>koi_disposition</td>\n",
       "      <td>koi_disposition</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>koi_pdisposition</td>\n",
       "      <td>koi_pdisposition</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>koi_score</td>\n",
       "      <td>koi_score</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>koi_fpflag_nt</td>\n",
       "      <td>koi_fpflag_nt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>koi_fpflag_ss</td>\n",
       "      <td>koi_fpflag_ss</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>koi_fpflag_co</td>\n",
       "      <td>koi_fpflag_co</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>koi_fpflag_ec</td>\n",
       "      <td>koi_fpflag_ec</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>koi_period</td>\n",
       "      <td>koi_period</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>koi_period_err1</td>\n",
       "      <td>koi_period_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>koi_period_err2</td>\n",
       "      <td>koi_period_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>koi_time0bk</td>\n",
       "      <td>koi_time0bk</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>koi_time0bk_err1</td>\n",
       "      <td>koi_time0bk_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>koi_time0bk_err2</td>\n",
       "      <td>koi_time0bk_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>koi_impact</td>\n",
       "      <td>koi_impact</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>koi_impact_err1</td>\n",
       "      <td>koi_impact_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>koi_impact_err2</td>\n",
       "      <td>koi_impact_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>koi_duration</td>\n",
       "      <td>koi_duration</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>koi_duration_err1</td>\n",
       "      <td>koi_duration_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>koi_duration_err2</td>\n",
       "      <td>koi_duration_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>koi_depth</td>\n",
       "      <td>koi_depth</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>koi_depth_err1</td>\n",
       "      <td>koi_depth_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>koi_depth_err2</td>\n",
       "      <td>koi_depth_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>koi_prad</td>\n",
       "      <td>koi_prad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>koi_prad_err1</td>\n",
       "      <td>koi_prad_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>koi_prad_err2</td>\n",
       "      <td>koi_prad_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>koi_teq</td>\n",
       "      <td>koi_teq</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>koi_teq_err1</td>\n",
       "      <td>koi_teq_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>koi_teq_err2</td>\n",
       "      <td>koi_teq_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>koi_insol</td>\n",
       "      <td>koi_insol</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>koi_insol_err1</td>\n",
       "      <td>koi_insol_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>koi_insol_err2</td>\n",
       "      <td>koi_insol_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>koi_model_snr</td>\n",
       "      <td>koi_model_snr</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>koi_tce_plnt_num</td>\n",
       "      <td>koi_tce_plnt_num</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>koi_tce_delivname</td>\n",
       "      <td>koi_tce_delivname</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>koi_steff</td>\n",
       "      <td>koi_steff</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>koi_steff_err1</td>\n",
       "      <td>koi_steff_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>koi_steff_err2</td>\n",
       "      <td>koi_steff_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>koi_slogg</td>\n",
       "      <td>koi_slogg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>koi_slogg_err1</td>\n",
       "      <td>koi_slogg_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>koi_slogg_err2</td>\n",
       "      <td>koi_slogg_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>koi_srad</td>\n",
       "      <td>koi_srad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>koi_srad_err1</td>\n",
       "      <td>koi_srad_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>koi_srad_err2</td>\n",
       "      <td>koi_srad_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>ra</td>\n",
       "      <td>ra</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>dec</td>\n",
       "      <td>dec</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>koi_kepmag</td>\n",
       "      <td>koi_kepmag</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        original_name    normalized_name\n",
       "0               kepid              kepid\n",
       "1          kepoi_name         kepoi_name\n",
       "2         kepler_name        kepler_name\n",
       "3     koi_disposition    koi_disposition\n",
       "4    koi_pdisposition   koi_pdisposition\n",
       "5           koi_score          koi_score\n",
       "6       koi_fpflag_nt      koi_fpflag_nt\n",
       "7       koi_fpflag_ss      koi_fpflag_ss\n",
       "8       koi_fpflag_co      koi_fpflag_co\n",
       "9       koi_fpflag_ec      koi_fpflag_ec\n",
       "10         koi_period         koi_period\n",
       "11    koi_period_err1    koi_period_err1\n",
       "12    koi_period_err2    koi_period_err2\n",
       "13        koi_time0bk        koi_time0bk\n",
       "14   koi_time0bk_err1   koi_time0bk_err1\n",
       "15   koi_time0bk_err2   koi_time0bk_err2\n",
       "16         koi_impact         koi_impact\n",
       "17    koi_impact_err1    koi_impact_err1\n",
       "18    koi_impact_err2    koi_impact_err2\n",
       "19       koi_duration       koi_duration\n",
       "20  koi_duration_err1  koi_duration_err1\n",
       "21  koi_duration_err2  koi_duration_err2\n",
       "22          koi_depth          koi_depth\n",
       "23     koi_depth_err1     koi_depth_err1\n",
       "24     koi_depth_err2     koi_depth_err2\n",
       "25           koi_prad           koi_prad\n",
       "26      koi_prad_err1      koi_prad_err1\n",
       "27      koi_prad_err2      koi_prad_err2\n",
       "28            koi_teq            koi_teq\n",
       "29       koi_teq_err1       koi_teq_err1\n",
       "30       koi_teq_err2       koi_teq_err2\n",
       "31          koi_insol          koi_insol\n",
       "32     koi_insol_err1     koi_insol_err1\n",
       "33     koi_insol_err2     koi_insol_err2\n",
       "34      koi_model_snr      koi_model_snr\n",
       "35   koi_tce_plnt_num   koi_tce_plnt_num\n",
       "36  koi_tce_delivname  koi_tce_delivname\n",
       "37          koi_steff          koi_steff\n",
       "38     koi_steff_err1     koi_steff_err1\n",
       "39     koi_steff_err2     koi_steff_err2\n",
       "40          koi_slogg          koi_slogg\n",
       "41     koi_slogg_err1     koi_slogg_err1\n",
       "42     koi_slogg_err2     koi_slogg_err2\n",
       "43           koi_srad           koi_srad\n",
       "44      koi_srad_err1      koi_srad_err1\n",
       "45      koi_srad_err2      koi_srad_err2\n",
       "46                 ra                 ra\n",
       "47                dec                dec\n",
       "48         koi_kepmag         koi_kepmag"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "def normalize_col_name(col):\n",
    "    col = str(col).strip().lower()\n",
    "    col = col.replace(\"%\", \"percent\")\n",
    "    col = re.sub(r\"[^0-9a-zA-Zа-яА-Я]+\", \"_\", col)\n",
    "    col = re.sub(r\"_+\", \"_\", col).strip(\"_\")\n",
    "    return col\n",
    "\n",
    "original_cols = df.columns.tolist()\n",
    "normalized_cols = [normalize_col_name(c) for c in original_cols]\n",
    "mapping_df = pd.DataFrame({\"original_name\": original_cols, \"normalized_name\": normalized_cols})\n",
    "display(mapping_df)\n",
    "\n",
    "df.columns = normalized_cols\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "680f5053",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Выбранная зависимая переменная: koi_pdisposition\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>score</th>\n",
       "      <th>n_unique</th>\n",
       "      <th>column</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>100</td>\n",
       "      <td>2</td>\n",
       "      <td>koi_pdisposition</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>100</td>\n",
       "      <td>3</td>\n",
       "      <td>koi_disposition</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>20</td>\n",
       "      <td>2</td>\n",
       "      <td>koi_fpflag_co</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>20</td>\n",
       "      <td>2</td>\n",
       "      <td>koi_fpflag_ec</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>20</td>\n",
       "      <td>2</td>\n",
       "      <td>koi_fpflag_ss</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>20</td>\n",
       "      <td>3</td>\n",
       "      <td>koi_fpflag_nt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>20</td>\n",
       "      <td>3</td>\n",
       "      <td>koi_tce_delivname</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>20</td>\n",
       "      <td>8</td>\n",
       "      <td>koi_tce_plnt_num</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0</td>\n",
       "      <td>2745</td>\n",
       "      <td>kepler_name</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>-100</td>\n",
       "      <td>9564</td>\n",
       "      <td>kepoi_name</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   score  n_unique             column\n",
       "0    100         2   koi_pdisposition\n",
       "1    100         3    koi_disposition\n",
       "2     20         2      koi_fpflag_co\n",
       "3     20         2      koi_fpflag_ec\n",
       "4     20         2      koi_fpflag_ss\n",
       "5     20         3      koi_fpflag_nt\n",
       "6     20         3  koi_tce_delivname\n",
       "7     20         8   koi_tce_plnt_num\n",
       "8      0      2745        kepler_name\n",
       "9   -100      9564         kepoi_name"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "def choose_target_column(dataframe):\n",
    "    cols = dataframe.columns.tolist()\n",
    "    priority_patterns = [\n",
    "        r\"disposition\",\n",
    "        r\"target\",\n",
    "        r\"label\",\n",
    "        r\"class$\",\n",
    "        r\"status\",\n",
    "        r\"outcome\",\n",
    "        r\"risk\",\n",
    "        r\"quality\",\n",
    "    ]\n",
    "\n",
    "    candidates = []\n",
    "    n = len(dataframe)\n",
    "\n",
    "    for col in cols:\n",
    "        s = dataframe[col]\n",
    "        nunique = s.dropna().nunique()\n",
    "        if nunique < 2:\n",
    "            continue\n",
    "\n",
    "        is_class_like = (\n",
    "            s.dtype == \"object\"\n",
    "            or str(s.dtype).startswith(\"category\")\n",
    "            or nunique <= 20\n",
    "        )\n",
    "        if not is_class_like:\n",
    "            continue\n",
    "\n",
    "        score = 0\n",
    "        for i, pat in enumerate(priority_patterns[::-1], start=1):\n",
    "            if re.search(pat, col):\n",
    "                score += 10 * i\n",
    "\n",
    "        # более предпочтительны столбцы с небольшим числом классов\n",
    "        if 2 <= nunique <= 10:\n",
    "            score += 20\n",
    "        elif nunique <= 20:\n",
    "            score += 10\n",
    "\n",
    "        # целевая переменная не должна быть почти уникальной\n",
    "        if nunique / max(n, 1) > 0.3:\n",
    "            score -= 100\n",
    "\n",
    "        candidates.append((score, nunique, col))\n",
    "\n",
    "    if not candidates:\n",
    "        raise ValueError(\"Не удалось автоматически определить зависимую переменную для классификации.\")\n",
    "\n",
    "    candidates = sorted(candidates, key=lambda x: (-x[0], x[1], x[2]))\n",
    "    return candidates[0][2], pd.DataFrame(candidates, columns=[\"score\", \"n_unique\", \"column\"])\n",
    "\n",
    "target_col, target_candidates = choose_target_column(df)\n",
    "print(\"Выбранная зависимая переменная:\", target_col)\n",
    "display(target_candidates.head(15))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "675462f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class_index</th>\n",
       "      <th>class_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   class_index     class_label\n",
       "0            0       CANDIDATE\n",
       "1            1  FALSE POSITIVE"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Распределение классов:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "koi_pdisposition\n",
       "CANDIDATE         0.493204\n",
       "FALSE POSITIVE    0.506796\n",
       "Name: share, dtype: float64"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Очистка и подготовка целевой переменной\n",
    "target_raw = df[target_col].copy()\n",
    "\n",
    "if target_raw.dtype == \"object\":\n",
    "    target_clean = (\n",
    "        target_raw.astype(str)\n",
    "        .str.strip()\n",
    "        .replace({\"\": np.nan, \"nan\": np.nan, \"None\": np.nan, \"none\": np.nan, \"NULL\": np.nan, \"null\": np.nan})\n",
    "    )\n",
    "else:\n",
    "    target_clean = target_raw.copy()\n",
    "\n",
    "mask_target = target_clean.notna()\n",
    "df = df.loc[mask_target].reset_index(drop=True)\n",
    "target_clean = target_clean.loc[mask_target].reset_index(drop=True)\n",
    "\n",
    "label_encoder = LabelEncoder()\n",
    "y = label_encoder.fit_transform(target_clean.astype(str))\n",
    "\n",
    "class_mapping = pd.DataFrame({\n",
    "    \"class_index\": np.arange(len(label_encoder.classes_)),\n",
    "    \"class_label\": label_encoder.classes_\n",
    "})\n",
    "display(class_mapping)\n",
    "\n",
    "class_share = pd.Series(target_clean).value_counts(normalize=True).sort_index().rename(\"share\")\n",
    "print(\"Распределение классов:\")\n",
    "display(class_share)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0ff25260",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Удаляемые идентификаторы / несодержательные поля: ['kepoi_name', 'koi_teq_err1', 'koi_teq_err2']\n",
      "Число исходных признаков: 45\n",
      "Числовые признаки: 42\n",
      "Категориальные признаки: 3\n",
      "Первые числовые признаки: ['kepid', 'koi_score', 'koi_fpflag_nt', 'koi_fpflag_ss', 'koi_fpflag_co', 'koi_fpflag_ec', 'koi_period', 'koi_period_err1', 'koi_period_err2', 'koi_time0bk', 'koi_time0bk_err1', 'koi_time0bk_err2', 'koi_impact', 'koi_impact_err1', 'koi_impact_err2']\n",
      "Первые категориальные признаки: ['kepler_name', 'koi_disposition', 'koi_tce_delivname']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Отделяем признаки и убираем очевидные идентификаторы / почти уникальные текстовые поля\n",
    "X = df.drop(columns=[target_col]).copy()\n",
    "\n",
    "def find_drop_columns(dataframe):\n",
    "    to_drop = []\n",
    "    n = len(dataframe)\n",
    "    for col in dataframe.columns:\n",
    "        s = dataframe[col]\n",
    "        nunique = s.nunique(dropna=True)\n",
    "        unique_ratio = nunique / max(n, 1)\n",
    "\n",
    "        if re.search(r\"(^id$|_id$|^id_|name)\", col) and unique_ratio > 0.3:\n",
    "            to_drop.append(col)\n",
    "            continue\n",
    "\n",
    "        if s.dtype == \"object\" and unique_ratio > 0.95:\n",
    "            to_drop.append(col)\n",
    "            continue\n",
    "\n",
    "        if nunique <= 1:\n",
    "            to_drop.append(col)\n",
    "            continue\n",
    "\n",
    "    return sorted(set(to_drop))\n",
    "\n",
    "drop_cols = find_drop_columns(X)\n",
    "print(\"Удаляемые идентификаторы / несодержательные поля:\", drop_cols)\n",
    "\n",
    "X = X.drop(columns=drop_cols, errors=\"ignore\").copy()\n",
    "\n",
    "# Попытка перевести object-столбцы в числовой формат, если это возможно\n",
    "for col in X.columns:\n",
    "    if X[col].dtype == \"object\":\n",
    "        temp = X[col].astype(str).str.strip()\n",
    "        temp = temp.replace({\"\": np.nan, \"nan\": np.nan, \"None\": np.nan, \"none\": np.nan, \"NULL\": np.nan, \"null\": np.nan})\n",
    "        temp_num = pd.to_numeric(temp, errors=\"coerce\")\n",
    "        share_parsed = temp_num.notna().mean()\n",
    "        if share_parsed >= 0.85:\n",
    "            X[col] = temp_num\n",
    "        else:\n",
    "            X[col] = temp\n",
    "\n",
    "numeric_features = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "categorical_features = [c for c in X.columns if c not in numeric_features]\n",
    "\n",
    "print(\"Число исходных признаков:\", X.shape[1])\n",
    "print(\"Числовые признаки:\", len(numeric_features))\n",
    "print(\"Категориальные признаки:\", len(categorical_features))\n",
    "print(\"Первые числовые признаки:\", numeric_features[:15])\n",
    "print(\"Первые категориальные признаки:\", categorical_features[:15])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9886ad12",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Размер train: (6120, 45)\n",
      "Размер val: (1531, 45)\n",
      "Размер test: (1913, 45)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Разбиение данных\n",
    "X_trainval, X_test, y_trainval, y_test = train_test_split(\n",
    "    X, y,\n",
    "    test_size=0.2,\n",
    "    random_state=RANDOM_STATE,\n",
    "    stratify=y\n",
    ")\n",
    "\n",
    "X_train, X_val, y_train, y_val = train_test_split(\n",
    "    X_trainval, y_trainval,\n",
    "    test_size=0.2,\n",
    "    random_state=RANDOM_STATE,\n",
    "    stratify=y_trainval\n",
    ")\n",
    "\n",
    "print(\"Размер train:\", X_train.shape)\n",
    "print(\"Размер val:\", X_val.shape)\n",
    "print(\"Размер test:\", X_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c7333379",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Число признаков после кодирования: 2240\n",
      "Первые 25 признаков после кодирования:\n",
      "['num__kepid' 'num__koi_score' 'num__koi_fpflag_nt' 'num__koi_fpflag_ss'\n",
      " 'num__koi_fpflag_co' 'num__koi_fpflag_ec' 'num__koi_period'\n",
      " 'num__koi_period_err1' 'num__koi_period_err2' 'num__koi_time0bk'\n",
      " 'num__koi_time0bk_err1' 'num__koi_time0bk_err2' 'num__koi_impact'\n",
      " 'num__koi_impact_err1' 'num__koi_impact_err2' 'num__koi_duration'\n",
      " 'num__koi_duration_err1' 'num__koi_duration_err2' 'num__koi_depth'\n",
      " 'num__koi_depth_err1' 'num__koi_depth_err2' 'num__koi_prad'\n",
      " 'num__koi_prad_err1' 'num__koi_prad_err2' 'num__koi_teq']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Предобработка\n",
    "try:\n",
    "    ohe = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False)\n",
    "except TypeError:\n",
    "    ohe = OneHotEncoder(handle_unknown=\"ignore\", sparse=False)\n",
    "\n",
    "numeric_transformer = Pipeline(steps=[\n",
    "    (\"imputer\", SimpleImputer(strategy=\"median\")),\n",
    "    (\"scaler\", StandardScaler())\n",
    "])\n",
    "\n",
    "categorical_transformer = Pipeline(steps=[\n",
    "    (\"imputer\", SimpleImputer(strategy=\"most_frequent\")),\n",
    "    (\"onehot\", ohe)\n",
    "])\n",
    "\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        (\"num\", numeric_transformer, numeric_features),\n",
    "        (\"cat\", categorical_transformer, categorical_features)\n",
    "    ],\n",
    "    remainder=\"drop\"\n",
    ")\n",
    "\n",
    "X_train_proc = preprocessor.fit_transform(X_train)\n",
    "X_val_proc = preprocessor.transform(X_val)\n",
    "X_test_proc = preprocessor.transform(X_test)\n",
    "X_trainval_proc = preprocessor.fit_transform(X_trainval)\n",
    "X_test_proc_final = preprocessor.transform(X_test)\n",
    "\n",
    "feature_names = preprocessor.get_feature_names_out()\n",
    "\n",
    "print(\"Число признаков после кодирования:\", len(feature_names))\n",
    "print(\"Первые 25 признаков после кодирования:\")\n",
    "print(feature_names[:25])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fbcd8812",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Размер train после oversampling: (6202, 1795)\n",
      "Распределение классов после oversampling:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0    3101\n",
       "1    3101\n",
       "Name: count, dtype: int64"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "def oversample_multiclass(X_array, y_array, random_state=42):\n",
    "    rng = np.random.default_rng(random_state)\n",
    "    classes, counts = np.unique(y_array, return_counts=True)\n",
    "    max_count = counts.max()\n",
    "\n",
    "    X_parts = []\n",
    "    y_parts = []\n",
    "\n",
    "    for cls, cnt in zip(classes, counts):\n",
    "        idx = np.where(y_array == cls)[0]\n",
    "        if cnt < max_count:\n",
    "            extra_idx = rng.choice(idx, size=max_count - cnt, replace=True)\n",
    "            idx = np.concatenate([idx, extra_idx])\n",
    "        rng.shuffle(idx)\n",
    "        X_parts.append(X_array[idx])\n",
    "        y_parts.append(y_array[idx])\n",
    "\n",
    "    X_bal = np.vstack(X_parts)\n",
    "    y_bal = np.concatenate(y_parts)\n",
    "\n",
    "    perm = rng.permutation(len(y_bal))\n",
    "    return X_bal[perm], y_bal[perm]\n",
    "\n",
    "X_train_bal, y_train_bal = oversample_multiclass(np.asarray(X_train_proc, dtype=np.float32), y_train, random_state=RANDOM_STATE)\n",
    "X_val_np = np.asarray(X_val_proc, dtype=np.float32)\n",
    "X_test_np = np.asarray(X_test_proc, dtype=np.float32)\n",
    "X_trainval_np = np.asarray(X_trainval_proc, dtype=np.float32)\n",
    "X_test_final_np = np.asarray(X_test_proc_final, dtype=np.float32)\n",
    "\n",
    "print(\"Размер train после oversampling:\", X_train_bal.shape)\n",
    "print(\"Распределение классов после oversampling:\")\n",
    "display(pd.Series(y_train_bal).value_counts().sort_index())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d4bf58b",
   "metadata": {},
   "source": [
    "## Глубокая нейросеть и PSO\n",
    "\n",
    "Используется архитектура с числом скрытых слоёв больше 3: **(128, 64, 32, 16)**.  \n",
    "PSO выполняется на сбалансированной подвыборке train, после чего лучшая найденная конфигурация весов используется для инициализации модели и дальнейшего дообучения.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e0b60499",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Архитектура скрытых слоёв: (128, 64, 32, 16)\n",
      "Полная структура сети: [1795, 128, 64, 32, 16, 1]\n",
      "Бинарная классификация: True\n"
     ]
    }
   ],
   "source": [
    "\n",
    "hidden_layers = (128, 64, 32, 16)\n",
    "input_dim = X_train_bal.shape[1]\n",
    "classes_unique = np.unique(y)\n",
    "n_classes = len(classes_unique)\n",
    "is_binary = (n_classes == 2)\n",
    "output_dim = 1 if is_binary else n_classes\n",
    "layer_sizes = [input_dim] + list(hidden_layers) + [output_dim]\n",
    "\n",
    "print(\"Архитектура скрытых слоёв:\", hidden_layers)\n",
    "print(\"Полная структура сети:\", layer_sizes)\n",
    "print(\"Бинарная классификация:\", is_binary)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3f8d880e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Число параметров сети: 240769\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Вспомогательные функции для ручного прямого прохода\n",
    "def relu(x):\n",
    "    return np.maximum(0.0, x)\n",
    "\n",
    "def sigmoid(z):\n",
    "    return 1.0 / (1.0 + np.exp(-np.clip(z, -40, 40)))\n",
    "\n",
    "def softmax(z):\n",
    "    z = z - np.max(z, axis=1, keepdims=True)\n",
    "    exp_z = np.exp(z)\n",
    "    return exp_z / np.clip(exp_z.sum(axis=1, keepdims=True), 1e-12, None)\n",
    "\n",
    "def build_shapes(layer_sizes):\n",
    "    shapes = []\n",
    "    for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]):\n",
    "        shapes.append((in_dim, out_dim))\n",
    "    return shapes\n",
    "\n",
    "weight_shapes = build_shapes(layer_sizes)\n",
    "bias_shapes = [(out_dim,) for out_dim in layer_sizes[1:]]\n",
    "\n",
    "def total_params_count(weight_shapes, bias_shapes):\n",
    "    return int(sum(np.prod(s) for s in weight_shapes) + sum(np.prod(s) for s in bias_shapes))\n",
    "\n",
    "param_size = total_params_count(weight_shapes, bias_shapes)\n",
    "print(\"Число параметров сети:\", param_size)\n",
    "\n",
    "def unpack_params(vector, weight_shapes, bias_shapes):\n",
    "    pos = 0\n",
    "    weights = []\n",
    "    biases = []\n",
    "\n",
    "    for shape in weight_shapes:\n",
    "        size = int(np.prod(shape))\n",
    "        w = vector[pos:pos+size].reshape(shape)\n",
    "        weights.append(w)\n",
    "        pos += size\n",
    "\n",
    "    for shape in bias_shapes:\n",
    "        size = int(np.prod(shape))\n",
    "        b = vector[pos:pos+size].reshape(shape)\n",
    "        biases.append(b)\n",
    "        pos += size\n",
    "\n",
    "    return weights, biases\n",
    "\n",
    "def forward_proba(X_array, vector, weight_shapes, bias_shapes, is_binary=False):\n",
    "    weights, biases = unpack_params(vector, weight_shapes, bias_shapes)\n",
    "    h = X_array\n",
    "    for idx, (w, b) in enumerate(zip(weights, biases)):\n",
    "        h = h @ w + b\n",
    "        if idx < len(weights) - 1:\n",
    "            h = relu(h)\n",
    "    if is_binary:\n",
    "        proba_pos = sigmoid(h).reshape(-1)\n",
    "        return proba_pos\n",
    "    return softmax(h)\n",
    "\n",
    "def cross_entropy_loss(y_true, proba, is_binary=False):\n",
    "    eps = 1e-12\n",
    "    if is_binary:\n",
    "        p = np.clip(proba.reshape(-1), eps, 1.0 - eps)\n",
    "        y_true = y_true.astype(np.float64)\n",
    "        return -np.mean(y_true * np.log(p) + (1.0 - y_true) * np.log(1.0 - p))\n",
    "    p = np.clip(proba[np.arange(len(y_true)), y_true], eps, 1.0)\n",
    "    return -np.mean(np.log(p))\n",
    "\n",
    "def metrics_from_proba(y_true, proba, is_binary=False, threshold=0.5):\n",
    "    if is_binary:\n",
    "        proba_pos = np.asarray(proba).reshape(-1)\n",
    "        y_pred = (proba_pos >= threshold).astype(int)\n",
    "        res = {\n",
    "            \"accuracy\": accuracy_score(y_true, y_pred),\n",
    "            \"balanced_accuracy\": balanced_accuracy_score(y_true, y_pred),\n",
    "            \"precision_macro\": precision_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "            \"recall_macro\": recall_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "            \"f1_macro\": f1_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "        }\n",
    "        try:\n",
    "            res[\"roc_auc\"] = roc_auc_score(y_true, proba_pos)\n",
    "        except Exception:\n",
    "            res[\"roc_auc\"] = np.nan\n",
    "        return res, y_pred\n",
    "    y_pred = np.argmax(proba, axis=1)\n",
    "    res = {\n",
    "        \"accuracy\": accuracy_score(y_true, y_pred),\n",
    "        \"balanced_accuracy\": balanced_accuracy_score(y_true, y_pred),\n",
    "        \"precision_macro\": precision_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "        \"recall_macro\": recall_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "        \"f1_macro\": f1_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "    }\n",
    "    try:\n",
    "        res[\"roc_auc\"] = roc_auc_score(y_true, proba, multi_class=\"ovr\", average=\"macro\")\n",
    "    except Exception:\n",
    "        res[\"roc_auc\"] = np.nan\n",
    "    return res, y_pred\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7eedf9b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Размер подвыборки для PSO: (4000, 1795)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Подвыборка для PSO, чтобы ограничить время работы на больших данных\n",
    "rng = np.random.default_rng(RANDOM_STATE)\n",
    "\n",
    "max_pso_samples = min(4000, len(X_train_bal))\n",
    "if len(X_train_bal) > max_pso_samples:\n",
    "    pso_idx = rng.choice(len(X_train_bal), size=max_pso_samples, replace=False)\n",
    "    X_pso = X_train_bal[pso_idx]\n",
    "    y_pso = y_train_bal[pso_idx]\n",
    "else:\n",
    "    X_pso = X_train_bal.copy()\n",
    "    y_pso = y_train_bal.copy()\n",
    "\n",
    "print(\"Размер подвыборки для PSO:\", X_pso.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3985e855",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>iteration</th>\n",
       "      <th>best_score</th>\n",
       "      <th>best_train_loss</th>\n",
       "      <th>best_val_loss</th>\n",
       "      <th>best_val_f1_macro</th>\n",
       "      <th>mean_particle_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>0.735761</td>\n",
       "      <td>0.625030</td>\n",
       "      <td>0.644012</td>\n",
       "      <td>0.737860</td>\n",
       "      <td>0.836567</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0.735761</td>\n",
       "      <td>0.625030</td>\n",
       "      <td>0.644012</td>\n",
       "      <td>0.737860</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0.693688</td>\n",
       "      <td>0.598501</td>\n",
       "      <td>0.605414</td>\n",
       "      <td>0.747788</td>\n",
       "      <td>0.834892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>0.682013</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.588410</td>\n",
       "      <td>0.732563</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>0.674345</td>\n",
       "      <td>0.573870</td>\n",
       "      <td>0.582176</td>\n",
       "      <td>0.736661</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>0.674345</td>\n",
       "      <td>0.573870</td>\n",
       "      <td>0.582176</td>\n",
       "      <td>0.736661</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>12</td>\n",
       "      <td>0.660400</td>\n",
       "      <td>0.562579</td>\n",
       "      <td>0.594342</td>\n",
       "      <td>0.811262</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>13</td>\n",
       "      <td>0.660400</td>\n",
       "      <td>0.562579</td>\n",
       "      <td>0.594342</td>\n",
       "      <td>0.811262</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>14</td>\n",
       "      <td>0.640824</td>\n",
       "      <td>0.571202</td>\n",
       "      <td>0.567675</td>\n",
       "      <td>0.791004</td>\n",
       "      <td>0.763580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>15</td>\n",
       "      <td>0.611386</td>\n",
       "      <td>0.551305</td>\n",
       "      <td>0.544089</td>\n",
       "      <td>0.807722</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    iteration  best_score  best_train_loss  best_val_loss  best_val_f1_macro  \\\n",
       "5           6    0.735761         0.625030       0.644012           0.737860   \n",
       "6           7    0.735761         0.625030       0.644012           0.737860   \n",
       "7           8    0.693688         0.598501       0.605414           0.747788   \n",
       "8           9    0.682013              NaN       0.588410           0.732563   \n",
       "9          10    0.674345         0.573870       0.582176           0.736661   \n",
       "10         11    0.674345         0.573870       0.582176           0.736661   \n",
       "11         12    0.660400         0.562579       0.594342           0.811262   \n",
       "12         13    0.660400         0.562579       0.594342           0.811262   \n",
       "13         14    0.640824         0.571202       0.567675           0.791004   \n",
       "14         15    0.611386         0.551305       0.544089           0.807722   \n",
       "\n",
       "    mean_particle_score  \n",
       "5              0.836567  \n",
       "6                   NaN  \n",
       "7              0.834892  \n",
       "8                   NaN  \n",
       "9                   NaN  \n",
       "10                  NaN  \n",
       "11                  NaN  \n",
       "12                  NaN  \n",
       "13             0.763580  \n",
       "14                  NaN  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# PSO\n",
    "n_particles = 8\n",
    "n_iterations = 15\n",
    "w_inertia = 0.72\n",
    "c1 = 1.45\n",
    "c2 = 1.45\n",
    "init_scale = 0.08\n",
    "velocity_scale = 0.02\n",
    "\n",
    "particles = rng.normal(0.0, init_scale, size=(n_particles, param_size)).astype(np.float32)\n",
    "velocities = rng.normal(0.0, velocity_scale, size=(n_particles, param_size)).astype(np.float32)\n",
    "\n",
    "personal_best = particles.copy()\n",
    "personal_best_scores = np.full(n_particles, np.inf, dtype=np.float64)\n",
    "\n",
    "global_best = None\n",
    "global_best_score = np.inf\n",
    "global_best_info = {}\n",
    "\n",
    "pso_history = []\n",
    "\n",
    "for iteration in range(n_iterations):\n",
    "    particle_scores = []\n",
    "\n",
    "    for p in range(n_particles):\n",
    "        vec = particles[p]\n",
    "\n",
    "        train_proba = forward_proba(X_pso, vec, weight_shapes, bias_shapes, is_binary=is_binary)\n",
    "        val_proba = forward_proba(X_val_np, vec, weight_shapes, bias_shapes, is_binary=is_binary)\n",
    "\n",
    "        train_loss = cross_entropy_loss(y_pso, train_proba, is_binary=is_binary)\n",
    "        val_loss = cross_entropy_loss(y_val, val_proba, is_binary=is_binary)\n",
    "\n",
    "        val_metrics, _ = metrics_from_proba(y_val, val_proba, is_binary=is_binary)\n",
    "        val_f1 = val_metrics[\"f1_macro\"]\n",
    "\n",
    "        # минимизируем score: меньше loss и выше macro-F1\n",
    "        score = val_loss + 0.35 * (1.0 - val_f1)\n",
    "\n",
    "        particle_scores.append(score)\n",
    "\n",
    "        if score < personal_best_scores[p]:\n",
    "            personal_best_scores[p] = score\n",
    "            personal_best[p] = vec.copy()\n",
    "\n",
    "        if score < global_best_score:\n",
    "            global_best_score = score\n",
    "            global_best = vec.copy()\n",
    "            global_best_info = {\n",
    "                \"best_train_loss\": float(train_loss),\n",
    "                \"best_val_loss\": float(val_loss),\n",
    "                \"best_val_f1_macro\": float(val_f1),\n",
    "            }\n",
    "\n",
    "    # обновление скоростей и частиц\n",
    "    for p in range(n_particles):\n",
    "        r1 = rng.random(param_size, dtype=np.float32)\n",
    "        r2 = rng.random(param_size, dtype=np.float32)\n",
    "\n",
    "        velocities[p] = (\n",
    "            w_inertia * velocities[p]\n",
    "            + c1 * r1 * (personal_best[p] - particles[p])\n",
    "            + c2 * r2 * (global_best - particles[p])\n",
    "        )\n",
    "\n",
    "        particles[p] = particles[p] + velocities[p]\n",
    "        particles[p] = np.clip(particles[p], -2.5, 2.5)\n",
    "\n",
    "    pso_history.append({\n",
    "        \"iteration\": iteration + 1,\n",
    "        \"best_score\": float(global_best_score),\n",
    "        \"best_train_loss\": global_best_info.get(\"best_train_loss\", np.nan),\n",
    "        \"best_val_loss\": global_best_info.get(\"best_val_loss\", np.nan),\n",
    "        \"best_val_f1_macro\": global_best_info.get(\"best_val_f1_macro\", np.nan),\n",
    "        \"mean_particle_score\": float(np.mean(particle_scores)),\n",
    "    })\n",
    "\n",
    "pso_history_df = pd.DataFrame(pso_history)\n",
    "display(pso_history_df.tail(10))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fca1e0f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Метрики модели после PSO на validation:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.809928</td>\n",
       "      <td>0.808611</td>\n",
       "      <td>0.821479</td>\n",
       "      <td>0.808611</td>\n",
       "      <td>0.807722</td>\n",
       "      <td>0.887868</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   accuracy  balanced_accuracy  precision_macro  recall_macro  f1_macro  \\\n",
       "0  0.809928           0.808611         0.821479      0.808611  0.807722   \n",
       "\n",
       "    roc_auc  \n",
       "0  0.887868  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Метрики модели после PSO на test:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.796132</td>\n",
       "      <td>0.794674</td>\n",
       "      <td>0.808873</td>\n",
       "      <td>0.794674</td>\n",
       "      <td>0.793424</td>\n",
       "      <td>0.882042</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   accuracy  balanced_accuracy  precision_macro  recall_macro  f1_macro  \\\n",
       "0  0.796132           0.794674         0.808873      0.794674  0.793424   \n",
       "\n",
       "    roc_auc  \n",
       "0  0.882042  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Качество сети сразу после PSO\n",
    "pso_val_proba = forward_proba(X_val_np, global_best, weight_shapes, bias_shapes, is_binary=is_binary)\n",
    "pso_test_proba = forward_proba(X_test_np, global_best, weight_shapes, bias_shapes, is_binary=is_binary)\n",
    "\n",
    "pso_val_metrics, pso_val_pred = metrics_from_proba(y_val, pso_val_proba, is_binary=is_binary)\n",
    "pso_test_metrics, pso_test_pred = metrics_from_proba(y_test, pso_test_proba, is_binary=is_binary)\n",
    "\n",
    "print(\"Метрики модели после PSO на validation:\")\n",
    "display(pd.DataFrame([pso_val_metrics]))\n",
    "print(\"Метрики модели после PSO на test:\")\n",
    "display(pd.DataFrame([pso_test_metrics]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6544f719",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>val_accuracy</th>\n",
       "      <th>val_balanced_accuracy</th>\n",
       "      <th>val_f1_macro</th>\n",
       "      <th>val_roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>0.002394</td>\n",
       "      <td>0.996734</td>\n",
       "      <td>0.996743</td>\n",
       "      <td>0.996734</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>12</td>\n",
       "      <td>0.002204</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997405</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>13</td>\n",
       "      <td>0.002066</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997405</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>14</td>\n",
       "      <td>0.001955</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997405</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>15</td>\n",
       "      <td>0.001864</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>16</td>\n",
       "      <td>0.001787</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>17</td>\n",
       "      <td>0.001722</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>18</td>\n",
       "      <td>0.001665</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>19</td>\n",
       "      <td>0.001615</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>20</td>\n",
       "      <td>0.001570</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    epoch  train_loss  val_accuracy  val_balanced_accuracy  val_f1_macro  \\\n",
       "10     11    0.002394      0.996734               0.996743      0.996734   \n",
       "11     12    0.002204      0.997387               0.997405      0.997387   \n",
       "12     13    0.002066      0.997387               0.997405      0.997387   \n",
       "13     14    0.001955      0.997387               0.997405      0.997387   \n",
       "14     15    0.001864      0.998040               0.998067      0.998040   \n",
       "15     16    0.001787      0.998040               0.998067      0.998040   \n",
       "16     17    0.001722      0.998040               0.998067      0.998040   \n",
       "17     18    0.001665      0.998040               0.998067      0.998040   \n",
       "18     19    0.001615      0.998040               0.998067      0.998040   \n",
       "19     20    0.001570      0.998040               0.998067      0.998040   \n",
       "\n",
       "    val_roc_auc  \n",
       "10     0.997418  \n",
       "11     0.997418  \n",
       "12     0.997418  \n",
       "13     0.997418  \n",
       "14     0.997418  \n",
       "15     0.997419  \n",
       "16     0.997419  \n",
       "17     0.997419  \n",
       "18     0.997419  \n",
       "19     0.997419  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Инициализация и дообучение MLPClassifier весами из PSO\n",
    "mlp = MLPClassifier(\n",
    "    hidden_layer_sizes=hidden_layers,\n",
    "    activation=\"relu\",\n",
    "    solver=\"adam\",\n",
    "    alpha=1e-4,\n",
    "    batch_size=256,\n",
    "    learning_rate_init=1e-3,\n",
    "    max_iter=1,\n",
    "    warm_start=True,\n",
    "    shuffle=True,\n",
    "    random_state=RANDOM_STATE\n",
    ")\n",
    "\n",
    "# первый вызов partial_fit создаёт внутренние структуры модели\n",
    "mlp.partial_fit(X_train_bal, y_train_bal, classes=np.arange(n_classes))\n",
    "\n",
    "# перенос весов из PSO\n",
    "weights_best, biases_best = unpack_params(global_best, weight_shapes, bias_shapes)\n",
    "mlp.coefs_ = [w.astype(np.float64) for w in weights_best]\n",
    "mlp.intercepts_ = [b.astype(np.float64) for b in biases_best]\n",
    "\n",
    "fine_tune_history = []\n",
    "best_state = {\n",
    "    \"coefs_\": [w.copy() for w in mlp.coefs_],\n",
    "    \"intercepts_\": [b.copy() for b in mlp.intercepts_],\n",
    "    \"best_val_f1_macro\": -np.inf\n",
    "}\n",
    "\n",
    "n_epochs = 20\n",
    "for epoch in range(n_epochs):\n",
    "    mlp.partial_fit(X_train_bal, y_train_bal)\n",
    "\n",
    "    val_pred = np.asarray(mlp.predict(X_val_np)).reshape(-1)\n",
    "    val_proba_raw = mlp.predict_proba(X_val_np)\n",
    "\n",
    "    if is_binary:\n",
    "        if val_proba_raw.ndim == 2:\n",
    "            val_proba_for_metric = val_proba_raw[:, 1]\n",
    "        else:\n",
    "            val_proba_for_metric = np.asarray(val_proba_raw).reshape(-1)\n",
    "    else:\n",
    "        val_proba_for_metric = val_proba_raw\n",
    "\n",
    "    row = {\n",
    "        \"epoch\": epoch + 1,\n",
    "        \"train_loss\": float(mlp.loss_),\n",
    "        \"val_accuracy\": accuracy_score(y_val, val_pred),\n",
    "        \"val_balanced_accuracy\": balanced_accuracy_score(y_val, val_pred),\n",
    "        \"val_f1_macro\": f1_score(y_val, val_pred, average=\"macro\", zero_division=0),\n",
    "    }\n",
    "    try:\n",
    "        if is_binary:\n",
    "            row[\"val_roc_auc\"] = roc_auc_score(y_val, val_proba_for_metric)\n",
    "        else:\n",
    "            row[\"val_roc_auc\"] = roc_auc_score(y_val, val_proba_for_metric, multi_class=\"ovr\", average=\"macro\")\n",
    "    except Exception:\n",
    "        row[\"val_roc_auc\"] = np.nan\n",
    "\n",
    "    fine_tune_history.append(row)\n",
    "\n",
    "    if row[\"val_f1_macro\"] > best_state[\"best_val_f1_macro\"]:\n",
    "        best_state = {\n",
    "            \"coefs_\": [w.copy() for w in mlp.coefs_],\n",
    "            \"intercepts_\": [b.copy() for b in mlp.intercepts_],\n",
    "            \"best_val_f1_macro\": row[\"val_f1_macro\"]\n",
    "        }\n",
    "\n",
    "fine_tune_history_df = pd.DataFrame(fine_tune_history)\n",
    "display(fine_tune_history_df.tail(10))\n",
    "\n",
    "# восстановление лучшей версии модели по macro-F1 на validation\n",
    "mlp.coefs_ = [w.copy() for w in best_state[\"coefs_\"]]\n",
    "mlp.intercepts_ = [b.copy() for b in best_state[\"intercepts_\"]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "41785391",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Метрики финальной модели на validation:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.99804</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998021</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.99804</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   accuracy  balanced_accuracy  precision_macro  recall_macro  f1_macro  \\\n",
       "0   0.99804           0.998067         0.998021      0.998067   0.99804   \n",
       "\n",
       "    roc_auc  \n",
       "0  0.997418  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Метрики финальной модели на test:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.996864</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.998922</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   accuracy  balanced_accuracy  precision_macro  recall_macro  f1_macro  \\\n",
       "0  0.996864           0.996863         0.996863      0.996863  0.996863   \n",
       "\n",
       "    roc_auc  \n",
       "0  0.998922  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Финальная оценка\n",
    "ft_val_pred = np.asarray(mlp.predict(X_val_np)).reshape(-1)\n",
    "ft_val_proba_raw = mlp.predict_proba(X_val_np)\n",
    "\n",
    "ft_test_pred = np.asarray(mlp.predict(X_test_np)).reshape(-1)\n",
    "ft_test_proba_raw = mlp.predict_proba(X_test_np)\n",
    "\n",
    "if is_binary:\n",
    "    ft_val_proba = ft_val_proba_raw[:, 1] if np.ndim(ft_val_proba_raw) == 2 else np.asarray(ft_val_proba_raw).reshape(-1)\n",
    "    ft_test_proba = ft_test_proba_raw[:, 1] if np.ndim(ft_test_proba_raw) == 2 else np.asarray(ft_test_proba_raw).reshape(-1)\n",
    "else:\n",
    "    ft_val_proba = ft_val_proba_raw\n",
    "    ft_test_proba = ft_test_proba_raw\n",
    "\n",
    "def collect_metrics(y_true, y_pred, y_proba, is_binary=False):\n",
    "    res = {\n",
    "        \"accuracy\": accuracy_score(y_true, y_pred),\n",
    "        \"balanced_accuracy\": balanced_accuracy_score(y_true, y_pred),\n",
    "        \"precision_macro\": precision_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "        \"recall_macro\": recall_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "        \"f1_macro\": f1_score(y_true, y_pred, average=\"macro\", zero_division=0),\n",
    "    }\n",
    "    try:\n",
    "        if is_binary:\n",
    "            res[\"roc_auc\"] = roc_auc_score(y_true, y_proba)\n",
    "        else:\n",
    "            res[\"roc_auc\"] = roc_auc_score(y_true, y_proba, multi_class=\"ovr\", average=\"macro\")\n",
    "    except Exception:\n",
    "        res[\"roc_auc\"] = np.nan\n",
    "    return res\n",
    "\n",
    "ft_val_metrics = collect_metrics(y_val, ft_val_pred, ft_val_proba, is_binary=is_binary)\n",
    "ft_test_metrics = collect_metrics(y_test, ft_test_pred, ft_test_proba, is_binary=is_binary)\n",
    "\n",
    "print(\"Метрики финальной модели на validation:\")\n",
    "display(pd.DataFrame([ft_val_metrics]))\n",
    "print(\"Метрики финальной модели на test:\")\n",
    "display(pd.DataFrame([ft_test_metrics]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "54770815",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Глубокая нейросеть: PSO + дообучение</td>\n",
       "      <td>0.996864</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.998922</td>\n",
       "      <td>45</td>\n",
       "      <td>2240</td>\n",
       "      <td>(128, 64, 32, 16)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Глубокая нейросеть после PSO</td>\n",
       "      <td>0.796132</td>\n",
       "      <td>0.794674</td>\n",
       "      <td>0.808873</td>\n",
       "      <td>0.794674</td>\n",
       "      <td>0.793424</td>\n",
       "      <td>0.882042</td>\n",
       "      <td>45</td>\n",
       "      <td>2240</td>\n",
       "      <td>(128, 64, 32, 16)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                  model  accuracy  balanced_accuracy  \\\n",
       "0  Глубокая нейросеть: PSO + дообучение  0.996864           0.996863   \n",
       "1          Глубокая нейросеть после PSO  0.796132           0.794674   \n",
       "\n",
       "   precision_macro  recall_macro  f1_macro   roc_auc  n_original_features  \\\n",
       "0         0.996863      0.996863  0.996863  0.998922                   45   \n",
       "1         0.808873      0.794674  0.793424  0.882042                   45   \n",
       "\n",
       "   n_processed_features      hidden_layers  \n",
       "0                  2240  (128, 64, 32, 16)  \n",
       "1                  2240  (128, 64, 32, 16)  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Лучшая модель:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Глубокая нейросеть: PSO + дообучение</td>\n",
       "      <td>0.996864</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.998922</td>\n",
       "      <td>45</td>\n",
       "      <td>2240</td>\n",
       "      <td>(128, 64, 32, 16)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                  model  accuracy  balanced_accuracy  \\\n",
       "0  Глубокая нейросеть: PSO + дообучение  0.996864           0.996863   \n",
       "\n",
       "   precision_macro  recall_macro  f1_macro   roc_auc  n_original_features  \\\n",
       "0         0.996863      0.996863  0.996863  0.998922                   45   \n",
       "\n",
       "   n_processed_features      hidden_layers  \n",
       "0                  2240  (128, 64, 32, 16)  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Сравнение моделей\n",
    "results_df = pd.DataFrame([\n",
    "    {\n",
    "        \"model\": \"Глубокая нейросеть после PSO\",\n",
    "        \"accuracy\": pso_test_metrics[\"accuracy\"],\n",
    "        \"balanced_accuracy\": pso_test_metrics[\"balanced_accuracy\"],\n",
    "        \"precision_macro\": pso_test_metrics[\"precision_macro\"],\n",
    "        \"recall_macro\": pso_test_metrics[\"recall_macro\"],\n",
    "        \"f1_macro\": pso_test_metrics[\"f1_macro\"],\n",
    "        \"roc_auc\": pso_test_metrics[\"roc_auc\"],\n",
    "        \"n_original_features\": X.shape[1],\n",
    "        \"n_processed_features\": len(feature_names),\n",
    "        \"hidden_layers\": str(hidden_layers),\n",
    "    },\n",
    "    {\n",
    "        \"model\": \"Глубокая нейросеть: PSO + дообучение\",\n",
    "        \"accuracy\": ft_test_metrics[\"accuracy\"],\n",
    "        \"balanced_accuracy\": ft_test_metrics[\"balanced_accuracy\"],\n",
    "        \"precision_macro\": ft_test_metrics[\"precision_macro\"],\n",
    "        \"recall_macro\": ft_test_metrics[\"recall_macro\"],\n",
    "        \"f1_macro\": ft_test_metrics[\"f1_macro\"],\n",
    "        \"roc_auc\": ft_test_metrics[\"roc_auc\"],\n",
    "        \"n_original_features\": X.shape[1],\n",
    "        \"n_processed_features\": len(feature_names),\n",
    "        \"hidden_layers\": str(hidden_layers),\n",
    "    }\n",
    "]).sort_values([\"f1_macro\", \"balanced_accuracy\", \"accuracy\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "display(results_df)\n",
    "best_model_df = results_df.head(1)\n",
    "print(\"Лучшая модель:\")\n",
    "display(best_model_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "434faa41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Матрица ошибок для модели после PSO:\n",
      "[[652 291]\n",
      " [ 99 871]]\n",
      "\n",
      "Матрица ошибок для финальной модели:\n",
      "[[940   3]\n",
      " [  3 967]]\n",
      "\n",
      "Классы:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class_index</th>\n",
       "      <th>class_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   class_index     class_label\n",
       "0            0       CANDIDATE\n",
       "1            1  FALSE POSITIVE"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Классификационный отчёт для финальной модели:\n",
      "                precision    recall  f1-score   support\n",
      "\n",
      "     CANDIDATE       1.00      1.00      1.00       943\n",
      "FALSE POSITIVE       1.00      1.00      1.00       970\n",
      "\n",
      "      accuracy                           1.00      1913\n",
      "     macro avg       1.00      1.00      1.00      1913\n",
      "  weighted avg       1.00      1.00      1.00      1913\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Матрицы ошибок и короткие отчёты\n",
    "cm_pso = confusion_matrix(y_test, pso_test_pred)\n",
    "cm_ft = confusion_matrix(y_test, ft_test_pred)\n",
    "\n",
    "print(\"Матрица ошибок для модели после PSO:\")\n",
    "print(cm_pso)\n",
    "\n",
    "print(\"\\nМатрица ошибок для финальной модели:\")\n",
    "print(cm_ft)\n",
    "\n",
    "print(\"\\nКлассы:\")\n",
    "display(class_mapping)\n",
    "\n",
    "print(\"\\nКлассификационный отчёт для финальной модели:\")\n",
    "print(classification_report(y_test, ft_test_pred, target_names=label_encoder.classes_, zero_division=0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "9e10c041",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>y_true</th>\n",
       "      <th>y_pred_pso</th>\n",
       "      <th>y_pred_final</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           y_true      y_pred_pso    y_pred_final\n",
       "0       CANDIDATE       CANDIDATE       CANDIDATE\n",
       "1  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "2  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "3  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "4  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "5       CANDIDATE       CANDIDATE       CANDIDATE\n",
       "6  FALSE POSITIVE       CANDIDATE  FALSE POSITIVE\n",
       "7       CANDIDATE       CANDIDATE       CANDIDATE\n",
       "8  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "9  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Первые 10 фактических и предсказанных значений\n",
    "pred_preview = pd.DataFrame({\n",
    "    \"y_true\": label_encoder.inverse_transform(y_test[:10]),\n",
    "    \"y_pred_pso\": label_encoder.inverse_transform(pso_test_pred[:10]),\n",
    "    \"y_pred_final\": label_encoder.inverse_transform(ft_test_pred[:10]),\n",
    "})\n",
    "display(pred_preview)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "43e463ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== КЛЮЧЕВАЯ ИНФОРМАЦИЯ ДЛЯ ВЫВОДА ===\n",
      "Зависимая переменная: koi_pdisposition\n",
      "\n",
      "Число исходных признаков: 45\n",
      "Число признаков после кодирования: 2240\n",
      "Архитектура скрытых слоёв: (128, 64, 32, 16)\n",
      "\n",
      "Распределение классов:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "koi_pdisposition\n",
       "CANDIDATE         0.493204\n",
       "FALSE POSITIVE    0.506796\n",
       "Name: share, dtype: float64"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Использованные исходные признаки:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>kepid</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>kepler_name</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>koi_disposition</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>koi_score</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>koi_fpflag_nt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>koi_fpflag_ss</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>koi_fpflag_co</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>koi_fpflag_ec</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>koi_period</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>koi_period_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>koi_period_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>koi_time0bk</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>koi_time0bk_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>koi_time0bk_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>koi_impact</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>koi_impact_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>koi_impact_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>koi_duration</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>koi_duration_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>koi_duration_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>koi_depth</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>koi_depth_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>koi_depth_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>koi_prad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>koi_prad_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>koi_prad_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>koi_teq</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>koi_insol</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>koi_insol_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>koi_insol_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>koi_model_snr</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>koi_tce_plnt_num</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>koi_tce_delivname</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>koi_steff</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>koi_steff_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>koi_steff_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>koi_slogg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>koi_slogg_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>koi_slogg_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>koi_srad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>koi_srad_err1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>koi_srad_err2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>ra</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>dec</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>koi_kepmag</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              feature\n",
       "0               kepid\n",
       "1         kepler_name\n",
       "2     koi_disposition\n",
       "3           koi_score\n",
       "4       koi_fpflag_nt\n",
       "5       koi_fpflag_ss\n",
       "6       koi_fpflag_co\n",
       "7       koi_fpflag_ec\n",
       "8          koi_period\n",
       "9     koi_period_err1\n",
       "10    koi_period_err2\n",
       "11        koi_time0bk\n",
       "12   koi_time0bk_err1\n",
       "13   koi_time0bk_err2\n",
       "14         koi_impact\n",
       "15    koi_impact_err1\n",
       "16    koi_impact_err2\n",
       "17       koi_duration\n",
       "18  koi_duration_err1\n",
       "19  koi_duration_err2\n",
       "20          koi_depth\n",
       "21     koi_depth_err1\n",
       "22     koi_depth_err2\n",
       "23           koi_prad\n",
       "24      koi_prad_err1\n",
       "25      koi_prad_err2\n",
       "26            koi_teq\n",
       "27          koi_insol\n",
       "28     koi_insol_err1\n",
       "29     koi_insol_err2\n",
       "30      koi_model_snr\n",
       "31   koi_tce_plnt_num\n",
       "32  koi_tce_delivname\n",
       "33          koi_steff\n",
       "34     koi_steff_err1\n",
       "35     koi_steff_err2\n",
       "36          koi_slogg\n",
       "37     koi_slogg_err1\n",
       "38     koi_slogg_err2\n",
       "39           koi_srad\n",
       "40      koi_srad_err1\n",
       "41      koi_srad_err2\n",
       "42                 ra\n",
       "43                dec\n",
       "44         koi_kepmag"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Метрики моделей на тестовой выборке:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Глубокая нейросеть: PSO + дообучение</td>\n",
       "      <td>0.996864</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.998922</td>\n",
       "      <td>45</td>\n",
       "      <td>2240</td>\n",
       "      <td>(128, 64, 32, 16)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Глубокая нейросеть после PSO</td>\n",
       "      <td>0.796132</td>\n",
       "      <td>0.794674</td>\n",
       "      <td>0.808873</td>\n",
       "      <td>0.794674</td>\n",
       "      <td>0.793424</td>\n",
       "      <td>0.882042</td>\n",
       "      <td>45</td>\n",
       "      <td>2240</td>\n",
       "      <td>(128, 64, 32, 16)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                  model  accuracy  balanced_accuracy  \\\n",
       "0  Глубокая нейросеть: PSO + дообучение  0.996864           0.996863   \n",
       "1          Глубокая нейросеть после PSO  0.796132           0.794674   \n",
       "\n",
       "   precision_macro  recall_macro  f1_macro   roc_auc  n_original_features  \\\n",
       "0         0.996863      0.996863  0.996863  0.998922                   45   \n",
       "1         0.808873      0.794674  0.793424  0.882042                   45   \n",
       "\n",
       "   n_processed_features      hidden_layers  \n",
       "0                  2240  (128, 64, 32, 16)  \n",
       "1                  2240  (128, 64, 32, 16)  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Лучшая модель:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>balanced_accuracy</th>\n",
       "      <th>precision_macro</th>\n",
       "      <th>recall_macro</th>\n",
       "      <th>f1_macro</th>\n",
       "      <th>roc_auc</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Глубокая нейросеть: PSO + дообучение</td>\n",
       "      <td>0.996864</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.996863</td>\n",
       "      <td>0.998922</td>\n",
       "      <td>45</td>\n",
       "      <td>2240</td>\n",
       "      <td>(128, 64, 32, 16)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                  model  accuracy  balanced_accuracy  \\\n",
       "0  Глубокая нейросеть: PSO + дообучение  0.996864           0.996863   \n",
       "\n",
       "   precision_macro  recall_macro  f1_macro   roc_auc  n_original_features  \\\n",
       "0         0.996863      0.996863  0.996863  0.998922                   45   \n",
       "\n",
       "   n_processed_features      hidden_layers  \n",
       "0                  2240  (128, 64, 32, 16)  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Последние итерации PSO:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>iteration</th>\n",
       "      <th>best_score</th>\n",
       "      <th>best_train_loss</th>\n",
       "      <th>best_val_loss</th>\n",
       "      <th>best_val_f1_macro</th>\n",
       "      <th>mean_particle_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>0.735761</td>\n",
       "      <td>0.625030</td>\n",
       "      <td>0.644012</td>\n",
       "      <td>0.737860</td>\n",
       "      <td>0.836567</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0.735761</td>\n",
       "      <td>0.625030</td>\n",
       "      <td>0.644012</td>\n",
       "      <td>0.737860</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0.693688</td>\n",
       "      <td>0.598501</td>\n",
       "      <td>0.605414</td>\n",
       "      <td>0.747788</td>\n",
       "      <td>0.834892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>0.682013</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.588410</td>\n",
       "      <td>0.732563</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>0.674345</td>\n",
       "      <td>0.573870</td>\n",
       "      <td>0.582176</td>\n",
       "      <td>0.736661</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>0.674345</td>\n",
       "      <td>0.573870</td>\n",
       "      <td>0.582176</td>\n",
       "      <td>0.736661</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>12</td>\n",
       "      <td>0.660400</td>\n",
       "      <td>0.562579</td>\n",
       "      <td>0.594342</td>\n",
       "      <td>0.811262</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>13</td>\n",
       "      <td>0.660400</td>\n",
       "      <td>0.562579</td>\n",
       "      <td>0.594342</td>\n",
       "      <td>0.811262</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>14</td>\n",
       "      <td>0.640824</td>\n",
       "      <td>0.571202</td>\n",
       "      <td>0.567675</td>\n",
       "      <td>0.791004</td>\n",
       "      <td>0.763580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>15</td>\n",
       "      <td>0.611386</td>\n",
       "      <td>0.551305</td>\n",
       "      <td>0.544089</td>\n",
       "      <td>0.807722</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    iteration  best_score  best_train_loss  best_val_loss  best_val_f1_macro  \\\n",
       "5           6    0.735761         0.625030       0.644012           0.737860   \n",
       "6           7    0.735761         0.625030       0.644012           0.737860   \n",
       "7           8    0.693688         0.598501       0.605414           0.747788   \n",
       "8           9    0.682013              NaN       0.588410           0.732563   \n",
       "9          10    0.674345         0.573870       0.582176           0.736661   \n",
       "10         11    0.674345         0.573870       0.582176           0.736661   \n",
       "11         12    0.660400         0.562579       0.594342           0.811262   \n",
       "12         13    0.660400         0.562579       0.594342           0.811262   \n",
       "13         14    0.640824         0.571202       0.567675           0.791004   \n",
       "14         15    0.611386         0.551305       0.544089           0.807722   \n",
       "\n",
       "    mean_particle_score  \n",
       "5              0.836567  \n",
       "6                   NaN  \n",
       "7              0.834892  \n",
       "8                   NaN  \n",
       "9                   NaN  \n",
       "10                  NaN  \n",
       "11                  NaN  \n",
       "12                  NaN  \n",
       "13             0.763580  \n",
       "14                  NaN  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Последние эпохи дообучения:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>val_accuracy</th>\n",
       "      <th>val_balanced_accuracy</th>\n",
       "      <th>val_f1_macro</th>\n",
       "      <th>val_roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>0.002394</td>\n",
       "      <td>0.996734</td>\n",
       "      <td>0.996743</td>\n",
       "      <td>0.996734</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>12</td>\n",
       "      <td>0.002204</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997405</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>13</td>\n",
       "      <td>0.002066</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997405</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>14</td>\n",
       "      <td>0.001955</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997405</td>\n",
       "      <td>0.997387</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>15</td>\n",
       "      <td>0.001864</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>16</td>\n",
       "      <td>0.001787</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>17</td>\n",
       "      <td>0.001722</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>18</td>\n",
       "      <td>0.001665</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>19</td>\n",
       "      <td>0.001615</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>20</td>\n",
       "      <td>0.001570</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.998067</td>\n",
       "      <td>0.998040</td>\n",
       "      <td>0.997419</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    epoch  train_loss  val_accuracy  val_balanced_accuracy  val_f1_macro  \\\n",
       "10     11    0.002394      0.996734               0.996743      0.996734   \n",
       "11     12    0.002204      0.997387               0.997405      0.997387   \n",
       "12     13    0.002066      0.997387               0.997405      0.997387   \n",
       "13     14    0.001955      0.997387               0.997405      0.997387   \n",
       "14     15    0.001864      0.998040               0.998067      0.998040   \n",
       "15     16    0.001787      0.998040               0.998067      0.998040   \n",
       "16     17    0.001722      0.998040               0.998067      0.998040   \n",
       "17     18    0.001665      0.998040               0.998067      0.998040   \n",
       "18     19    0.001615      0.998040               0.998067      0.998040   \n",
       "19     20    0.001570      0.998040               0.998067      0.998040   \n",
       "\n",
       "    val_roc_auc  \n",
       "10     0.997418  \n",
       "11     0.997418  \n",
       "12     0.997418  \n",
       "13     0.997418  \n",
       "14     0.997418  \n",
       "15     0.997419  \n",
       "16     0.997419  \n",
       "17     0.997419  \n",
       "18     0.997419  \n",
       "19     0.997419  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Первые 10 фактических и предсказанных значений на тесте для финальной модели:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>y_true</th>\n",
       "      <th>y_pred_pso</th>\n",
       "      <th>y_pred_final</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "      <td>CANDIDATE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "      <td>FALSE POSITIVE</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           y_true      y_pred_pso    y_pred_final\n",
       "0       CANDIDATE       CANDIDATE       CANDIDATE\n",
       "1  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "2  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "3  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "4  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "5       CANDIDATE       CANDIDATE       CANDIDATE\n",
       "6  FALSE POSITIVE       CANDIDATE  FALSE POSITIVE\n",
       "7       CANDIDATE       CANDIDATE       CANDIDATE\n",
       "8  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE\n",
       "9  FALSE POSITIVE  FALSE POSITIVE  FALSE POSITIVE"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Матрица ошибок для модели после PSO:\n",
      "[[652 291]\n",
      " [ 99 871]]\n",
      "Матрица ошибок для финальной модели:\n",
      "[[940   3]\n",
      " [  3 967]]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(\"=== КЛЮЧЕВАЯ ИНФОРМАЦИЯ ДЛЯ ВЫВОДА ===\")\n",
    "print(\"Зависимая переменная:\", target_col)\n",
    "print()\n",
    "print(\"Число исходных признаков:\", X.shape[1])\n",
    "print(\"Число признаков после кодирования:\", len(feature_names))\n",
    "print(\"Архитектура скрытых слоёв:\", hidden_layers)\n",
    "print()\n",
    "print(\"Распределение классов:\")\n",
    "display(class_share)\n",
    "print(\"Использованные исходные признаки:\")\n",
    "display(pd.DataFrame({\"feature\": X.columns}))\n",
    "print(\"Метрики моделей на тестовой выборке:\")\n",
    "display(results_df)\n",
    "print(\"Лучшая модель:\")\n",
    "display(best_model_df)\n",
    "print(\"Последние итерации PSO:\")\n",
    "display(pso_history_df.tail(10))\n",
    "print(\"Последние эпохи дообучения:\")\n",
    "display(fine_tune_history_df.tail(10))\n",
    "print(\"Первые 10 фактических и предсказанных значений на тесте для финальной модели:\")\n",
    "display(pred_preview)\n",
    "print(\"Матрица ошибок для модели после PSO:\")\n",
    "print(cm_pso)\n",
    "print(\"Матрица ошибок для финальной модели:\")\n",
    "print(cm_ft)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f6e90a3",
   "metadata": {},
   "source": [
    "Итог\n",
    "\n",
    "Зависимой переменной выбрана koi_pdisposition. Построена глубокая нейросеть с архитектурой (128, 64, 32, 16), сначала обученная методом PSO, затем дообученная градиентным методом.\n",
    "\n",
    "Лучшая модель — PSO + дообучение: на тестовой выборке получено accuracy = 0.9969, balanced accuracy = 0.9969, F1-macro = 0.9969, ROC-AUC = 0.9989. По сравнению с моделью только после PSO качество резко улучшилось (accuracy: 0.796 → 0.997).\n",
    "\n",
    "Следовательно, дообучение после весовой эволюции существенно повышает точность классификации, и лучшей является финальная модель PSO + fine-tuning."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
