{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3abb67c4",
   "metadata": {},
   "source": [
    "# Регрессия урожайности: нейросети, отбор признаков, эволюционный поиск архитектуры и прототипы\n",
    "\n",
    "В ноутбуке:\n",
    "- автоматически загружается датасет и нормализуются названия столбцов;\n",
    "- определяется ключевая зависимая переменная;\n",
    "- строится базовая нейросеть регрессии по всем признакам;\n",
    "- оценивается важность признаков;\n",
    "- по отобранным признакам строится вторая нейросеть с эволюционным поиском архитектуры;\n",
    "- выбирается лучшая модель на тестовой выборке;\n",
    "- по лучшей модели определяются основные прототипы и интерпретируются их признаки.\n",
    "\n",
    "Финального текстового вывода в ноутбуке нет. В последней ячейке печатается вся информация, необходимая для корректного вывода по результатам выполнения."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "23640eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import copy\n",
    "import math\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.base import BaseEstimator, RegressorMixin\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error\n",
    "from sklearn.inspection import permutation_importance\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "pd.set_option(\"display.max_columns\", 200)\n",
    "pd.set_option(\"display.width\", 200)\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "40f02f03",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_csv_file():\n",
    "    candidates = [\n",
    "        \"yield_prediction_dataset.csv\",\n",
    "        \"/mnt/data/yield_prediction_dataset.csv\",\n",
    "        \"Yield_Prediction_Dataset.csv\",\n",
    "        \"/mnt/data/Yield_Prediction_Dataset.csv\",\n",
    "        \"crop_yield_prediction.csv\",\n",
    "        \"/mnt/data/crop_yield_prediction.csv\",\n",
    "    ]\n",
    "    for path in candidates:\n",
    "        if os.path.exists(path):\n",
    "            return path\n",
    "\n",
    "    for folder in [\".\", \"/mnt/data\"]:\n",
    "        if os.path.isdir(folder):\n",
    "            for name in os.listdir(folder):\n",
    "                low = name.lower()\n",
    "                if low.endswith(\".csv\") and (\"yield\" in low or \"crop\" in low):\n",
    "                    return os.path.join(folder, name)\n",
    "\n",
    "    raise FileNotFoundError(\"CSV-файл датасета не найден.\")\n",
    "\n",
    "\n",
    "def normalize_column_name(col):\n",
    "    col = str(col).strip().lower()\n",
    "    col = col.replace(\"%\", \" pct \")\n",
    "    col = col.replace(\"/\", \" \")\n",
    "    col = col.replace(\"-\", \"_\")\n",
    "    col = col.replace(\"(\", \" \").replace(\")\", \" \")\n",
    "    col = re.sub(r\"[^a-z0-9а-я_]+\", \"_\", col)\n",
    "    col = re.sub(r\"_+\", \"_\", col).strip(\"_\")\n",
    "    return col\n",
    "\n",
    "\n",
    "def clean_dataframe(df):\n",
    "    df = df.copy()\n",
    "    original_cols = df.columns.tolist()\n",
    "    df.columns = [normalize_column_name(c) for c in df.columns]\n",
    "\n",
    "    # удаляем полностью пустые и служебные столбцы\n",
    "    drop_cols = []\n",
    "    for c in df.columns:\n",
    "        if c.startswith(\"unnamed\"):\n",
    "            drop_cols.append(c)\n",
    "            continue\n",
    "        if df[c].isna().all():\n",
    "            drop_cols.append(c)\n",
    "            continue\n",
    "        if df[c].astype(str).str.strip().replace({\"\": np.nan, \"nan\": np.nan, \"None\": np.nan}).isna().all():\n",
    "            drop_cols.append(c)\n",
    "    df = df.drop(columns=drop_cols, errors=\"ignore\")\n",
    "\n",
    "    # дата\n",
    "    if \"date_of_image\" in df.columns:\n",
    "        dt = pd.to_datetime(df[\"date_of_image\"], dayfirst=True, errors=\"coerce\")\n",
    "        df[\"date_year\"] = dt.dt.year\n",
    "        df[\"date_month\"] = dt.dt.month\n",
    "        df[\"date_day\"] = dt.dt.day\n",
    "        df[\"date_dayofyear\"] = dt.dt.dayofyear\n",
    "        df = df.drop(columns=[\"date_of_image\"])\n",
    "\n",
    "    return df, original_cols\n",
    "\n",
    "\n",
    "def detect_target_column(df):\n",
    "    preferred = [\n",
    "        \"yield\", \"target\", \"price\", \"sales\", \"score\", \"output\"\n",
    "    ]\n",
    "    for col in preferred:\n",
    "        if col in df.columns:\n",
    "            return col\n",
    "\n",
    "    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()\n",
    "    if not numeric_cols:\n",
    "        raise ValueError(\"Не найден числовой столбец для регрессии.\")\n",
    "    return numeric_cols[-1]\n",
    "\n",
    "\n",
    "def split_features(df, target_col):\n",
    "    X = df.drop(columns=[target_col]).copy()\n",
    "    y = pd.to_numeric(df[target_col], errors=\"coerce\")\n",
    "    mask = y.notna()\n",
    "    X = X.loc[mask].reset_index(drop=True)\n",
    "    y = y.loc[mask].reset_index(drop=True)\n",
    "\n",
    "    # попытка привести объектные числовые колонки к числу\n",
    "    for c in X.columns:\n",
    "        if X[c].dtype == \"object\":\n",
    "            converted = pd.to_numeric(X[c], errors=\"coerce\")\n",
    "            # если значительная доля значений распарсилась как число, считаем столбец числовым\n",
    "            if converted.notna().mean() >= 0.8:\n",
    "                X[c] = converted\n",
    "\n",
    "    numeric_features = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "    categorical_features = [c for c in X.columns if c not in numeric_features]\n",
    "    return X, y, numeric_features, categorical_features\n",
    "\n",
    "\n",
    "def make_preprocessor(numeric_features, categorical_features):\n",
    "    num_pipe = Pipeline([\n",
    "        (\"imputer\", SimpleImputer(strategy=\"median\")),\n",
    "        (\"scaler\", StandardScaler())\n",
    "    ])\n",
    "    cat_pipe = Pipeline([\n",
    "        (\"imputer\", SimpleImputer(strategy=\"most_frequent\")),\n",
    "        (\"onehot\", OneHotEncoder(handle_unknown=\"ignore\"))\n",
    "    ])\n",
    "    return ColumnTransformer([\n",
    "        (\"num\", num_pipe, numeric_features),\n",
    "        (\"cat\", cat_pipe, categorical_features)\n",
    "    ], remainder=\"drop\", sparse_threshold=0.0)\n",
    "\n",
    "\n",
    "def regression_metrics(y_true, y_pred):\n",
    "    return {\n",
    "        \"r2\": float(r2_score(y_true, y_pred)),\n",
    "        \"rmse\": float(np.sqrt(mean_squared_error(y_true, y_pred))),\n",
    "        \"mae\": float(mean_absolute_error(y_true, y_pred)),\n",
    "    }\n",
    "\n",
    "\n",
    "class InverseScaledRegressorWrapper(RegressorMixin, BaseEstimator):\n",
    "    def __init__(self, model, y_scaler):\n",
    "        self.model = model\n",
    "        self.y_scaler = y_scaler\n",
    "\n",
    "    def fit(self, X, y=None):\n",
    "        return self\n",
    "\n",
    "    def predict(self, X):\n",
    "        pred_scaled = self.model.predict(X)\n",
    "        return self.y_scaler.inverse_transform(np.asarray(pred_scaled).reshape(-1, 1)).ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "37b0252a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Используемый файл: yield_prediction_dataset.csv\n",
      "Форма исходного датасета: (1625, 15)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>field_id</th>\n",
       "      <th>date_of_image</th>\n",
       "      <th>latitude</th>\n",
       "      <th>longitude</th>\n",
       "      <th>NDVI</th>\n",
       "      <th>GNDVI</th>\n",
       "      <th>NDWI</th>\n",
       "      <th>SAVI</th>\n",
       "      <th>soil_moisture</th>\n",
       "      <th>temperature</th>\n",
       "      <th>rainfall</th>\n",
       "      <th>crop_type</th>\n",
       "      <th>yield</th>\n",
       "      <th>Unnamed: 13</th>\n",
       "      <th>Unnamed: 14</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>01-01-2023</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.060190</td>\n",
       "      <td>0.084801</td>\n",
       "      <td>-0.084801</td>\n",
       "      <td>0.090280</td>\n",
       "      <td>46.119353</td>\n",
       "      <td>9.884229</td>\n",
       "      <td>1.662354</td>\n",
       "      <td>Rice</td>\n",
       "      <td>40.218031</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>16-01-2023</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.213957</td>\n",
       "      <td>0.222009</td>\n",
       "      <td>-0.222009</td>\n",
       "      <td>0.320896</td>\n",
       "      <td>37.542525</td>\n",
       "      <td>13.967073</td>\n",
       "      <td>8.446302</td>\n",
       "      <td>Rice</td>\n",
       "      <td>30.870338</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>31-01-2023</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.403306</td>\n",
       "      <td>0.431204</td>\n",
       "      <td>-0.431204</td>\n",
       "      <td>0.604837</td>\n",
       "      <td>24.926279</td>\n",
       "      <td>13.590147</td>\n",
       "      <td>3.862833</td>\n",
       "      <td>Rice</td>\n",
       "      <td>45.330050</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>15-02-2023</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.418187</td>\n",
       "      <td>0.444132</td>\n",
       "      <td>-0.444132</td>\n",
       "      <td>0.627144</td>\n",
       "      <td>24.114157</td>\n",
       "      <td>12.343355</td>\n",
       "      <td>16.623542</td>\n",
       "      <td>Rice</td>\n",
       "      <td>49.711781</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>03-02-2023</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.375138</td>\n",
       "      <td>0.387985</td>\n",
       "      <td>-0.387985</td>\n",
       "      <td>0.562591</td>\n",
       "      <td>27.420927</td>\n",
       "      <td>11.007707</td>\n",
       "      <td>9.496210</td>\n",
       "      <td>Rice</td>\n",
       "      <td>34.542646</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  field_id date_of_image   latitude  longitude      NDVI     GNDVI      NDWI      SAVI  soil_moisture  temperature   rainfall crop_type      yield  Unnamed: 13  Unnamed: 14\n",
       "0  Field_1    01-01-2023  22.625231  88.497925  0.060190  0.084801 -0.084801  0.090280      46.119353     9.884229   1.662354      Rice  40.218031          NaN          NaN\n",
       "1  Field_1    16-01-2023  22.625231  88.497925  0.213957  0.222009 -0.222009  0.320896      37.542525    13.967073   8.446302      Rice  30.870338          NaN          NaN\n",
       "2  Field_1    31-01-2023  22.625231  88.497925  0.403306  0.431204 -0.431204  0.604837      24.926279    13.590147   3.862833      Rice  45.330050          NaN          NaN\n",
       "3  Field_1    15-02-2023  22.625231  88.497925  0.418187  0.444132 -0.444132  0.627144      24.114157    12.343355  16.623542      Rice  49.711781          NaN          NaN\n",
       "4  Field_1    03-02-2023  22.625231  88.497925  0.375138  0.387985 -0.387985  0.562591      27.420927    11.007707   9.496210      Rice  34.542646          NaN          NaN"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Названия колонок после нормализации:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original_name</th>\n",
       "      <th>normalized_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>field_id</td>\n",
       "      <td>field_id</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>date_of_image</td>\n",
       "      <td>date_of_image</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>latitude</td>\n",
       "      <td>latitude</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>longitude</td>\n",
       "      <td>longitude</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NDVI</td>\n",
       "      <td>ndvi</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>GNDVI</td>\n",
       "      <td>gndvi</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>NDWI</td>\n",
       "      <td>ndwi</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>SAVI</td>\n",
       "      <td>savi</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>soil_moisture</td>\n",
       "      <td>soil_moisture</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>temperature</td>\n",
       "      <td>temperature</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>rainfall</td>\n",
       "      <td>rainfall</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>crop_type</td>\n",
       "      <td>crop_type</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>yield</td>\n",
       "      <td>yield</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>Unnamed: 13</td>\n",
       "      <td>unnamed_13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>Unnamed: 14</td>\n",
       "      <td>unnamed_14</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    original_name normalized_name\n",
       "0        field_id        field_id\n",
       "1   date_of_image   date_of_image\n",
       "2        latitude        latitude\n",
       "3       longitude       longitude\n",
       "4            NDVI            ndvi\n",
       "5           GNDVI           gndvi\n",
       "6            NDWI            ndwi\n",
       "7            SAVI            savi\n",
       "8   soil_moisture   soil_moisture\n",
       "9     temperature     temperature\n",
       "10       rainfall        rainfall\n",
       "11      crop_type       crop_type\n",
       "12          yield           yield\n",
       "13    Unnamed: 13      unnamed_13\n",
       "14    Unnamed: 14      unnamed_14"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Выбранная зависимая переменная: yield\n",
      "\n",
      "Форма таблицы после очистки: (1625, 16)\n",
      "Число исходных признаков: 15\n",
      "Число числовых признаков: 13\n",
      "Число категориальных признаков: 2\n",
      "Категориальные признаки: ['field_id', 'crop_type']\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>field_id</th>\n",
       "      <th>latitude</th>\n",
       "      <th>longitude</th>\n",
       "      <th>ndvi</th>\n",
       "      <th>gndvi</th>\n",
       "      <th>ndwi</th>\n",
       "      <th>savi</th>\n",
       "      <th>soil_moisture</th>\n",
       "      <th>temperature</th>\n",
       "      <th>rainfall</th>\n",
       "      <th>crop_type</th>\n",
       "      <th>yield</th>\n",
       "      <th>date_year</th>\n",
       "      <th>date_month</th>\n",
       "      <th>date_day</th>\n",
       "      <th>date_dayofyear</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.060190</td>\n",
       "      <td>0.084801</td>\n",
       "      <td>-0.084801</td>\n",
       "      <td>0.090280</td>\n",
       "      <td>46.119353</td>\n",
       "      <td>9.884229</td>\n",
       "      <td>1.662354</td>\n",
       "      <td>Rice</td>\n",
       "      <td>40.218031</td>\n",
       "      <td>2023</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.213957</td>\n",
       "      <td>0.222009</td>\n",
       "      <td>-0.222009</td>\n",
       "      <td>0.320896</td>\n",
       "      <td>37.542525</td>\n",
       "      <td>13.967073</td>\n",
       "      <td>8.446302</td>\n",
       "      <td>Rice</td>\n",
       "      <td>30.870338</td>\n",
       "      <td>2023</td>\n",
       "      <td>1</td>\n",
       "      <td>16</td>\n",
       "      <td>16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.403306</td>\n",
       "      <td>0.431204</td>\n",
       "      <td>-0.431204</td>\n",
       "      <td>0.604837</td>\n",
       "      <td>24.926279</td>\n",
       "      <td>13.590147</td>\n",
       "      <td>3.862833</td>\n",
       "      <td>Rice</td>\n",
       "      <td>45.330050</td>\n",
       "      <td>2023</td>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.418187</td>\n",
       "      <td>0.444132</td>\n",
       "      <td>-0.444132</td>\n",
       "      <td>0.627144</td>\n",
       "      <td>24.114157</td>\n",
       "      <td>12.343355</td>\n",
       "      <td>16.623542</td>\n",
       "      <td>Rice</td>\n",
       "      <td>49.711781</td>\n",
       "      <td>2023</td>\n",
       "      <td>2</td>\n",
       "      <td>15</td>\n",
       "      <td>46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Field_1</td>\n",
       "      <td>22.625231</td>\n",
       "      <td>88.497925</td>\n",
       "      <td>0.375138</td>\n",
       "      <td>0.387985</td>\n",
       "      <td>-0.387985</td>\n",
       "      <td>0.562591</td>\n",
       "      <td>27.420927</td>\n",
       "      <td>11.007707</td>\n",
       "      <td>9.496210</td>\n",
       "      <td>Rice</td>\n",
       "      <td>34.542646</td>\n",
       "      <td>2023</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  field_id   latitude  longitude      ndvi     gndvi      ndwi      savi  soil_moisture  temperature   rainfall crop_type      yield  date_year  date_month  date_day  date_dayofyear\n",
       "0  Field_1  22.625231  88.497925  0.060190  0.084801 -0.084801  0.090280      46.119353     9.884229   1.662354      Rice  40.218031       2023           1         1               1\n",
       "1  Field_1  22.625231  88.497925  0.213957  0.222009 -0.222009  0.320896      37.542525    13.967073   8.446302      Rice  30.870338       2023           1        16              16\n",
       "2  Field_1  22.625231  88.497925  0.403306  0.431204 -0.431204  0.604837      24.926279    13.590147   3.862833      Rice  45.330050       2023           1        31              31\n",
       "3  Field_1  22.625231  88.497925  0.418187  0.444132 -0.444132  0.627144      24.114157    12.343355  16.623542      Rice  49.711781       2023           2        15              46\n",
       "4  Field_1  22.625231  88.497925  0.375138  0.387985 -0.387985  0.562591      27.420927    11.007707   9.496210      Rice  34.542646       2023           2         3              34"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "csv_path = find_csv_file()\n",
    "raw_df = pd.read_csv(csv_path)\n",
    "df, original_columns = clean_dataframe(raw_df)\n",
    "\n",
    "print(\"Используемый файл:\", csv_path)\n",
    "print(\"Форма исходного датасета:\", raw_df.shape)\n",
    "display(raw_df.head())\n",
    "\n",
    "print(\"\\nНазвания колонок после нормализации:\")\n",
    "display(pd.DataFrame({\n",
    "    \"original_name\": original_columns,\n",
    "    \"normalized_name\": [normalize_column_name(c) for c in original_columns]\n",
    "}))\n",
    "\n",
    "target_col = detect_target_column(df)\n",
    "print(\"Выбранная зависимая переменная:\", target_col)\n",
    "\n",
    "X, y, numeric_features, categorical_features = split_features(df, target_col)\n",
    "\n",
    "print(\"\\nФорма таблицы после очистки:\", df.shape)\n",
    "print(\"Число исходных признаков:\", X.shape[1])\n",
    "print(\"Число числовых признаков:\", len(numeric_features))\n",
    "print(\"Число категориальных признаков:\", len(categorical_features))\n",
    "print(\"Категориальные признаки:\", categorical_features)\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c99d32fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Размер train: (975, 15)\n",
      "Размер validation: (325, 15)\n",
      "Размер test: (325, 15)\n"
     ]
    }
   ],
   "source": [
    "X_temp, X_test, y_temp, y_test = train_test_split(\n",
    "    X, y, test_size=0.2, random_state=RANDOM_STATE\n",
    ")\n",
    "X_train, X_val, y_train, y_val = train_test_split(\n",
    "    X_temp, y_temp, test_size=0.25, random_state=RANDOM_STATE\n",
    ")\n",
    "\n",
    "print(\"Размер train:\", X_train.shape)\n",
    "print(\"Размер validation:\", X_val.shape)\n",
    "print(\"Размер test:\", X_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e9e74e86",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_mlp_regression(\n",
    "    X_train, y_train,\n",
    "    X_val, y_val,\n",
    "    numeric_features, categorical_features,\n",
    "    hidden_layers=(64, 32),\n",
    "    alpha=1e-4,\n",
    "    learning_rate_init=1e-3,\n",
    "    max_epochs=120,\n",
    "    batch_size=64,\n",
    "    random_state=RANDOM_STATE,\n",
    "):\n",
    "    preprocessor = make_preprocessor(numeric_features, categorical_features)\n",
    "    X_train_proc = preprocessor.fit_transform(X_train)\n",
    "    X_val_proc = preprocessor.transform(X_val)\n",
    "\n",
    "    feature_names = preprocessor.get_feature_names_out().tolist()\n",
    "\n",
    "    y_scaler = StandardScaler()\n",
    "    y_train_scaled = y_scaler.fit_transform(np.asarray(y_train).reshape(-1, 1)).ravel()\n",
    "\n",
    "    mlp = MLPRegressor(\n",
    "        hidden_layer_sizes=hidden_layers,\n",
    "        activation=\"relu\",\n",
    "        solver=\"adam\",\n",
    "        alpha=alpha,\n",
    "        learning_rate_init=learning_rate_init,\n",
    "        batch_size=batch_size,\n",
    "        max_iter=1,\n",
    "        warm_start=True,\n",
    "        shuffle=True,\n",
    "        random_state=random_state,\n",
    "    )\n",
    "\n",
    "    history = []\n",
    "    best_model = None\n",
    "    best_val_rmse = np.inf\n",
    "    patience = 20\n",
    "    no_improve = 0\n",
    "\n",
    "    for epoch in range(max_epochs):\n",
    "        mlp.fit(X_train_proc, y_train_scaled)\n",
    "\n",
    "        pred_train = y_scaler.inverse_transform(mlp.predict(X_train_proc).reshape(-1, 1)).ravel()\n",
    "        pred_val = y_scaler.inverse_transform(mlp.predict(X_val_proc).reshape(-1, 1)).ravel()\n",
    "\n",
    "        train_m = regression_metrics(y_train, pred_train)\n",
    "        val_m = regression_metrics(y_val, pred_val)\n",
    "\n",
    "        history.append({\n",
    "            \"epoch\": epoch + 1,\n",
    "            \"train_loss\": float(mlp.loss_),\n",
    "            \"val_r2\": val_m[\"r2\"],\n",
    "            \"val_rmse\": val_m[\"rmse\"],\n",
    "            \"val_mae\": val_m[\"mae\"],\n",
    "        })\n",
    "\n",
    "        if val_m[\"rmse\"] < best_val_rmse - 1e-8:\n",
    "            best_val_rmse = val_m[\"rmse\"]\n",
    "            best_model = copy.deepcopy(mlp)\n",
    "            no_improve = 0\n",
    "        else:\n",
    "            no_improve += 1\n",
    "\n",
    "        if no_improve >= patience:\n",
    "            break\n",
    "\n",
    "    if best_model is None:\n",
    "        best_model = mlp\n",
    "\n",
    "    return {\n",
    "        \"model\": best_model,\n",
    "        \"preprocessor\": preprocessor,\n",
    "        \"y_scaler\": y_scaler,\n",
    "        \"feature_names\": feature_names,\n",
    "        \"history\": pd.DataFrame(history),\n",
    "        \"X_train_proc\": X_train_proc,\n",
    "        \"X_val_proc\": X_val_proc,\n",
    "    }\n",
    "\n",
    "\n",
    "def evaluate_result(result, X_train, y_train, X_val, y_val, X_test, y_test):\n",
    "    model = result[\"model\"]\n",
    "    preprocessor = result[\"preprocessor\"]\n",
    "    y_scaler = result[\"y_scaler\"]\n",
    "\n",
    "    X_train_proc = preprocessor.transform(X_train)\n",
    "    X_val_proc = preprocessor.transform(X_val)\n",
    "    X_test_proc = preprocessor.transform(X_test)\n",
    "\n",
    "    pred_train = y_scaler.inverse_transform(model.predict(X_train_proc).reshape(-1, 1)).ravel()\n",
    "    pred_val = y_scaler.inverse_transform(model.predict(X_val_proc).reshape(-1, 1)).ravel()\n",
    "    pred_test = y_scaler.inverse_transform(model.predict(X_test_proc).reshape(-1, 1)).ravel()\n",
    "\n",
    "    metrics_train = regression_metrics(y_train, pred_train)\n",
    "    metrics_val = regression_metrics(y_val, pred_val)\n",
    "    metrics_test = regression_metrics(y_test, pred_test)\n",
    "\n",
    "    return {\n",
    "        \"pred_train\": pred_train,\n",
    "        \"pred_val\": pred_val,\n",
    "        \"pred_test\": pred_test,\n",
    "        \"metrics_train\": metrics_train,\n",
    "        \"metrics_val\": metrics_val,\n",
    "        \"metrics_test\": metrics_test,\n",
    "        \"X_test_proc\": X_test_proc,\n",
    "        \"X_val_proc\": X_val_proc,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7ef6a60",
   "metadata": {},
   "source": [
    "## Базовая нейросеть по всем признакам"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d421161a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>split</th>\n",
       "      <th>r2</th>\n",
       "      <th>rmse</th>\n",
       "      <th>mae</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train</td>\n",
       "      <td>0.994751</td>\n",
       "      <td>0.582905</td>\n",
       "      <td>0.406545</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>validation</td>\n",
       "      <td>0.870878</td>\n",
       "      <td>2.828249</td>\n",
       "      <td>1.775249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>test</td>\n",
       "      <td>0.902929</td>\n",
       "      <td>2.743203</td>\n",
       "      <td>1.755315</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        split        r2      rmse       mae\n",
       "0       train  0.994751  0.582905  0.406545\n",
       "1  validation  0.870878  2.828249  1.775249\n",
       "2        test  0.902929  2.743203  1.755315"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>val_r2</th>\n",
       "      <th>val_rmse</th>\n",
       "      <th>val_mae</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>130</th>\n",
       "      <td>131</td>\n",
       "      <td>0.002935</td>\n",
       "      <td>0.869096</td>\n",
       "      <td>2.847705</td>\n",
       "      <td>1.784982</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>132</td>\n",
       "      <td>0.003854</td>\n",
       "      <td>0.869229</td>\n",
       "      <td>2.846251</td>\n",
       "      <td>1.777142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>133</td>\n",
       "      <td>0.003546</td>\n",
       "      <td>0.867845</td>\n",
       "      <td>2.861276</td>\n",
       "      <td>1.812772</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>134</td>\n",
       "      <td>0.003689</td>\n",
       "      <td>0.867460</td>\n",
       "      <td>2.865439</td>\n",
       "      <td>1.822093</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>135</td>\n",
       "      <td>0.002960</td>\n",
       "      <td>0.870878</td>\n",
       "      <td>2.828249</td>\n",
       "      <td>1.775249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>136</td>\n",
       "      <td>0.004358</td>\n",
       "      <td>0.870413</td>\n",
       "      <td>2.833335</td>\n",
       "      <td>1.799200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>136</th>\n",
       "      <td>137</td>\n",
       "      <td>0.003298</td>\n",
       "      <td>0.869790</td>\n",
       "      <td>2.840137</td>\n",
       "      <td>1.761328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>137</th>\n",
       "      <td>138</td>\n",
       "      <td>0.003451</td>\n",
       "      <td>0.868436</td>\n",
       "      <td>2.854873</td>\n",
       "      <td>1.820676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>138</th>\n",
       "      <td>139</td>\n",
       "      <td>0.003324</td>\n",
       "      <td>0.869806</td>\n",
       "      <td>2.839963</td>\n",
       "      <td>1.781549</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139</th>\n",
       "      <td>140</td>\n",
       "      <td>0.003620</td>\n",
       "      <td>0.870025</td>\n",
       "      <td>2.837579</td>\n",
       "      <td>1.791714</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     epoch  train_loss    val_r2  val_rmse   val_mae\n",
       "130    131    0.002935  0.869096  2.847705  1.784982\n",
       "131    132    0.003854  0.869229  2.846251  1.777142\n",
       "132    133    0.003546  0.867845  2.861276  1.812772\n",
       "133    134    0.003689  0.867460  2.865439  1.822093\n",
       "134    135    0.002960  0.870878  2.828249  1.775249\n",
       "135    136    0.004358  0.870413  2.833335  1.799200\n",
       "136    137    0.003298  0.869790  2.840137  1.761328\n",
       "137    138    0.003451  0.868436  2.854873  1.820676\n",
       "138    139    0.003324  0.869806  2.839963  1.781549\n",
       "139    140    0.003620  0.870025  2.837579  1.791714"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "baseline_result = fit_mlp_regression(\n",
    "    X_train, y_train,\n",
    "    X_val, y_val,\n",
    "    numeric_features, categorical_features,\n",
    "    hidden_layers=(64, 32),\n",
    "    alpha=1e-4,\n",
    "    learning_rate_init=1e-3,\n",
    "    max_epochs=140,\n",
    "    batch_size=64,\n",
    "    random_state=RANDOM_STATE,\n",
    ")\n",
    "baseline_eval = evaluate_result(\n",
    "    baseline_result, X_train, y_train, X_val, y_val, X_test, y_test\n",
    ")\n",
    "\n",
    "metrics_baseline_df = pd.DataFrame([\n",
    "    {\"split\": \"train\", **baseline_eval[\"metrics_train\"]},\n",
    "    {\"split\": \"validation\", **baseline_eval[\"metrics_val\"]},\n",
    "    {\"split\": \"test\", **baseline_eval[\"metrics_test\"]},\n",
    "])\n",
    "display(metrics_baseline_df)\n",
    "display(baseline_result[\"history\"].tail(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dab5771d",
   "metadata": {},
   "source": [
    "## Важность признаков\n",
    "\n",
    "Для оценки важности используется перестановочная важность на валидационной выборке. Дополнительно признаки агрегируются обратно к исходным столбцам."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7cfc5569",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature</th>\n",
       "      <th>permutation_importance_mean</th>\n",
       "      <th>permutation_importance_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>rainfall</td>\n",
       "      <td>5.739523</td>\n",
       "      <td>0.193029</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ndvi</td>\n",
       "      <td>2.208541</td>\n",
       "      <td>0.169049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>savi</td>\n",
       "      <td>1.538535</td>\n",
       "      <td>0.151130</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>crop_type</td>\n",
       "      <td>1.415012</td>\n",
       "      <td>0.608425</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ndwi</td>\n",
       "      <td>0.890150</td>\n",
       "      <td>0.098273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>latitude</td>\n",
       "      <td>0.798413</td>\n",
       "      <td>0.087298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>field_id</td>\n",
       "      <td>0.692456</td>\n",
       "      <td>0.718612</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>soil_moisture</td>\n",
       "      <td>0.609792</td>\n",
       "      <td>0.092161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>gndvi</td>\n",
       "      <td>0.568960</td>\n",
       "      <td>0.064340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>date_dayofyear</td>\n",
       "      <td>0.462667</td>\n",
       "      <td>0.074318</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          feature  permutation_importance_mean  permutation_importance_std\n",
       "0        rainfall                     5.739523                    0.193029\n",
       "1            ndvi                     2.208541                    0.169049\n",
       "2            savi                     1.538535                    0.151130\n",
       "3       crop_type                     1.415012                    0.608425\n",
       "4            ndwi                     0.890150                    0.098273\n",
       "5        latitude                     0.798413                    0.087298\n",
       "6        field_id                     0.692456                    0.718612\n",
       "7   soil_moisture                     0.609792                    0.092161\n",
       "8           gndvi                     0.568960                    0.064340\n",
       "9  date_dayofyear                     0.462667                    0.074318"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Отобранные признаки:\n",
      "['rainfall', 'ndvi', 'savi', 'crop_type', 'ndwi', 'latitude', 'field_id', 'soil_moisture']\n"
     ]
    }
   ],
   "source": [
    "wrapped_baseline = InverseScaledRegressorWrapper(\n",
    "    baseline_result[\"model\"],\n",
    "    baseline_result[\"y_scaler\"]\n",
    ")\n",
    "\n",
    "perm = permutation_importance(\n",
    "    estimator=wrapped_baseline,\n",
    "    X=baseline_eval[\"X_val_proc\"],\n",
    "    y=y_val,\n",
    "    scoring=\"neg_root_mean_squared_error\",\n",
    "    n_repeats=10,\n",
    "    random_state=RANDOM_STATE\n",
    ")\n",
    "\n",
    "perm_df = pd.DataFrame({\n",
    "    \"processed_feature\": baseline_result[\"feature_names\"],\n",
    "    \"permutation_importance_mean\": perm.importances_mean,\n",
    "    \"permutation_importance_std\": perm.importances_std\n",
    "})\n",
    "\n",
    "def processed_to_original(name, numeric_features, categorical_features):\n",
    "    if name.startswith(\"num__\"):\n",
    "        return name.replace(\"num__\", \"\", 1)\n",
    "    if name.startswith(\"cat__\"):\n",
    "        suffix = name.replace(\"cat__\", \"\", 1)\n",
    "        for c in sorted(categorical_features, key=len, reverse=True):\n",
    "            prefix = c + \"_\"\n",
    "            if suffix.startswith(prefix):\n",
    "                return c\n",
    "        return suffix\n",
    "    return name\n",
    "\n",
    "perm_df[\"feature\"] = perm_df[\"processed_feature\"].apply(\n",
    "    lambda x: processed_to_original(x, numeric_features, categorical_features)\n",
    ")\n",
    "\n",
    "importance_df = (\n",
    "    perm_df.groupby(\"feature\", as_index=False)[[\"permutation_importance_mean\", \"permutation_importance_std\"]]\n",
    "    .sum()\n",
    "    .sort_values(\"permutation_importance_mean\", ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "display(importance_df.head(10))\n",
    "\n",
    "positive_features = importance_df.loc[\n",
    "    importance_df[\"permutation_importance_mean\"] > 0, \"feature\"\n",
    "].tolist()\n",
    "\n",
    "if len(positive_features) >= 5:\n",
    "    selected_features = positive_features[:min(8, len(positive_features))]\n",
    "else:\n",
    "    selected_features = importance_df[\"feature\"].head(min(6, len(importance_df))).tolist()\n",
    "\n",
    "print(\"Отобранные признаки:\")\n",
    "print(selected_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42a13b1b",
   "metadata": {},
   "source": [
    "## Вторая нейросеть: только по отобранным признакам + эволюционный поиск архитектуры"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8c1c0e6b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>generation</th>\n",
       "      <th>rank_in_generation</th>\n",
       "      <th>hidden_layers</th>\n",
       "      <th>alpha</th>\n",
       "      <th>learning_rate_init</th>\n",
       "      <th>rmse_val</th>\n",
       "      <th>mae_val</th>\n",
       "      <th>r2_val</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000004</td>\n",
       "      <td>0.009822</td>\n",
       "      <td>2.664205</td>\n",
       "      <td>1.501977</td>\n",
       "      <td>0.885422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>6</td>\n",
       "      <td>3</td>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>(80, 40, 20)</td>\n",
       "      <td>0.000005</td>\n",
       "      <td>0.005015</td>\n",
       "      <td>2.683304</td>\n",
       "      <td>1.368121</td>\n",
       "      <td>0.883774</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   generation  rank_in_generation hidden_layers     alpha  learning_rate_init  rmse_val   mae_val    r2_val\n",
       "0           6                   1         (64,)  0.000004            0.009822  2.664205  1.501977  0.885422\n",
       "1           3                   1         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "2           4                   1         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "3           5                   1         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "4           6                   2         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "5           3                   2         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "6           4                   2         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "7           5                   2         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "8           6                   3         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "9           2                   1  (80, 40, 20)  0.000005            0.005015  2.683304  1.368121  0.883774"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>split</th>\n",
       "      <th>r2</th>\n",
       "      <th>rmse</th>\n",
       "      <th>mae</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train</td>\n",
       "      <td>0.969089</td>\n",
       "      <td>1.414586</td>\n",
       "      <td>0.905347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>validation</td>\n",
       "      <td>0.885422</td>\n",
       "      <td>2.664205</td>\n",
       "      <td>1.501977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>test</td>\n",
       "      <td>0.933225</td>\n",
       "      <td>2.275202</td>\n",
       "      <td>1.291376</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        split        r2      rmse       mae\n",
       "0       train  0.969089  1.414586  0.905347\n",
       "1  validation  0.885422  2.664205  1.501977\n",
       "2        test  0.933225  2.275202  1.291376"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>val_r2</th>\n",
       "      <th>val_rmse</th>\n",
       "      <th>val_mae</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>41</td>\n",
       "      <td>0.019114</td>\n",
       "      <td>0.877719</td>\n",
       "      <td>2.752308</td>\n",
       "      <td>1.439200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>42</td>\n",
       "      <td>0.020414</td>\n",
       "      <td>0.871173</td>\n",
       "      <td>2.825014</td>\n",
       "      <td>1.502276</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>43</td>\n",
       "      <td>0.020595</td>\n",
       "      <td>0.869560</td>\n",
       "      <td>2.842648</td>\n",
       "      <td>1.457928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>44</td>\n",
       "      <td>0.020510</td>\n",
       "      <td>0.866645</td>\n",
       "      <td>2.874235</td>\n",
       "      <td>1.521252</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>45</td>\n",
       "      <td>0.020608</td>\n",
       "      <td>0.871446</td>\n",
       "      <td>2.822021</td>\n",
       "      <td>1.508952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>46</td>\n",
       "      <td>0.020722</td>\n",
       "      <td>0.869417</td>\n",
       "      <td>2.844210</td>\n",
       "      <td>1.433634</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>47</td>\n",
       "      <td>0.018139</td>\n",
       "      <td>0.875383</td>\n",
       "      <td>2.778477</td>\n",
       "      <td>1.497307</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>48</td>\n",
       "      <td>0.020332</td>\n",
       "      <td>0.864971</td>\n",
       "      <td>2.892221</td>\n",
       "      <td>1.498261</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>49</td>\n",
       "      <td>0.021497</td>\n",
       "      <td>0.863142</td>\n",
       "      <td>2.911740</td>\n",
       "      <td>1.533948</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>50</td>\n",
       "      <td>0.019801</td>\n",
       "      <td>0.870129</td>\n",
       "      <td>2.836441</td>\n",
       "      <td>1.450026</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    epoch  train_loss    val_r2  val_rmse   val_mae\n",
       "40     41    0.019114  0.877719  2.752308  1.439200\n",
       "41     42    0.020414  0.871173  2.825014  1.502276\n",
       "42     43    0.020595  0.869560  2.842648  1.457928\n",
       "43     44    0.020510  0.866645  2.874235  1.521252\n",
       "44     45    0.020608  0.871446  2.822021  1.508952\n",
       "45     46    0.020722  0.869417  2.844210  1.433634\n",
       "46     47    0.018139  0.875383  2.778477  1.497307\n",
       "47     48    0.020332  0.864971  2.892221  1.498261\n",
       "48     49    0.021497  0.863142  2.911740  1.533948\n",
       "49     50    0.019801  0.870129  2.836441  1.450026"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def random_hidden_layers():\n",
    "    templates = [\n",
    "        (16,), (24,), (32,), (48,), (64,), (72,), (96,),\n",
    "        (16, 8), (24, 12), (32, 16), (48, 24), (64, 32), (80, 40),\n",
    "        (32, 16, 8), (48, 24, 12), (64, 32, 16), (96, 32, 8),\n",
    "        (64, 16), (96, 48, 16), (128, 64, 32), (80, 40, 20), (64, 32, 16, 8)\n",
    "    ]\n",
    "    return templates[np.random.randint(0, len(templates))]\n",
    "\n",
    "\n",
    "def sample_candidate():\n",
    "    return {\n",
    "        \"hidden_layers\": random_hidden_layers(),\n",
    "        \"alpha\": float(10 ** np.random.uniform(-6, -2)),\n",
    "        \"learning_rate_init\": float(10 ** np.random.uniform(-4, -2)),\n",
    "    }\n",
    "\n",
    "\n",
    "def mutate_candidate(base):\n",
    "    cand = copy.deepcopy(base)\n",
    "    if np.random.rand() < 0.5:\n",
    "        cand[\"hidden_layers\"] = random_hidden_layers()\n",
    "    if np.random.rand() < 0.7:\n",
    "        cand[\"alpha\"] = float(np.clip(cand[\"alpha\"] * (10 ** np.random.uniform(-0.5, 0.5)), 1e-6, 1e-1))\n",
    "    if np.random.rand() < 0.7:\n",
    "        cand[\"learning_rate_init\"] = float(np.clip(cand[\"learning_rate_init\"] * (10 ** np.random.uniform(-0.5, 0.5)), 1e-4, 1e-2))\n",
    "    return cand\n",
    "\n",
    "\n",
    "def crossover_candidate(a, b):\n",
    "    return {\n",
    "        \"hidden_layers\": a[\"hidden_layers\"] if np.random.rand() < 0.5 else b[\"hidden_layers\"],\n",
    "        \"alpha\": float(np.sqrt(a[\"alpha\"] * b[\"alpha\"])),\n",
    "        \"learning_rate_init\": float(np.sqrt(a[\"learning_rate_init\"] * b[\"learning_rate_init\"])),\n",
    "    }\n",
    "\n",
    "\n",
    "def evaluate_candidate(candidate, X_train_s, y_train_s, X_val_s, y_val_s, n_feats, c_feats):\n",
    "    result = fit_mlp_regression(\n",
    "        X_train_s, y_train_s,\n",
    "        X_val_s, y_val_s,\n",
    "        n_feats, c_feats,\n",
    "        hidden_layers=candidate[\"hidden_layers\"],\n",
    "        alpha=candidate[\"alpha\"],\n",
    "        learning_rate_init=candidate[\"learning_rate_init\"],\n",
    "        max_epochs=90,\n",
    "        batch_size=64,\n",
    "        random_state=RANDOM_STATE,\n",
    "    )\n",
    "    eval_result = evaluate_result(result, X_train_s, y_train_s, X_val_s, y_val_s, X_val_s, y_val_s)\n",
    "    row = {\n",
    "        \"hidden_layers\": candidate[\"hidden_layers\"],\n",
    "        \"alpha\": candidate[\"alpha\"],\n",
    "        \"learning_rate_init\": candidate[\"learning_rate_init\"],\n",
    "        \"rmse_val\": eval_result[\"metrics_val\"][\"rmse\"],\n",
    "        \"mae_val\": eval_result[\"metrics_val\"][\"mae\"],\n",
    "        \"r2_val\": eval_result[\"metrics_val\"][\"r2\"],\n",
    "        \"result\": result,\n",
    "    }\n",
    "    return row\n",
    "\n",
    "\n",
    "X_train_sel = X_train[selected_features].copy()\n",
    "X_val_sel = X_val[selected_features].copy()\n",
    "X_test_sel = X_test[selected_features].copy()\n",
    "\n",
    "numeric_features_sel = X_train_sel.select_dtypes(include=[np.number]).columns.tolist()\n",
    "categorical_features_sel = [c for c in X_train_sel.columns if c not in numeric_features_sel]\n",
    "\n",
    "population_size = 8\n",
    "n_generations = 6\n",
    "elite_size = 3\n",
    "\n",
    "population = [sample_candidate() for _ in range(population_size)]\n",
    "search_rows = []\n",
    "\n",
    "for generation in range(n_generations):\n",
    "    evaluated = [\n",
    "        evaluate_candidate(c, X_train_sel, y_train, X_val_sel, y_val, numeric_features_sel, categorical_features_sel)\n",
    "        for c in population\n",
    "    ]\n",
    "    evaluated = sorted(evaluated, key=lambda x: x[\"rmse_val\"])\n",
    "    elites = evaluated[:elite_size]\n",
    "\n",
    "    for rank, item in enumerate(evaluated):\n",
    "        search_rows.append({\n",
    "            \"generation\": generation + 1,\n",
    "            \"rank_in_generation\": rank + 1,\n",
    "            \"hidden_layers\": item[\"hidden_layers\"],\n",
    "            \"alpha\": item[\"alpha\"],\n",
    "            \"learning_rate_init\": item[\"learning_rate_init\"],\n",
    "            \"rmse_val\": item[\"rmse_val\"],\n",
    "            \"mae_val\": item[\"mae_val\"],\n",
    "            \"r2_val\": item[\"r2_val\"],\n",
    "        })\n",
    "\n",
    "    new_population = [\n",
    "        {\n",
    "            \"hidden_layers\": e[\"hidden_layers\"],\n",
    "            \"alpha\": e[\"alpha\"],\n",
    "            \"learning_rate_init\": e[\"learning_rate_init\"],\n",
    "        }\n",
    "        for e in elites\n",
    "    ]\n",
    "\n",
    "    while len(new_population) < population_size:\n",
    "        parent_a, parent_b = np.random.choice(elites, size=2, replace=True)\n",
    "        child = crossover_candidate(parent_a, parent_b)\n",
    "        child = mutate_candidate(child)\n",
    "        new_population.append(child)\n",
    "\n",
    "    population = new_population\n",
    "\n",
    "evolution_df = pd.DataFrame(search_rows).sort_values([\"rmse_val\", \"mae_val\", \"r2_val\"], ascending=[True, True, False]).reset_index(drop=True)\n",
    "best_candidate = evolution_df.iloc[0].to_dict()\n",
    "display(evolution_df.head(10))\n",
    "\n",
    "selected_result = fit_mlp_regression(\n",
    "    X_train_sel, y_train,\n",
    "    X_val_sel, y_val,\n",
    "    numeric_features_sel, categorical_features_sel,\n",
    "    hidden_layers=best_candidate[\"hidden_layers\"],\n",
    "    alpha=float(best_candidate[\"alpha\"]),\n",
    "    learning_rate_init=float(best_candidate[\"learning_rate_init\"]),\n",
    "    max_epochs=180,\n",
    "    batch_size=64,\n",
    "    random_state=RANDOM_STATE,\n",
    ")\n",
    "selected_eval = evaluate_result(\n",
    "    selected_result, X_train_sel, y_train, X_val_sel, y_val, X_test_sel, y_test\n",
    ")\n",
    "\n",
    "metrics_selected_df = pd.DataFrame([\n",
    "    {\"split\": \"train\", **selected_eval[\"metrics_train\"]},\n",
    "    {\"split\": \"validation\", **selected_eval[\"metrics_val\"]},\n",
    "    {\"split\": \"test\", **selected_eval[\"metrics_test\"]},\n",
    "])\n",
    "\n",
    "display(metrics_selected_df)\n",
    "display(selected_result[\"history\"].tail(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "277a5a31",
   "metadata": {},
   "source": [
    "## Сравнение моделей"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c7050b6b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>r2_test</th>\n",
       "      <th>rmse_test</th>\n",
       "      <th>mae_test</th>\n",
       "      <th>r2_val</th>\n",
       "      <th>rmse_val</th>\n",
       "      <th>mae_val</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Нейросеть по отобранным признакам + эволюционн...</td>\n",
       "      <td>0.933225</td>\n",
       "      <td>2.275202</td>\n",
       "      <td>1.291376</td>\n",
       "      <td>0.885422</td>\n",
       "      <td>2.664205</td>\n",
       "      <td>1.501977</td>\n",
       "      <td>8</td>\n",
       "      <td>126</td>\n",
       "      <td>(64,)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Базовая нейросеть по всем признакам</td>\n",
       "      <td>0.902929</td>\n",
       "      <td>2.743203</td>\n",
       "      <td>1.755315</td>\n",
       "      <td>0.870878</td>\n",
       "      <td>2.828249</td>\n",
       "      <td>1.775249</td>\n",
       "      <td>15</td>\n",
       "      <td>133</td>\n",
       "      <td>(64, 32)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               model   r2_test  rmse_test  mae_test    r2_val  rmse_val   mae_val  n_original_features  n_processed_features hidden_layers\n",
       "0  Нейросеть по отобранным признакам + эволюционн...  0.933225   2.275202  1.291376  0.885422  2.664205  1.501977                    8                   126         (64,)\n",
       "1                Базовая нейросеть по всем признакам  0.902929   2.743203  1.755315  0.870878  2.828249  1.775249                   15                   133      (64, 32)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Лучшая модель:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>r2_test</th>\n",
       "      <th>rmse_test</th>\n",
       "      <th>mae_test</th>\n",
       "      <th>r2_val</th>\n",
       "      <th>rmse_val</th>\n",
       "      <th>mae_val</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Нейросеть по отобранным признакам + эволюционн...</td>\n",
       "      <td>0.933225</td>\n",
       "      <td>2.275202</td>\n",
       "      <td>1.291376</td>\n",
       "      <td>0.885422</td>\n",
       "      <td>2.664205</td>\n",
       "      <td>1.501977</td>\n",
       "      <td>8</td>\n",
       "      <td>126</td>\n",
       "      <td>(64,)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               model   r2_test  rmse_test  mae_test    r2_val  rmse_val   mae_val  n_original_features  n_processed_features hidden_layers\n",
       "0  Нейросеть по отобранным признакам + эволюционн...  0.933225   2.275202  1.291376  0.885422  2.664205  1.501977                    8                   126         (64,)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "comparison_df = pd.DataFrame([\n",
    "    {\n",
    "        \"model\": \"Базовая нейросеть по всем признакам\",\n",
    "        \"r2_test\": baseline_eval[\"metrics_test\"][\"r2\"],\n",
    "        \"rmse_test\": baseline_eval[\"metrics_test\"][\"rmse\"],\n",
    "        \"mae_test\": baseline_eval[\"metrics_test\"][\"mae\"],\n",
    "        \"r2_val\": baseline_eval[\"metrics_val\"][\"r2\"],\n",
    "        \"rmse_val\": baseline_eval[\"metrics_val\"][\"rmse\"],\n",
    "        \"mae_val\": baseline_eval[\"metrics_val\"][\"mae\"],\n",
    "        \"n_original_features\": X.shape[1],\n",
    "        \"n_processed_features\": len(baseline_result[\"feature_names\"]),\n",
    "        \"hidden_layers\": (64, 32),\n",
    "    },\n",
    "    {\n",
    "        \"model\": \"Нейросеть по отобранным признакам + эволюционный поиск\",\n",
    "        \"r2_test\": selected_eval[\"metrics_test\"][\"r2\"],\n",
    "        \"rmse_test\": selected_eval[\"metrics_test\"][\"rmse\"],\n",
    "        \"mae_test\": selected_eval[\"metrics_test\"][\"mae\"],\n",
    "        \"r2_val\": selected_eval[\"metrics_val\"][\"r2\"],\n",
    "        \"rmse_val\": selected_eval[\"metrics_val\"][\"rmse\"],\n",
    "        \"mae_val\": selected_eval[\"metrics_val\"][\"mae\"],\n",
    "        \"n_original_features\": len(selected_features),\n",
    "        \"n_processed_features\": len(selected_result[\"feature_names\"]),\n",
    "        \"hidden_layers\": best_candidate[\"hidden_layers\"],\n",
    "    },\n",
    "]).sort_values([\"r2_test\", \"rmse_test\", \"mae_test\"], ascending=[False, True, True]).reset_index(drop=True)\n",
    "\n",
    "display(comparison_df)\n",
    "best_model_name = comparison_df.iloc[0][\"model\"]\n",
    "print(\"Лучшая модель:\")\n",
    "display(comparison_df.head(1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4188955f",
   "metadata": {},
   "source": [
    "## Прототипы по лучшей сети\n",
    "\n",
    "Прототипы определяются в пространстве признаков лучшей модели. Для этого:\n",
    "- берётся обучающая и валидационная выборка;\n",
    "- данные переводятся в пространcтво признаков модели;\n",
    "- признаки взвешиваются по важности;\n",
    "- далее выполняется кластеризация `KMeans`, а центры кластеров рассматриваются как прототипы."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1d4e7f1f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature</th>\n",
       "      <th>permutation_importance_mean</th>\n",
       "      <th>permutation_importance_std</th>\n",
       "      <th>processed_weight_importance</th>\n",
       "      <th>rank_permutation</th>\n",
       "      <th>rank_weight</th>\n",
       "      <th>rank_mean</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>rainfall</td>\n",
       "      <td>6.018178</td>\n",
       "      <td>0.223739</td>\n",
       "      <td>0.209641</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>crop_type</td>\n",
       "      <td>1.903834</td>\n",
       "      <td>0.796544</td>\n",
       "      <td>3.011422</td>\n",
       "      <td>4.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ndvi</td>\n",
       "      <td>2.013407</td>\n",
       "      <td>0.170873</td>\n",
       "      <td>0.116551</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>field_id</td>\n",
       "      <td>1.117020</td>\n",
       "      <td>0.848319</td>\n",
       "      <td>6.782297</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>savi</td>\n",
       "      <td>1.920353</td>\n",
       "      <td>0.161670</td>\n",
       "      <td>0.107952</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>ndwi</td>\n",
       "      <td>1.043480</td>\n",
       "      <td>0.081542</td>\n",
       "      <td>0.093077</td>\n",
       "      <td>6.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>latitude</td>\n",
       "      <td>0.126535</td>\n",
       "      <td>0.049307</td>\n",
       "      <td>0.101311</td>\n",
       "      <td>8.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>7.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>soil_moisture</td>\n",
       "      <td>0.481983</td>\n",
       "      <td>0.055921</td>\n",
       "      <td>0.087355</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         feature  permutation_importance_mean  permutation_importance_std  processed_weight_importance  rank_permutation  rank_weight  rank_mean\n",
       "0       rainfall                     6.018178                    0.223739                     0.209641               1.0          3.0        2.0\n",
       "1      crop_type                     1.903834                    0.796544                     3.011422               4.0          2.0        3.0\n",
       "2           ndvi                     2.013407                    0.170873                     0.116551               2.0          4.0        3.0\n",
       "3       field_id                     1.117020                    0.848319                     6.782297               5.0          1.0        3.0\n",
       "4           savi                     1.920353                    0.161670                     0.107952               3.0          5.0        4.0\n",
       "5           ndwi                     1.043480                    0.081542                     0.093077               6.0          7.0        6.5\n",
       "6       latitude                     0.126535                    0.049307                     0.101311               8.0          6.0        7.0\n",
       "7  soil_moisture                     0.481983                    0.055921                     0.087355               7.0          8.0        7.5"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prototype_id</th>\n",
       "      <th>support_n</th>\n",
       "      <th>support_share_trainval</th>\n",
       "      <th>representative_target</th>\n",
       "      <th>assigned_mean_target_trainval</th>\n",
       "      <th>rainfall</th>\n",
       "      <th>crop_type</th>\n",
       "      <th>ndvi</th>\n",
       "      <th>field_id</th>\n",
       "      <th>savi</th>\n",
       "      <th>ndwi</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>453</td>\n",
       "      <td>0.348462</td>\n",
       "      <td>31.892742</td>\n",
       "      <td>32.555298</td>\n",
       "      <td>5.990291</td>\n",
       "      <td>Sorghum</td>\n",
       "      <td>0.281223</td>\n",
       "      <td>Field_47</td>\n",
       "      <td>0.421783</td>\n",
       "      <td>-0.340460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>312</td>\n",
       "      <td>0.240000</td>\n",
       "      <td>39.800723</td>\n",
       "      <td>40.323829</td>\n",
       "      <td>8.126362</td>\n",
       "      <td>Millets</td>\n",
       "      <td>0.582184</td>\n",
       "      <td>Field_3</td>\n",
       "      <td>0.873146</td>\n",
       "      <td>-0.565245</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4</td>\n",
       "      <td>254</td>\n",
       "      <td>0.195385</td>\n",
       "      <td>44.005098</td>\n",
       "      <td>44.097473</td>\n",
       "      <td>13.176073</td>\n",
       "      <td>Saffron</td>\n",
       "      <td>0.254225</td>\n",
       "      <td>Field_5</td>\n",
       "      <td>0.381289</td>\n",
       "      <td>-0.290785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>206</td>\n",
       "      <td>0.158462</td>\n",
       "      <td>52.382550</td>\n",
       "      <td>52.766332</td>\n",
       "      <td>16.352781</td>\n",
       "      <td>Mustard</td>\n",
       "      <td>0.546054</td>\n",
       "      <td>Field_1</td>\n",
       "      <td>0.818966</td>\n",
       "      <td>-0.531147</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>75</td>\n",
       "      <td>0.057692</td>\n",
       "      <td>46.216829</td>\n",
       "      <td>44.511292</td>\n",
       "      <td>6.718947</td>\n",
       "      <td>Coconut</td>\n",
       "      <td>-0.141769</td>\n",
       "      <td>Field_86</td>\n",
       "      <td>-0.212330</td>\n",
       "      <td>0.300656</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   prototype_id  support_n  support_share_trainval  representative_target  assigned_mean_target_trainval   rainfall crop_type      ndvi  field_id      savi      ndwi\n",
       "0             1        453                0.348462              31.892742                      32.555298   5.990291   Sorghum  0.281223  Field_47  0.421783 -0.340460\n",
       "1             0        312                0.240000              39.800723                      40.323829   8.126362   Millets  0.582184   Field_3  0.873146 -0.565245\n",
       "2             4        254                0.195385              44.005098                      44.097473  13.176073   Saffron  0.254225   Field_5  0.381289 -0.290785\n",
       "3             2        206                0.158462              52.382550                      52.766332  16.352781   Mustard  0.546054   Field_1  0.818966 -0.531147\n",
       "4             3         75                0.057692              46.216829                      44.511292   6.718947   Coconut -0.141769  Field_86 -0.212330  0.300656"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prototype_id</th>\n",
       "      <th>feature</th>\n",
       "      <th>contribution_score</th>\n",
       "      <th>interpreted_value</th>\n",
       "      <th>rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.808172</td>\n",
       "      <td>Field_3</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.703522</td>\n",
       "      <td>Millets</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>2.114486</td>\n",
       "      <td>0.582184</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>savi</td>\n",
       "      <td>2.013780</td>\n",
       "      <td>0.873146</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>1.961811</td>\n",
       "      <td>8.126362</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.808172</td>\n",
       "      <td>Field_47</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>4.822911</td>\n",
       "      <td>5.990291</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.718646</td>\n",
       "      <td>Sorghum</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>0.825673</td>\n",
       "      <td>0.281223</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>savi</td>\n",
       "      <td>0.786483</td>\n",
       "      <td>0.421783</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>2</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>9.056828</td>\n",
       "      <td>16.352781</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>2</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.808172</td>\n",
       "      <td>Field_1</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>2</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.707303</td>\n",
       "      <td>Mustard</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>2</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>1.761516</td>\n",
       "      <td>0.546054</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>2</td>\n",
       "      <td>savi</td>\n",
       "      <td>1.677645</td>\n",
       "      <td>0.818966</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>3</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.820325</td>\n",
       "      <td>Field_86</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>3</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>4.957980</td>\n",
       "      <td>-0.141769</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>3</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.801827</td>\n",
       "      <td>Coconut</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>3</td>\n",
       "      <td>savi</td>\n",
       "      <td>4.720528</td>\n",
       "      <td>-0.21233</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>3</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>3.846934</td>\n",
       "      <td>6.718947</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>4</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.777790</td>\n",
       "      <td>Field_5</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>4</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>4.801879</td>\n",
       "      <td>13.176073</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>4</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.707303</td>\n",
       "      <td>Saffron</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>4</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>1.089423</td>\n",
       "      <td>0.254225</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>4</td>\n",
       "      <td>savi</td>\n",
       "      <td>1.037709</td>\n",
       "      <td>0.381289</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    prototype_id    feature  contribution_score interpreted_value  rank\n",
       "0              0   field_id            7.808172           Field_3   1.0\n",
       "1              0  crop_type            4.703522           Millets   2.0\n",
       "2              0       ndvi            2.114486          0.582184   3.0\n",
       "3              0       savi            2.013780          0.873146   4.0\n",
       "4              0   rainfall            1.961811          8.126362   5.0\n",
       "5              1   field_id            7.808172          Field_47   1.0\n",
       "6              1   rainfall            4.822911          5.990291   2.0\n",
       "7              1  crop_type            4.718646           Sorghum   3.0\n",
       "8              1       ndvi            0.825673          0.281223   4.0\n",
       "9              1       savi            0.786483          0.421783   5.0\n",
       "10             2   rainfall            9.056828         16.352781   1.0\n",
       "11             2   field_id            7.808172           Field_1   2.0\n",
       "12             2  crop_type            4.707303           Mustard   3.0\n",
       "13             2       ndvi            1.761516          0.546054   4.0\n",
       "14             2       savi            1.677645          0.818966   5.0\n",
       "15             3   field_id            7.820325          Field_86   1.0\n",
       "16             3       ndvi            4.957980         -0.141769   2.0\n",
       "17             3  crop_type            4.801827           Coconut   3.0\n",
       "18             3       savi            4.720528          -0.21233   4.0\n",
       "19             3   rainfall            3.846934          6.718947   5.0\n",
       "20             4   field_id            7.777790           Field_5   1.0\n",
       "21             4   rainfall            4.801879         13.176073   2.0\n",
       "22             4  crop_type            4.707303           Saffron   3.0\n",
       "23             4       ndvi            1.089423          0.254225   4.0\n",
       "24             4       savi            1.037709          0.381289   5.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def compute_processed_weight_importance(model, feature_names):\n",
    "    first_layer = np.abs(model.coefs_[0]).mean(axis=1)\n",
    "    return pd.DataFrame({\n",
    "        \"processed_feature\": feature_names,\n",
    "        \"processed_weight_importance\": first_layer\n",
    "    })\n",
    "\n",
    "def aggregate_processed_importance(processed_df, numeric_features, categorical_features):\n",
    "    processed_df = processed_df.copy()\n",
    "    processed_df[\"feature\"] = processed_df[\"processed_feature\"].apply(\n",
    "        lambda x: processed_to_original(x, numeric_features, categorical_features)\n",
    "    )\n",
    "    agg = (\n",
    "        processed_df.groupby(\"feature\", as_index=False)[\"processed_weight_importance\"]\n",
    "        .sum()\n",
    "        .sort_values(\"processed_weight_importance\", ascending=False)\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "    return agg\n",
    "\n",
    "if best_model_name == \"Базовая нейросеть по всем признакам\":\n",
    "    best_result = baseline_result\n",
    "    best_eval = baseline_eval\n",
    "    X_trainval_best = pd.concat([X_train, X_val], axis=0).reset_index(drop=True)\n",
    "    y_trainval_best = pd.concat([y_train, y_val], axis=0).reset_index(drop=True)\n",
    "    used_numeric = numeric_features\n",
    "    used_categorical = categorical_features\n",
    "    used_original_features = X.columns.tolist()\n",
    "    perm_source_df = importance_df.copy()\n",
    "else:\n",
    "    best_result = selected_result\n",
    "    best_eval = selected_eval\n",
    "    X_trainval_best = pd.concat([X_train_sel, X_val_sel], axis=0).reset_index(drop=True)\n",
    "    y_trainval_best = pd.concat([y_train, y_val], axis=0).reset_index(drop=True)\n",
    "    used_numeric = numeric_features_sel\n",
    "    used_categorical = categorical_features_sel\n",
    "    used_original_features = selected_features\n",
    "\n",
    "    wrapped_selected = InverseScaledRegressorWrapper(\n",
    "        selected_result[\"model\"],\n",
    "        selected_result[\"y_scaler\"]\n",
    "    )\n",
    "    perm_sel = permutation_importance(\n",
    "        estimator=wrapped_selected,\n",
    "        X=selected_eval[\"X_val_proc\"],\n",
    "        y=y_val,\n",
    "        scoring=\"neg_root_mean_squared_error\",\n",
    "        n_repeats=10,\n",
    "        random_state=RANDOM_STATE\n",
    "    )\n",
    "    perm_sel_df = pd.DataFrame({\n",
    "        \"processed_feature\": selected_result[\"feature_names\"],\n",
    "        \"permutation_importance_mean\": perm_sel.importances_mean,\n",
    "        \"permutation_importance_std\": perm_sel.importances_std\n",
    "    })\n",
    "    perm_sel_df[\"feature\"] = perm_sel_df[\"processed_feature\"].apply(\n",
    "        lambda x: processed_to_original(x, used_numeric, used_categorical)\n",
    "    )\n",
    "    perm_source_df = (\n",
    "        perm_sel_df.groupby(\"feature\", as_index=False)[[\"permutation_importance_mean\", \"permutation_importance_std\"]]\n",
    "        .sum()\n",
    "        .sort_values(\"permutation_importance_mean\", ascending=False)\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "\n",
    "processed_weight_df = compute_processed_weight_importance(best_result[\"model\"], best_result[\"feature_names\"])\n",
    "weight_importance_df = aggregate_processed_importance(processed_weight_df, used_numeric, used_categorical)\n",
    "\n",
    "importance_merge = perm_source_df.merge(weight_importance_df, on=\"feature\", how=\"outer\").fillna(0)\n",
    "importance_merge[\"rank_permutation\"] = importance_merge[\"permutation_importance_mean\"].rank(ascending=False, method=\"average\")\n",
    "importance_merge[\"rank_weight\"] = importance_merge[\"processed_weight_importance\"].rank(ascending=False, method=\"average\")\n",
    "importance_merge[\"rank_mean\"] = importance_merge[[\"rank_permutation\", \"rank_weight\"]].mean(axis=1)\n",
    "importance_merge = importance_merge.sort_values(\"rank_mean\").reset_index(drop=True)\n",
    "\n",
    "display(importance_merge.head(10))\n",
    "\n",
    "X_trainval_best_proc = best_result[\"preprocessor\"].transform(X_trainval_best)\n",
    "proc_importance = processed_weight_df.copy()\n",
    "proc_importance[\"weight_norm\"] = proc_importance[\"processed_weight_importance\"] / (proc_importance[\"processed_weight_importance\"].sum() + 1e-12)\n",
    "weights_vector = proc_importance[\"weight_norm\"].values\n",
    "X_weighted = X_trainval_best_proc * np.sqrt(weights_vector + 1e-8)\n",
    "\n",
    "n_prototypes = min(5, max(3, len(X_trainval_best) // 100))\n",
    "kmeans = KMeans(n_clusters=n_prototypes, random_state=RANDOM_STATE, n_init=20)\n",
    "cluster_labels = kmeans.fit_predict(X_weighted)\n",
    "\n",
    "prototypes_rows = []\n",
    "top_proto_features = importance_merge[\"feature\"].head(min(6, len(importance_merge))).tolist()\n",
    "\n",
    "for proto_id in range(n_prototypes):\n",
    "    idx = np.where(cluster_labels == proto_id)[0]\n",
    "    cluster_X = X_trainval_best.iloc[idx]\n",
    "    cluster_y = y_trainval_best.iloc[idx]\n",
    "\n",
    "    row = {\n",
    "        \"prototype_id\": proto_id,\n",
    "        \"support_n\": int(len(idx)),\n",
    "        \"support_share_trainval\": float(len(idx) / len(X_trainval_best)),\n",
    "        \"representative_target\": float(np.median(cluster_y)),\n",
    "        \"assigned_mean_target_trainval\": float(np.mean(cluster_y)),\n",
    "    }\n",
    "    for feat in top_proto_features:\n",
    "        if feat in cluster_X.columns:\n",
    "            if pd.api.types.is_numeric_dtype(cluster_X[feat]):\n",
    "                row[feat] = float(cluster_X[feat].mean())\n",
    "            else:\n",
    "                mode_vals = cluster_X[feat].mode(dropna=True)\n",
    "                row[feat] = mode_vals.iloc[0] if len(mode_vals) else np.nan\n",
    "    prototypes_rows.append(row)\n",
    "\n",
    "prototypes_df = pd.DataFrame(prototypes_rows).sort_values(\"support_n\", ascending=False).reset_index(drop=True)\n",
    "\n",
    "global_stats = {}\n",
    "for feat in used_original_features:\n",
    "    if feat not in X_trainval_best.columns:\n",
    "        continue\n",
    "    if pd.api.types.is_numeric_dtype(X_trainval_best[feat]):\n",
    "        global_stats[feat] = {\n",
    "            \"type\": \"num\",\n",
    "            \"mean\": float(X_trainval_best[feat].mean()),\n",
    "            \"std\": float(X_trainval_best[feat].std(ddof=0) + 1e-9)\n",
    "        }\n",
    "    else:\n",
    "        freq = X_trainval_best[feat].astype(str).value_counts(normalize=True, dropna=False).to_dict()\n",
    "        global_stats[feat] = {\n",
    "            \"type\": \"cat\",\n",
    "            \"freq\": freq\n",
    "        }\n",
    "\n",
    "importance_lookup = importance_merge.set_index(\"feature\")[[\"permutation_importance_mean\", \"processed_weight_importance\"]].to_dict(\"index\")\n",
    "\n",
    "contrib_rows = []\n",
    "for _, prow in prototypes_df.iterrows():\n",
    "    proto_id = int(prow[\"prototype_id\"])\n",
    "    for feat in used_original_features:\n",
    "        if feat not in X_trainval_best.columns:\n",
    "            continue\n",
    "        imp = importance_lookup.get(feat, {\"permutation_importance_mean\": 0.0, \"processed_weight_importance\": 0.0})\n",
    "        feature_importance = float(max(imp[\"permutation_importance_mean\"], 0) + imp[\"processed_weight_importance\"])\n",
    "\n",
    "        if global_stats[feat][\"type\"] == \"num\":\n",
    "            value = float(prow.get(feat, X_trainval_best[feat].mean()))\n",
    "            z = abs((value - global_stats[feat][\"mean\"]) / global_stats[feat][\"std\"])\n",
    "            score = feature_importance * z\n",
    "            interpreted_value = value\n",
    "        else:\n",
    "            value = str(prow.get(feat, np.nan))\n",
    "            freq = global_stats[feat][\"freq\"].get(value, 0.0)\n",
    "            score = feature_importance * (1 - freq)\n",
    "            interpreted_value = value\n",
    "\n",
    "        contrib_rows.append({\n",
    "            \"prototype_id\": proto_id,\n",
    "            \"feature\": feat,\n",
    "            \"contribution_score\": float(score),\n",
    "            \"interpreted_value\": interpreted_value\n",
    "        })\n",
    "\n",
    "proto_contrib_df = (\n",
    "    pd.DataFrame(contrib_rows)\n",
    "    .sort_values([\"prototype_id\", \"contribution_score\"], ascending=[True, False])\n",
    "    .groupby(\"prototype_id\")\n",
    "    .head(5)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "proto_contrib_df[\"rank\"] = proto_contrib_df.groupby(\"prototype_id\")[\"contribution_score\"].rank(ascending=False, method=\"first\")\n",
    "\n",
    "display(prototypes_df)\n",
    "display(proto_contrib_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9a42c0e",
   "metadata": {},
   "source": [
    "## Итоговая служебная ячейка для вывода\n",
    "\n",
    "Эта ячейка печатает все ключевые результаты в компактном виде."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cad8cdcc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== КЛЮЧЕВАЯ ИНФОРМАЦИЯ ДЛЯ ВЫВОДА ===\n",
      "Зависимая переменная: yield\n",
      "\n",
      "Число исходных признаков: 15\n",
      "Число признаков после кодирования (все признаки): 133\n",
      "Число отобранных исходных признаков: 8\n",
      "Число признаков после кодирования (отобранные признаки): 126\n",
      "Фиксированная архитектура первой модели: (64, 32)\n",
      "Архитектура, найденная эволюционным поиском: (64,)\n",
      "alpha найденной модели: 3.774676323193703e-06\n",
      "learning_rate_init найденной модели: 0.009822263016087532\n",
      "\n",
      "Отобранные признаки:\n",
      "['rainfall', 'ndvi', 'savi', 'crop_type', 'ndwi', 'latitude', 'field_id', 'soil_moisture']\n",
      "\n",
      "Метрики моделей на тестовой выборке:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>r2_test</th>\n",
       "      <th>rmse_test</th>\n",
       "      <th>mae_test</th>\n",
       "      <th>r2_val</th>\n",
       "      <th>rmse_val</th>\n",
       "      <th>mae_val</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Нейросеть по отобранным признакам + эволюционн...</td>\n",
       "      <td>0.933225</td>\n",
       "      <td>2.275202</td>\n",
       "      <td>1.291376</td>\n",
       "      <td>0.885422</td>\n",
       "      <td>2.664205</td>\n",
       "      <td>1.501977</td>\n",
       "      <td>8</td>\n",
       "      <td>126</td>\n",
       "      <td>(64,)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Базовая нейросеть по всем признакам</td>\n",
       "      <td>0.902929</td>\n",
       "      <td>2.743203</td>\n",
       "      <td>1.755315</td>\n",
       "      <td>0.870878</td>\n",
       "      <td>2.828249</td>\n",
       "      <td>1.775249</td>\n",
       "      <td>15</td>\n",
       "      <td>133</td>\n",
       "      <td>(64, 32)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               model   r2_test  rmse_test  mae_test    r2_val  rmse_val   mae_val  n_original_features  n_processed_features hidden_layers\n",
       "0  Нейросеть по отобранным признакам + эволюционн...  0.933225   2.275202  1.291376  0.885422  2.664205  1.501977                    8                   126         (64,)\n",
       "1                Базовая нейросеть по всем признакам  0.902929   2.743203  1.755315  0.870878  2.828249  1.775249                   15                   133      (64, 32)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Лучшая модель:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>r2_test</th>\n",
       "      <th>rmse_test</th>\n",
       "      <th>mae_test</th>\n",
       "      <th>r2_val</th>\n",
       "      <th>rmse_val</th>\n",
       "      <th>mae_val</th>\n",
       "      <th>n_original_features</th>\n",
       "      <th>n_processed_features</th>\n",
       "      <th>hidden_layers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Нейросеть по отобранным признакам + эволюционн...</td>\n",
       "      <td>0.933225</td>\n",
       "      <td>2.275202</td>\n",
       "      <td>1.291376</td>\n",
       "      <td>0.885422</td>\n",
       "      <td>2.664205</td>\n",
       "      <td>1.501977</td>\n",
       "      <td>8</td>\n",
       "      <td>126</td>\n",
       "      <td>(64,)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               model   r2_test  rmse_test  mae_test    r2_val  rmse_val   mae_val  n_original_features  n_processed_features hidden_layers\n",
       "0  Нейросеть по отобранным признакам + эволюционн...  0.933225   2.275202  1.291376  0.885422  2.664205  1.501977                    8                   126         (64,)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Топ-10 признаков по важности:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature</th>\n",
       "      <th>permutation_importance_mean</th>\n",
       "      <th>permutation_importance_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>rainfall</td>\n",
       "      <td>5.739523</td>\n",
       "      <td>0.193029</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ndvi</td>\n",
       "      <td>2.208541</td>\n",
       "      <td>0.169049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>savi</td>\n",
       "      <td>1.538535</td>\n",
       "      <td>0.151130</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>crop_type</td>\n",
       "      <td>1.415012</td>\n",
       "      <td>0.608425</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ndwi</td>\n",
       "      <td>0.890150</td>\n",
       "      <td>0.098273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>latitude</td>\n",
       "      <td>0.798413</td>\n",
       "      <td>0.087298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>field_id</td>\n",
       "      <td>0.692456</td>\n",
       "      <td>0.718612</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>soil_moisture</td>\n",
       "      <td>0.609792</td>\n",
       "      <td>0.092161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>gndvi</td>\n",
       "      <td>0.568960</td>\n",
       "      <td>0.064340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>date_dayofyear</td>\n",
       "      <td>0.462667</td>\n",
       "      <td>0.074318</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          feature  permutation_importance_mean  permutation_importance_std\n",
       "0        rainfall                     5.739523                    0.193029\n",
       "1            ndvi                     2.208541                    0.169049\n",
       "2            savi                     1.538535                    0.151130\n",
       "3       crop_type                     1.415012                    0.608425\n",
       "4            ndwi                     0.890150                    0.098273\n",
       "5        latitude                     0.798413                    0.087298\n",
       "6        field_id                     0.692456                    0.718612\n",
       "7   soil_moisture                     0.609792                    0.092161\n",
       "8           gndvi                     0.568960                    0.064340\n",
       "9  date_dayofyear                     0.462667                    0.074318"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Топ-10 конфигураций эволюционного поиска:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>hidden_layers</th>\n",
       "      <th>alpha</th>\n",
       "      <th>learning_rate_init</th>\n",
       "      <th>rmse_val</th>\n",
       "      <th>mae_val</th>\n",
       "      <th>r2_val</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000004</td>\n",
       "      <td>0.009822</td>\n",
       "      <td>2.664205</td>\n",
       "      <td>1.501977</td>\n",
       "      <td>0.885422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>(96,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>2.672635</td>\n",
       "      <td>1.421672</td>\n",
       "      <td>0.884696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>(64,)</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>2.675607</td>\n",
       "      <td>1.394640</td>\n",
       "      <td>0.884440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>(80, 40, 20)</td>\n",
       "      <td>0.000005</td>\n",
       "      <td>0.005015</td>\n",
       "      <td>2.683304</td>\n",
       "      <td>1.368121</td>\n",
       "      <td>0.883774</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  hidden_layers     alpha  learning_rate_init  rmse_val   mae_val    r2_val\n",
       "0         (64,)  0.000004            0.009822  2.664205  1.501977  0.885422\n",
       "1         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "2         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "3         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "4         (96,)  0.000003            0.009648  2.672635  1.421672  0.884696\n",
       "5         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "6         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "7         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "8         (64,)  0.000003            0.010000  2.675607  1.394640  0.884440\n",
       "9  (80, 40, 20)  0.000005            0.005015  2.683304  1.368121  0.883774"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Последние эпохи обучения базовой нейросети:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>val_r2</th>\n",
       "      <th>val_rmse</th>\n",
       "      <th>val_mae</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>130</th>\n",
       "      <td>131</td>\n",
       "      <td>0.002935</td>\n",
       "      <td>0.869096</td>\n",
       "      <td>2.847705</td>\n",
       "      <td>1.784982</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>132</td>\n",
       "      <td>0.003854</td>\n",
       "      <td>0.869229</td>\n",
       "      <td>2.846251</td>\n",
       "      <td>1.777142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>133</td>\n",
       "      <td>0.003546</td>\n",
       "      <td>0.867845</td>\n",
       "      <td>2.861276</td>\n",
       "      <td>1.812772</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>134</td>\n",
       "      <td>0.003689</td>\n",
       "      <td>0.867460</td>\n",
       "      <td>2.865439</td>\n",
       "      <td>1.822093</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>135</td>\n",
       "      <td>0.002960</td>\n",
       "      <td>0.870878</td>\n",
       "      <td>2.828249</td>\n",
       "      <td>1.775249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>136</td>\n",
       "      <td>0.004358</td>\n",
       "      <td>0.870413</td>\n",
       "      <td>2.833335</td>\n",
       "      <td>1.799200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>136</th>\n",
       "      <td>137</td>\n",
       "      <td>0.003298</td>\n",
       "      <td>0.869790</td>\n",
       "      <td>2.840137</td>\n",
       "      <td>1.761328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>137</th>\n",
       "      <td>138</td>\n",
       "      <td>0.003451</td>\n",
       "      <td>0.868436</td>\n",
       "      <td>2.854873</td>\n",
       "      <td>1.820676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>138</th>\n",
       "      <td>139</td>\n",
       "      <td>0.003324</td>\n",
       "      <td>0.869806</td>\n",
       "      <td>2.839963</td>\n",
       "      <td>1.781549</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139</th>\n",
       "      <td>140</td>\n",
       "      <td>0.003620</td>\n",
       "      <td>0.870025</td>\n",
       "      <td>2.837579</td>\n",
       "      <td>1.791714</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     epoch  train_loss    val_r2  val_rmse   val_mae\n",
       "130    131    0.002935  0.869096  2.847705  1.784982\n",
       "131    132    0.003854  0.869229  2.846251  1.777142\n",
       "132    133    0.003546  0.867845  2.861276  1.812772\n",
       "133    134    0.003689  0.867460  2.865439  1.822093\n",
       "134    135    0.002960  0.870878  2.828249  1.775249\n",
       "135    136    0.004358  0.870413  2.833335  1.799200\n",
       "136    137    0.003298  0.869790  2.840137  1.761328\n",
       "137    138    0.003451  0.868436  2.854873  1.820676\n",
       "138    139    0.003324  0.869806  2.839963  1.781549\n",
       "139    140    0.003620  0.870025  2.837579  1.791714"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Последние эпохи обучения модели с найденной архитектурой:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>val_r2</th>\n",
       "      <th>val_rmse</th>\n",
       "      <th>val_mae</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>41</td>\n",
       "      <td>0.019114</td>\n",
       "      <td>0.877719</td>\n",
       "      <td>2.752308</td>\n",
       "      <td>1.439200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>42</td>\n",
       "      <td>0.020414</td>\n",
       "      <td>0.871173</td>\n",
       "      <td>2.825014</td>\n",
       "      <td>1.502276</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>43</td>\n",
       "      <td>0.020595</td>\n",
       "      <td>0.869560</td>\n",
       "      <td>2.842648</td>\n",
       "      <td>1.457928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>44</td>\n",
       "      <td>0.020510</td>\n",
       "      <td>0.866645</td>\n",
       "      <td>2.874235</td>\n",
       "      <td>1.521252</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>45</td>\n",
       "      <td>0.020608</td>\n",
       "      <td>0.871446</td>\n",
       "      <td>2.822021</td>\n",
       "      <td>1.508952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>46</td>\n",
       "      <td>0.020722</td>\n",
       "      <td>0.869417</td>\n",
       "      <td>2.844210</td>\n",
       "      <td>1.433634</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>47</td>\n",
       "      <td>0.018139</td>\n",
       "      <td>0.875383</td>\n",
       "      <td>2.778477</td>\n",
       "      <td>1.497307</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>48</td>\n",
       "      <td>0.020332</td>\n",
       "      <td>0.864971</td>\n",
       "      <td>2.892221</td>\n",
       "      <td>1.498261</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>49</td>\n",
       "      <td>0.021497</td>\n",
       "      <td>0.863142</td>\n",
       "      <td>2.911740</td>\n",
       "      <td>1.533948</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>50</td>\n",
       "      <td>0.019801</td>\n",
       "      <td>0.870129</td>\n",
       "      <td>2.836441</td>\n",
       "      <td>1.450026</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    epoch  train_loss    val_r2  val_rmse   val_mae\n",
       "40     41    0.019114  0.877719  2.752308  1.439200\n",
       "41     42    0.020414  0.871173  2.825014  1.502276\n",
       "42     43    0.020595  0.869560  2.842648  1.457928\n",
       "43     44    0.020510  0.866645  2.874235  1.521252\n",
       "44     45    0.020608  0.871446  2.822021  1.508952\n",
       "45     46    0.020722  0.869417  2.844210  1.433634\n",
       "46     47    0.018139  0.875383  2.778477  1.497307\n",
       "47     48    0.020332  0.864971  2.892221  1.498261\n",
       "48     49    0.021497  0.863142  2.911740  1.533948\n",
       "49     50    0.019801  0.870129  2.836441  1.450026"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Основные прототипы по лучшей сети:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prototype_id</th>\n",
       "      <th>support_n</th>\n",
       "      <th>support_share_trainval</th>\n",
       "      <th>representative_target</th>\n",
       "      <th>assigned_mean_target_trainval</th>\n",
       "      <th>rainfall</th>\n",
       "      <th>crop_type</th>\n",
       "      <th>ndvi</th>\n",
       "      <th>field_id</th>\n",
       "      <th>savi</th>\n",
       "      <th>ndwi</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>453</td>\n",
       "      <td>0.348462</td>\n",
       "      <td>31.892742</td>\n",
       "      <td>32.555298</td>\n",
       "      <td>5.990291</td>\n",
       "      <td>Sorghum</td>\n",
       "      <td>0.281223</td>\n",
       "      <td>Field_47</td>\n",
       "      <td>0.421783</td>\n",
       "      <td>-0.340460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>312</td>\n",
       "      <td>0.240000</td>\n",
       "      <td>39.800723</td>\n",
       "      <td>40.323829</td>\n",
       "      <td>8.126362</td>\n",
       "      <td>Millets</td>\n",
       "      <td>0.582184</td>\n",
       "      <td>Field_3</td>\n",
       "      <td>0.873146</td>\n",
       "      <td>-0.565245</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4</td>\n",
       "      <td>254</td>\n",
       "      <td>0.195385</td>\n",
       "      <td>44.005098</td>\n",
       "      <td>44.097473</td>\n",
       "      <td>13.176073</td>\n",
       "      <td>Saffron</td>\n",
       "      <td>0.254225</td>\n",
       "      <td>Field_5</td>\n",
       "      <td>0.381289</td>\n",
       "      <td>-0.290785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>206</td>\n",
       "      <td>0.158462</td>\n",
       "      <td>52.382550</td>\n",
       "      <td>52.766332</td>\n",
       "      <td>16.352781</td>\n",
       "      <td>Mustard</td>\n",
       "      <td>0.546054</td>\n",
       "      <td>Field_1</td>\n",
       "      <td>0.818966</td>\n",
       "      <td>-0.531147</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>75</td>\n",
       "      <td>0.057692</td>\n",
       "      <td>46.216829</td>\n",
       "      <td>44.511292</td>\n",
       "      <td>6.718947</td>\n",
       "      <td>Coconut</td>\n",
       "      <td>-0.141769</td>\n",
       "      <td>Field_86</td>\n",
       "      <td>-0.212330</td>\n",
       "      <td>0.300656</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   prototype_id  support_n  support_share_trainval  representative_target  assigned_mean_target_trainval   rainfall crop_type      ndvi  field_id      savi      ndwi\n",
       "0             1        453                0.348462              31.892742                      32.555298   5.990291   Sorghum  0.281223  Field_47  0.421783 -0.340460\n",
       "1             0        312                0.240000              39.800723                      40.323829   8.126362   Millets  0.582184   Field_3  0.873146 -0.565245\n",
       "2             4        254                0.195385              44.005098                      44.097473  13.176073   Saffron  0.254225   Field_5  0.381289 -0.290785\n",
       "3             2        206                0.158462              52.382550                      52.766332  16.352781   Mustard  0.546054   Field_1  0.818966 -0.531147\n",
       "4             3         75                0.057692              46.216829                      44.511292   6.718947   Coconut -0.141769  Field_86 -0.212330  0.300656"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Топ-5 вкладов признаков для прототипов:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prototype_id</th>\n",
       "      <th>feature</th>\n",
       "      <th>contribution_score</th>\n",
       "      <th>interpreted_value</th>\n",
       "      <th>rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.808172</td>\n",
       "      <td>Field_3</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.703522</td>\n",
       "      <td>Millets</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>2.114486</td>\n",
       "      <td>0.582184</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>savi</td>\n",
       "      <td>2.013780</td>\n",
       "      <td>0.873146</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>1.961811</td>\n",
       "      <td>8.126362</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.808172</td>\n",
       "      <td>Field_47</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>4.822911</td>\n",
       "      <td>5.990291</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.718646</td>\n",
       "      <td>Sorghum</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>0.825673</td>\n",
       "      <td>0.281223</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>savi</td>\n",
       "      <td>0.786483</td>\n",
       "      <td>0.421783</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>2</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>9.056828</td>\n",
       "      <td>16.352781</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>2</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.808172</td>\n",
       "      <td>Field_1</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>2</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.707303</td>\n",
       "      <td>Mustard</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>2</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>1.761516</td>\n",
       "      <td>0.546054</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>2</td>\n",
       "      <td>savi</td>\n",
       "      <td>1.677645</td>\n",
       "      <td>0.818966</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>3</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.820325</td>\n",
       "      <td>Field_86</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>3</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>4.957980</td>\n",
       "      <td>-0.141769</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>3</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.801827</td>\n",
       "      <td>Coconut</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>3</td>\n",
       "      <td>savi</td>\n",
       "      <td>4.720528</td>\n",
       "      <td>-0.21233</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>3</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>3.846934</td>\n",
       "      <td>6.718947</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>4</td>\n",
       "      <td>field_id</td>\n",
       "      <td>7.777790</td>\n",
       "      <td>Field_5</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>4</td>\n",
       "      <td>rainfall</td>\n",
       "      <td>4.801879</td>\n",
       "      <td>13.176073</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>4</td>\n",
       "      <td>crop_type</td>\n",
       "      <td>4.707303</td>\n",
       "      <td>Saffron</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>4</td>\n",
       "      <td>ndvi</td>\n",
       "      <td>1.089423</td>\n",
       "      <td>0.254225</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>4</td>\n",
       "      <td>savi</td>\n",
       "      <td>1.037709</td>\n",
       "      <td>0.381289</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    prototype_id    feature  contribution_score interpreted_value  rank\n",
       "0              0   field_id            7.808172           Field_3   1.0\n",
       "1              0  crop_type            4.703522           Millets   2.0\n",
       "2              0       ndvi            2.114486          0.582184   3.0\n",
       "3              0       savi            2.013780          0.873146   4.0\n",
       "4              0   rainfall            1.961811          8.126362   5.0\n",
       "5              1   field_id            7.808172          Field_47   1.0\n",
       "6              1   rainfall            4.822911          5.990291   2.0\n",
       "7              1  crop_type            4.718646           Sorghum   3.0\n",
       "8              1       ndvi            0.825673          0.281223   4.0\n",
       "9              1       savi            0.786483          0.421783   5.0\n",
       "10             2   rainfall            9.056828         16.352781   1.0\n",
       "11             2   field_id            7.808172           Field_1   2.0\n",
       "12             2  crop_type            4.707303           Mustard   3.0\n",
       "13             2       ndvi            1.761516          0.546054   4.0\n",
       "14             2       savi            1.677645          0.818966   5.0\n",
       "15             3   field_id            7.820325          Field_86   1.0\n",
       "16             3       ndvi            4.957980         -0.141769   2.0\n",
       "17             3  crop_type            4.801827           Coconut   3.0\n",
       "18             3       savi            4.720528          -0.21233   4.0\n",
       "19             3   rainfall            3.846934          6.718947   5.0\n",
       "20             4   field_id            7.777790           Field_5   1.0\n",
       "21             4   rainfall            4.801879         13.176073   2.0\n",
       "22             4  crop_type            4.707303           Saffron   3.0\n",
       "23             4       ndvi            1.089423          0.254225   4.0\n",
       "24             4       savi            1.037709          0.381289   5.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Первые 10 фактических и предсказанных значений на тесте:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>y_true</th>\n",
       "      <th>y_pred_baseline</th>\n",
       "      <th>y_pred_selected_evolution</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>42.940963</td>\n",
       "      <td>42.939202</td>\n",
       "      <td>43.146026</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>54.260435</td>\n",
       "      <td>54.585298</td>\n",
       "      <td>55.617393</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>36.368367</td>\n",
       "      <td>35.403129</td>\n",
       "      <td>38.181089</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>30.515626</td>\n",
       "      <td>34.520034</td>\n",
       "      <td>33.731471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>34.354857</td>\n",
       "      <td>40.373399</td>\n",
       "      <td>35.667811</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>43.712983</td>\n",
       "      <td>43.366476</td>\n",
       "      <td>43.688931</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>37.526396</td>\n",
       "      <td>38.296231</td>\n",
       "      <td>38.164426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>29.785338</td>\n",
       "      <td>29.048048</td>\n",
       "      <td>30.262982</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>45.288319</td>\n",
       "      <td>45.136464</td>\n",
       "      <td>45.805447</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>41.616011</td>\n",
       "      <td>43.522796</td>\n",
       "      <td>42.239683</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      y_true  y_pred_baseline  y_pred_selected_evolution\n",
       "0  42.940963        42.939202                  43.146026\n",
       "1  54.260435        54.585298                  55.617393\n",
       "2  36.368367        35.403129                  38.181089\n",
       "3  30.515626        34.520034                  33.731471\n",
       "4  34.354857        40.373399                  35.667811\n",
       "5  43.712983        43.366476                  43.688931\n",
       "6  37.526396        38.296231                  38.164426\n",
       "7  29.785338        29.048048                  30.262982\n",
       "8  45.288319        45.136464                  45.805447\n",
       "9  41.616011        43.522796                  42.239683"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_compare_df = pd.DataFrame({\n",
    "    \"y_true\": y_test.reset_index(drop=True),\n",
    "    \"y_pred_baseline\": baseline_eval[\"pred_test\"],\n",
    "    \"y_pred_selected_evolution\": selected_eval[\"pred_test\"],\n",
    "}).head(10)\n",
    "\n",
    "print(\"=== КЛЮЧЕВАЯ ИНФОРМАЦИЯ ДЛЯ ВЫВОДА ===\")\n",
    "print(\"Зависимая переменная:\", target_col)\n",
    "print()\n",
    "\n",
    "print(\"Число исходных признаков:\", X.shape[1])\n",
    "print(\"Число признаков после кодирования (все признаки):\", len(baseline_result[\"feature_names\"]))\n",
    "print(\"Число отобранных исходных признаков:\", len(selected_features))\n",
    "print(\"Число признаков после кодирования (отобранные признаки):\", len(selected_result[\"feature_names\"]))\n",
    "print(\"Фиксированная архитектура первой модели:\", (64, 32))\n",
    "print(\"Архитектура, найденная эволюционным поиском:\", best_candidate[\"hidden_layers\"])\n",
    "print(\"alpha найденной модели:\", float(best_candidate[\"alpha\"]))\n",
    "print(\"learning_rate_init найденной модели:\", float(best_candidate[\"learning_rate_init\"]))\n",
    "print()\n",
    "\n",
    "print(\"Отобранные признаки:\")\n",
    "print(selected_features)\n",
    "print()\n",
    "\n",
    "print(\"Метрики моделей на тестовой выборке:\")\n",
    "display(comparison_df)\n",
    "\n",
    "print(\"Лучшая модель:\")\n",
    "display(comparison_df.head(1))\n",
    "\n",
    "print(\"Топ-10 признаков по важности:\")\n",
    "display(importance_df.head(10))\n",
    "\n",
    "print(\"Топ-10 конфигураций эволюционного поиска:\")\n",
    "display(evolution_df.head(10)[[\"hidden_layers\", \"alpha\", \"learning_rate_init\", \"rmse_val\", \"mae_val\", \"r2_val\"]])\n",
    "\n",
    "print(\"Последние эпохи обучения базовой нейросети:\")\n",
    "display(baseline_result[\"history\"].tail(10))\n",
    "\n",
    "print(\"Последние эпохи обучения модели с найденной архитектурой:\")\n",
    "display(selected_result[\"history\"].tail(10))\n",
    "\n",
    "print(\"Основные прототипы по лучшей сети:\")\n",
    "display(prototypes_df)\n",
    "\n",
    "print(\"Топ-5 вкладов признаков для прототипов:\")\n",
    "display(proto_contrib_df)\n",
    "\n",
    "print(\"Первые 10 фактических и предсказанных значений на тесте:\")\n",
    "display(pred_compare_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "363e2b40",
   "metadata": {},
   "source": [
    "Итог\n",
    "\n",
    "Зависимой переменной выбрана yield. Наиболее важными признаками оказались rainfall, ndvi, savi, crop_type, ndwi, latitude, field_id, soil_moisture; по ним построена вторая нейросеть с архитектурой, найденной эволюционным поиском (64,).\n",
    "\n",
    "На тестовой выборке лучшей стала модель по отобранным признакам: R² = 0.933, RMSE = 2.275, MAE = 1.291. Базовая нейросеть по всем признакам показала более слабый результат (R² = 0.903, RMSE = 2.743, MAE = 1.755).\n",
    "\n",
    "Прототипы лучшей сети соответствуют интерпретируемым агрономическим профилям: группы с разными сочетаниями осадков, вегетационных индексов (NDVI/SAVI), типа культуры и поля. Наиболее продуктивные прототипы связаны с более высокими rainfall и выраженными вегетационными индексами, тогда как менее продуктивные — с более низкими значениями этих показателей.\n",
    "\n",
    "Следовательно, ключевой вклад в прогноз урожайности вносят осадки, спектральные индексы и тип культуры, а модель по отобранным признакам с эволюционно найденной архитектурой является наиболее качественной и интерпретируемой."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
