{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6979c756",
   "metadata": {},
   "source": [
    "# Car recommendation notebook — H1 variants, tuned top models, batched evaluation\n",
    "\n",
    "This notebook focuses on richer H1-style formulations only and removes the old slow screening path.\n",
    "\n",
    "Goals:\n",
    "- try several harder H1 variants\n",
    "- improve the current top families: RP3beta, EASE, TwoTower, MultVAE\n",
    "- batch all heavy evaluation paths\n",
    "- use GPU where it helps and CPU threads where possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8fb2968a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU thread target: 16\n",
      "PYTORCH_CUDA_ALLOC_CONF: expandable_segments:True,max_split_size_mb:256\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "\n",
    "CPU_THREADS = min(16, os.cpu_count() or 1)\n",
    "for var in [\n",
    "    \"OMP_NUM_THREADS\",\n",
    "    \"OPENBLAS_NUM_THREADS\",\n",
    "    \"MKL_NUM_THREADS\",\n",
    "    \"NUMEXPR_NUM_THREADS\",\n",
    "    \"VECLIB_MAXIMUM_THREADS\",\n",
    "]:\n",
    "    os.environ[var] = str(CPU_THREADS)\n",
    "\n",
    "# Helps reduce CUDA allocator fragmentation in long notebook runs.\n",
    "os.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True,max_split_size_mb:256\")\n",
    "\n",
    "import gc\n",
    "import math\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy import linalg as sla\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "try:\n",
    "    from threadpoolctl import threadpool_limits\n",
    "except Exception:\n",
    "    threadpool_limits = None\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "\n",
    "if threadpool_limits is not None:\n",
    "    try:\n",
    "        threadpool_limits(limits=CPU_THREADS)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"CPU thread target:\", CPU_THREADS)\n",
    "print(\"PYTORCH_CUDA_ALLOC_CONF:\", os.environ.get(\"PYTORCH_CUDA_ALLOC_CONF\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6731d71d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Formulations: ['H1_base', 'H1_color', 'H1_fine', 'H1_fine_color']\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "CURRENT_YEAR = 2026\n",
    "\n",
    "# Binning\n",
    "PRICE_BINS_BASE = 12\n",
    "PRICE_BINS_FINE = 16\n",
    "MILEAGE_BINS_BASE = 12\n",
    "MILEAGE_BINS_FINE = 16\n",
    "AGE_BINS_BASE = 10\n",
    "AGE_BINS_FINE = 14\n",
    "\n",
    "# Filtering\n",
    "FILTER_SCHEDULE = [\n",
    "    (5, 10),\n",
    "    (4, 8),\n",
    "    (3, 5),\n",
    "    (2, 3),\n",
    "]\n",
    "MAX_FILTER_ITERS = 10\n",
    "\n",
    "# Evaluation\n",
    "TOP_KS = [5, 10, 20]\n",
    "EVAL_MAX_USERS = None\n",
    "LINEAR_EVAL_BATCH = 4096\n",
    "NEURAL_EVAL_BATCH = 2048\n",
    "\n",
    "# Manual richer H1 variants only\n",
    "FORMULATIONS = {\n",
    "    \"H1_base\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H1_fine\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_fine_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "# Cheap screen all formulations first, then run the heavy regular/neural search only on the best ones\n",
    "TOP_FORMULATIONS_FOR_FULL_REGULAR = 2\n",
    "TOP_FORMULATIONS_FOR_NEURAL = 2\n",
    "\n",
    "# Best current families only\n",
    "# One cheap RP3 config for formulation screening, then a tighter grid for the full search\n",
    "SCREEN_RP3 = (1.10, 0.70)\n",
    "\n",
    "RP3_GRID = [\n",
    "    (1.00, 0.60),\n",
    "    (1.10, 0.70),\n",
    "    (1.15, 0.70),\n",
    "]\n",
    "\n",
    "# Keep only the EASE lambdas that were already near the top in previous runs.\n",
    "EASE_BINARY_LAMBDAS = [1000.0, 1200.0, 1600.0]\n",
    "EASE_COUNT_LAMBDAS = [1200.0, 1600.0]\n",
    "\n",
    "# Skip very large EASE formulations instead of spending tens of minutes building/factoring huge Gram matrices.\n",
    "MAX_EASE_ITEMS = 12000\n",
    "# Only try GPU EASE for comparatively small item spaces; larger ones fall back to CPU torch.\n",
    "EASE_GPU_MAX_GRAM_GB = 0.60\n",
    "\n",
    "TWOTOWER_CONFIGS = [\n",
    "    {\"name\": \"TwoTower_t2\", \"emb_dim\": 96,  \"hidden_dims\": (768, 384, 192),  \"out_dim\": 160, \"epochs\": 28, \"lr\": 1.5e-3, \"wd\": 1e-5, \"temperature\": 0.05},\n",
    "    {\"name\": \"TwoTower_t3\", \"emb_dim\": 128, \"hidden_dims\": (1024, 512, 256), \"out_dim\": 192, \"epochs\": 32, \"lr\": 1.2e-3, \"wd\": 1e-5, \"temperature\": 0.04},\n",
    "]\n",
    "\n",
    "MULTVAE_CONFIGS = [\n",
    "    {\"name\": \"MultVAE_v3\", \"hidden_dim\": 2048, \"latent_dim\": 512, \"dropout\": 0.30, \"epochs\": 100, \"lr\": 6e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "    {\"name\": \"MultVAE_v4\", \"hidden_dim\": 2560, \"latent_dim\": 640, \"dropout\": 0.35, \"epochs\": 110, \"lr\": 5e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "]\n",
    "\n",
    "NEUMF_CONFIGS = [\n",
    "    {\"name\": \"NeuMF_n2\", \"mf_dim\": 96, \"mlp_dim\": 192, \"hidden_dims\": (384, 192, 96), \"epochs\": 26, \"lr\": 1.5e-3, \"wd\": 1e-6},\n",
    "]\n",
    "\n",
    "# Start large, but allow the training functions to shrink automatically on OOM.\n",
    "TWOTOWER_BATCH_START = 8192\n",
    "TWOTOWER_BATCH_MIN = 1024\n",
    "PAIRWISE_BATCH = 16384\n",
    "AUTOENC_BATCH_START = 2048\n",
    "AUTOENC_BATCH_MIN = 512\n",
    "NEUMF_EVAL_USER_BATCH = 512\n",
    "NEUMF_EVAL_ITEM_BATCH = 1024\n",
    "\n",
    "RUN_NEURAL = True\n",
    "RUN_FUSION = True\n",
    "FUSION_FETCH_N = 100\n",
    "FUSION_RRF_K = 60\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "349eba60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch version: 2.11.0\n",
      "Torch device : cuda\n",
      "USE_AMP      : True\n"
     ]
    }
   ],
   "source": [
    "# PyTorch runtime setup\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "\n",
    "try:\n",
    "    torch.set_num_threads(CPU_THREADS)\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "try:\n",
    "    torch.set_num_interop_threads(max(1, CPU_THREADS // 2))\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "USE_AMP = (device.type == \"cuda\")\n",
    "\n",
    "if device.type == \"cuda\":\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    try:\n",
    "        torch.set_float32_matmul_precision(\"high\")\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"Torch version:\", torch.__version__)\n",
    "print(\"Torch device :\", device)\n",
    "print(\"USE_AMP      :\", USE_AMP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "145329d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows: 1000000\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin12</th>\n",
       "      <th>PriceBin16</th>\n",
       "      <th>MileageBin12</th>\n",
       "      <th>MileageBin16</th>\n",
       "      <th>AgeBin10</th>\n",
       "      <th>AgeBin14</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>P12_3</td>\n",
       "      <td>P16_4</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_0</td>\n",
       "      <td>A14_0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_9</td>\n",
       "      <td>A14_13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>P12_7</td>\n",
       "      <td>P16_9</td>\n",
       "      <td>M12_0</td>\n",
       "      <td>M16_0</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_1</td>\n",
       "      <td>M16_2</td>\n",
       "      <td>A10_8</td>\n",
       "      <td>A14_11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>P12_0</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_4</td>\n",
       "      <td>M16_6</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    3   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   26   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   12   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   23   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   14   \n",
       "\n",
       "  PriceBin12 PriceBin16 MileageBin12 MileageBin16 AgeBin10 AgeBin14  \n",
       "0      P12_3      P16_4        M12_3        M16_4    A10_0    A14_0  \n",
       "1      P12_1      P16_1        M12_3        M16_4    A10_9   A14_13  \n",
       "2      P12_7      P16_9        M12_0        M16_0    A10_4    A14_5  \n",
       "3      P12_1      P16_1        M12_1        M16_2    A10_8   A14_11  \n",
       "4      P12_0      P16_1        M12_4        M16_6    A10_4    A14_6  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean the dataset\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\"]:\n",
    "    df[col] = df[col].astype(str).str.strip()\n",
    "\n",
    "df[\"Year\"] = pd.to_numeric(df[\"Year\"], errors=\"coerce\")\n",
    "df[\"Price\"] = pd.to_numeric(df[\"Price\"], errors=\"coerce\")\n",
    "df[\"Mileage\"] = pd.to_numeric(df[\"Mileage\"], errors=\"coerce\")\n",
    "\n",
    "df = df.dropna(subset=[\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\"]).copy()\n",
    "\n",
    "df[\"Year\"] = df[\"Year\"].astype(int)\n",
    "df[\"Age\"] = (CURRENT_YEAR - df[\"Year\"]).clip(lower=0)\n",
    "\n",
    "def make_qbin_labels(series, q, prefix):\n",
    "    cats = pd.qcut(series, q=q, labels=False, duplicates=\"drop\")\n",
    "    cats = cats.astype(int)\n",
    "    return cats.map(lambda x: f\"{prefix}{int(x)}\")\n",
    "\n",
    "df[\"PriceBin12\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_BASE, \"P12_\")\n",
    "df[\"PriceBin16\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_FINE, \"P16_\")\n",
    "df[\"MileageBin12\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_BASE, \"M12_\")\n",
    "df[\"MileageBin16\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_FINE, \"M16_\")\n",
    "df[\"AgeBin10\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_BASE, \"A10_\")\n",
    "df[\"AgeBin14\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_FINE, \"A14_\")\n",
    "\n",
    "print(\"Rows:\", len(df))\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "450426c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user, min_users_per_item, max_iters=10):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_n = len(out)\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if len(out) == old_n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    raw_interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "        .size()\n",
    "        .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = None\n",
    "    used_thresholds = None\n",
    "\n",
    "    for min_items_per_user, min_users_per_item in FILTER_SCHEDULE:\n",
    "        candidate = iterative_filter(\n",
    "            raw_interactions,\n",
    "            min_items_per_user=min_items_per_user,\n",
    "            min_users_per_item=min_users_per_item,\n",
    "            max_iters=MAX_FILTER_ITERS,\n",
    "        ).reset_index(drop=True)\n",
    "        if not candidate.empty:\n",
    "            interactions = candidate\n",
    "            used_thresholds = (min_items_per_user, min_users_per_item)\n",
    "            break\n",
    "\n",
    "    if interactions is None or interactions.empty:\n",
    "        raise ValueError(f\"{formulation_name} produced no interactions after filtering\")\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    user_feature_df = (\n",
    "        tmp[tmp[\"user_id\"].isin(interactions[\"user_id\"])]\n",
    "        [[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "    user_feature_df = user_feature_df.sort_values(\"user_idx\").reset_index(drop=True)\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[tmp[\"item_id\"].isin(interactions[\"item_id\"])]\n",
    "        [[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "    item_feature_df = item_feature_df.sort_values(\"item_idx\").reset_index(drop=True)\n",
    "\n",
    "    # Weighted leave-one-out\n",
    "    rng = np.random.default_rng(RANDOM_STATE)\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = rng.choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(int(picked))\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    num_users = int(interactions[\"user_idx\"].nunique())\n",
    "    num_items = int(interactions[\"item_idx\"].nunique())\n",
    "\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(u): int(i)\n",
    "        for u, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    user_seen_arrays = {}\n",
    "    for uid, g in train_interactions.groupby(\"user_idx\"):\n",
    "        user_seen_arrays[int(uid)] = np.sort(g[\"item_idx\"].astype(np.int32).to_numpy())\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy(dtype=np.int32)\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": list(user_cols),\n",
    "        \"item_cols\": list(item_cols),\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df,\n",
    "        \"item_feature_df\": item_feature_df,\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen_arrays\": user_seen_arrays,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_encoder.classes_.tolist(),\n",
    "        \"used_thresholds\": used_thresholds,\n",
    "    }\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "    avg_train = len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1)\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Thresholds :\", bundle[\"used_thresholds\"])\n",
    "    print(\"Train rows :\", len(bundle[\"train_interactions\"]))\n",
    "    print(\"Test rows  :\", len(bundle[\"test_interactions\"]))\n",
    "    print(\"Avg train interactions/user:\", round(avg_train, 3))\n",
    "\n",
    "def get_eval_users(bundle):\n",
    "    user_indices = np.array(sorted(bundle[\"test_item_by_user\"].keys()), dtype=np.int32)\n",
    "    if EVAL_MAX_USERS is not None and len(user_indices) > EVAL_MAX_USERS:\n",
    "        rng = np.random.default_rng(RANDOM_STATE)\n",
    "        user_indices = rng.choice(user_indices, size=EVAL_MAX_USERS, replace=False)\n",
    "    return np.sort(user_indices)\n",
    "\n",
    "def batch_metric_update(topk_idx, true_items, acc, ks):\n",
    "    matches = (topk_idx == true_items[:, None])\n",
    "    for k in ks:\n",
    "        acc[f\"HR@{k}\"] += matches[:, :k].any(axis=1).sum()\n",
    "\n",
    "    top10 = matches[:, :10]\n",
    "    found10 = top10.any(axis=1)\n",
    "    first_rank10 = np.where(found10, top10.argmax(axis=1) + 1, 0)\n",
    "\n",
    "    acc[\"MRR@10\"] += np.where(found10, 1.0 / first_rank10, 0.0).sum()\n",
    "    acc[\"NDCG@10\"] += np.where(found10, 1.0 / np.log2(first_rank10 + 1), 0.0).sum()\n",
    "    acc[\"UsersEval\"] += len(true_items)\n",
    "\n",
    "def finalize_metrics(acc, model_name, formulation_name):\n",
    "    users_eval = max(int(acc[\"UsersEval\"]), 1)\n",
    "    out = {\n",
    "        \"Formulation\": formulation_name,\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": users_eval,\n",
    "        \"HR@5\": float(acc[\"HR@5\"]) / users_eval,\n",
    "        \"HR@10\": float(acc[\"HR@10\"]) / users_eval,\n",
    "        \"HR@20\": float(acc[\"HR@20\"]) / users_eval,\n",
    "        \"MRR@10\": float(acc[\"MRR@10\"]) / users_eval,\n",
    "        \"NDCG@10\": float(acc[\"NDCG@10\"]) / users_eval,\n",
    "    }\n",
    "    return out\n",
    "\n",
    "def popularity_recommend_one(global_pop_rank, seen_array, n):\n",
    "    if seen_array is None or len(seen_array) == 0:\n",
    "        return global_pop_rank[:n].tolist()\n",
    "    seen_set = set(int(x) for x in seen_array)\n",
    "    out = []\n",
    "    for item in global_pop_rank:\n",
    "        item = int(item)\n",
    "        if item not in seen_set:\n",
    "            out.append(item)\n",
    "            if len(out) >= n:\n",
    "                break\n",
    "    return out\n",
    "\n",
    "def evaluate_popularity(bundle, model_name=\"Popularity\"):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    global_pop_rank = bundle[\"global_pop_rank\"]\n",
    "    user_seen_arrays = bundle[\"user_seen_arrays\"]\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    for uid in tqdm(user_indices, desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        true_item = test_item_by_user[int(uid)]\n",
    "        recs = popularity_recommend_one(global_pop_rank, user_seen_arrays.get(int(uid)), max(TOP_KS))\n",
    "        recs_arr = np.array(recs, dtype=np.int32)[None, :]\n",
    "        true_arr = np.array([true_item], dtype=np.int32)\n",
    "        batch_metric_update(recs_arr, true_arr, acc, TOP_KS)\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def clear_memory():\n",
    "    gc.collect()\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e001d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model fitting helpers and batched evaluators\n",
    "\n",
    "def l1_row_normalize(X):\n",
    "    X = X.tocsr(copy=True)\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel().astype(np.float32)\n",
    "    inv = np.zeros_like(row_sums, dtype=np.float32)\n",
    "    mask = row_sums > 0\n",
    "    inv[mask] = 1.0 / row_sums[mask]\n",
    "    return sp.diags(inv) @ X\n",
    "\n",
    "\n",
    "def fit_p3alpha(X_binary, alpha=1.0, beta=0.0):\n",
    "    Pui = l1_row_normalize(X_binary).power(alpha).tocsr()\n",
    "    Piu = l1_row_normalize(X_binary.T).power(alpha).tocsr()\n",
    "\n",
    "    S = (Piu @ Pui).astype(np.float32).tocsr()\n",
    "    S.setdiag(0.0)\n",
    "    S.eliminate_zeros()\n",
    "\n",
    "    if beta > 0:\n",
    "        item_degree = np.asarray(X_binary.sum(axis=0)).ravel().astype(np.float32)\n",
    "        penalty = np.ones_like(item_degree, dtype=np.float32)\n",
    "        mask = item_degree > 0\n",
    "        penalty[mask] = np.power(item_degree[mask], -beta)\n",
    "        S = S @ sp.diags(penalty.astype(np.float32))\n",
    "\n",
    "    return S.tocsr()\n",
    "\n",
    "\n",
    "def prepare_ease_cache(X_matrix, source_name=\"binary\", prefer_gpu=True):\n",
    "    t0 = time.time()\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32, copy=False)\n",
    "    n_items = int(G.shape[0])\n",
    "    gram_gb = (G.nbytes / (1024 ** 3))\n",
    "\n",
    "    backend = \"cpu_torch\"\n",
    "    prep = {\n",
    "        \"source_name\": source_name,\n",
    "        \"backend\": backend,\n",
    "        \"gram_np\": G,\n",
    "        \"n_items\": n_items,\n",
    "        \"gram_gb\": gram_gb,\n",
    "        \"prep_sec\": time.time() - t0,\n",
    "    }\n",
    "\n",
    "    # GPU EASE fit needs room for the Gram matrix, a factorization workspace, and the solved matrix.\n",
    "    # Use it only when the matrix is small enough to fit comfortably.\n",
    "    if prefer_gpu and device.type == \"cuda\" and gram_gb <= EASE_GPU_MAX_GRAM_GB:\n",
    "        try:\n",
    "            free_bytes, total_bytes = torch.cuda.mem_get_info(device=device)\n",
    "            free_gb = free_bytes / (1024 ** 3)\n",
    "            needed_gb = gram_gb * 3.2\n",
    "            if free_gb >= needed_gb:\n",
    "                gram_t = torch.from_numpy(G).to(device=device, dtype=torch.float32, non_blocking=True)\n",
    "                prep.update({\n",
    "                    \"backend\": \"gpu\",\n",
    "                    \"gram_t\": gram_t,\n",
    "                })\n",
    "            else:\n",
    "                prep[\"backend\"] = \"cpu_torch\"\n",
    "        except Exception:\n",
    "            prep[\"backend\"] = \"cpu_torch\"\n",
    "\n",
    "    return prep\n",
    "\n",
    "def fit_ease_from_cache(cache, lam):\n",
    "    n_items = int(cache[\"n_items\"])\n",
    "\n",
    "    if cache.get(\"backend\") == \"gpu\":\n",
    "        try:\n",
    "            gram_t = cache[\"gram_t\"].clone()\n",
    "            gram_t.diagonal().add_(float(lam))\n",
    "            L, info = torch.linalg.cholesky_ex(gram_t, upper=False)\n",
    "            info_val = int(info.item()) if hasattr(info, \"item\") else int(info)\n",
    "            if info_val == 0:\n",
    "                eye_t = torch.eye(n_items, dtype=torch.float32, device=gram_t.device)\n",
    "                P = torch.cholesky_solve(eye_t, L, upper=False)\n",
    "                diag = torch.diagonal(P).clone()\n",
    "                B = -P / diag.unsqueeze(0)\n",
    "                B.fill_diagonal_(0.0)\n",
    "                out = B.detach().cpu().numpy().astype(np.float32, copy=False)\n",
    "                del gram_t, L, eye_t, P, diag, B\n",
    "                clear_memory()\n",
    "                return out\n",
    "            else:\n",
    "                print(f\"    GPU Cholesky failed at lambda={lam} (info={info_val}), falling back to CPU torch.\")\n",
    "        except RuntimeError as e:\n",
    "            print(f\"    GPU EASE failed at lambda={lam}, falling back to CPU torch: {e}\")\n",
    "\n",
    "    # CPU torch path: uses torch's threaded linear algebra without requiring huge GPU memory.\n",
    "    G_t = torch.from_numpy(cache[\"gram_np\"]).to(dtype=torch.float32, device=\"cpu\").clone()\n",
    "    G_t.diagonal().add_(float(lam))\n",
    "    eye_t = torch.eye(n_items, dtype=torch.float32, device=\"cpu\")\n",
    "    L, info = torch.linalg.cholesky_ex(G_t, upper=False)\n",
    "    info_val = int(info.item()) if hasattr(info, \"item\") else int(info)\n",
    "    if info_val != 0:\n",
    "        # Final fallback to SciPy if CPU torch cholesky reports numerical trouble.\n",
    "        del G_t, eye_t, L, info\n",
    "        clear_memory()\n",
    "        G = cache[\"gram_np\"].copy()\n",
    "        G[np.diag_indices_from(G)] += np.float32(lam)\n",
    "        c, lower = sla.cho_factor(G, lower=True, overwrite_a=True, check_finite=False)\n",
    "        I = np.eye(G.shape[0], dtype=np.float32)\n",
    "        P = sla.cho_solve((c, lower), I, overwrite_b=True, check_finite=False)\n",
    "        diag = np.diag(P).copy()\n",
    "        B = -P / diag[None, :]\n",
    "        np.fill_diagonal(B, 0.0)\n",
    "        return B.astype(np.float32, copy=False)\n",
    "\n",
    "    P = torch.cholesky_solve(eye_t, L, upper=False)\n",
    "    diag = torch.diagonal(P).clone()\n",
    "    B = -P / diag.unsqueeze(0)\n",
    "    B.fill_diagonal_(0.0)\n",
    "    out = B.numpy().astype(np.float32, copy=False)\n",
    "    del G_t, eye_t, L, info, P, diag, B\n",
    "    clear_memory()\n",
    "    return out\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_dense_operator(bundle, operator_matrix, model_name, source=\"binary\", batch_size=4096):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    X_source = bundle[\"X_binary\"] if source == \"binary\" else bundle[\"X_counts\"]\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    if isinstance(operator_matrix, torch.Tensor):\n",
    "        op_t = operator_matrix.to(device=device, dtype=torch.float32)\n",
    "    else:\n",
    "        op_t = torch.as_tensor(operator_matrix, dtype=torch.float32, device=device)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), batch_size), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + batch_size]\n",
    "        X_np = X_source[batch_uids].toarray().astype(np.float32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "\n",
    "        X_t = torch.from_numpy(X_np).to(device, non_blocking=True)\n",
    "        scores = X_t @ op_t\n",
    "        scores = scores.masked_fill(X_t > 0, -1e9)\n",
    "\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "\n",
    "        del X_t, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "def make_rrf_scores(score_matrices, rrf_k=60):\n",
    "    # score_matrices: list of top-k item index arrays with shape [batch, fetch_n]\n",
    "    batch, fetch_n = score_matrices[0].shape\n",
    "    fused = np.zeros((batch, fetch_n * len(score_matrices)), dtype=np.float32)\n",
    "    del fused\n",
    "    # kept only as a placeholder helper; fusion is done in a simple Python routine later\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "def make_rrf_scores(score_matrices, rrf_k=60):\n",
    "    # score_matrices: list of top-k item index arrays with shape [batch, fetch_n]\n",
    "    batch, fetch_n = score_matrices[0].shape\n",
    "    fused = np.zeros((batch, fetch_n * len(score_matrices)), dtype=np.float32)\n",
    "    del fused\n",
    "    # kept only as a placeholder helper; fusion is done in a simple Python routine later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "33f155c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Formulation: H1_base\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition']\n",
      "Users      : 43199\n",
      "Items      : 2640\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 831152\n",
      "Test rows  : 43199\n",
      "Avg train interactions/user: 19.24\n",
      "====================================================================================================\n",
      "Formulation: H1_color\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition', 'Color']\n",
      "Users      : 62921\n",
      "Items      : 13198\n",
      "Thresholds : (4, 8)\n",
      "Train rows : 237108\n",
      "Test rows  : 62921\n",
      "Avg train interactions/user: 3.768\n",
      "====================================================================================================\n",
      "Formulation: H1_fine\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition']\n",
      "Users      : 95529\n",
      "Items      : 3696\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 812918\n",
      "Test rows  : 95529\n",
      "Avg train interactions/user: 8.51\n",
      "====================================================================================================\n",
      "Formulation: H1_fine_color\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition', 'Color']\n",
      "Users      : 59346\n",
      "Items      : 23514\n",
      "Thresholds : (3, 5)\n",
      "Train rows : 134712\n",
      "Test rows  : 59346\n",
      "Avg train interactions/user: 2.27\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>TrainRows</th>\n",
       "      <th>AvgTrainPerUser</th>\n",
       "      <th>Thresholds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_fine_color</td>\n",
       "      <td>59346</td>\n",
       "      <td>23514</td>\n",
       "      <td>134712</td>\n",
       "      <td>2.269942</td>\n",
       "      <td>(3, 5)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_color</td>\n",
       "      <td>62921</td>\n",
       "      <td>13198</td>\n",
       "      <td>237108</td>\n",
       "      <td>3.768344</td>\n",
       "      <td>(4, 8)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>95529</td>\n",
       "      <td>3696</td>\n",
       "      <td>812918</td>\n",
       "      <td>8.509646</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>2640</td>\n",
       "      <td>831152</td>\n",
       "      <td>19.240075</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     Formulation  Users  Items  TrainRows  AvgTrainPerUser Thresholds\n",
       "0  H1_fine_color  59346  23514     134712         2.269942     (3, 5)\n",
       "1       H1_color  62921  13198     237108         3.768344     (4, 8)\n",
       "2        H1_fine  95529   3696     812918         8.509646    (5, 10)\n",
       "3        H1_base  43199   2640     831152        19.240075    (5, 10)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Build all richer H1 formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"TrainRows\": len(bundle[\"train_interactions\"]),\n",
    "        \"AvgTrainPerUser\": len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1),\n",
    "        \"Thresholds\": bundle[\"used_thresholds\"],\n",
    "    })\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values([\"Items\", \"Users\"], ascending=False).reset_index(drop=True)\n",
    "display(summary_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a61a7425",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_base\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8101ae854b1d459196b353508d6acd44",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / Popularity:   0%|          | 0/43199 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.10s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a9fc0e39d144452acd92f25969ba8ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / RP3beta_screen_a1_1_b0_7:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.61s\n",
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_color\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4673b4da404c44019107abbec1da029c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_color / Popularity:   0%|          | 0/62921 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "649787f0eca447719388d8bd78d2105b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_color / RP3beta_screen_a1_1_b0_7:   0%|          | 0/16 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.44s\n",
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_fine\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48987d047bda4d7da192757e77a98b2a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / Popularity:   0%|          | 0/95529 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2447b48716b249c089ee91f08cfb94e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / RP3beta_screen_a1_1_b0_7:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.53s\n",
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_fine_color\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d08b4e2ddf54d5ea5980d02a326800a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine_color / Popularity:   0%|          | 0/59346 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.15s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b2af53a7fa24ce28250077909a36e30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine_color / RP3beta_screen_a1_1_b0_7:   0%|          | 0/15 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=3.19s\n",
      "========================================================================================================================\n",
      "Formulation cheap-screen summary\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>screen_best_hr10</th>\n",
       "      <th>screen_best_ndcg10</th>\n",
       "      <th>screen_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>0.191070</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>0.159298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_fine_color</td>\n",
       "      <td>0.127877</td>\n",
       "      <td>0.060050</td>\n",
       "      <td>0.142890</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_color</td>\n",
       "      <td>0.119626</td>\n",
       "      <td>0.054071</td>\n",
       "      <td>0.133144</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     Formulation  screen_best_hr10  screen_best_ndcg10  screen_score\n",
       "0        H1_base          0.169518            0.086209      0.191070\n",
       "1        H1_fine          0.142208            0.068361      0.159298\n",
       "2  H1_fine_color          0.127877            0.060050      0.142890\n",
       "3       H1_color          0.119626            0.054071      0.133144"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected formulations for full regular search: ['H1_base', 'H1_fine']\n",
      "========================================================================================================================\n",
      "Full regular search on formulation: H1_base\n",
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "474a065cfa9245dc811f99140e0dc711",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / RP3beta_a1_0_b0_6:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d8f2f496c35436897319092541db9f4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / RP3beta_a1_1_b0_7:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.11s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "830a73bbec2d474ba9a331fc6a7490cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / RP3beta_a1_15_b0_7:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.11s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=2640 gram_gb=0.03 prep_time=0.06s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=2640 gram_gb=0.03 prep_time=0.06s\n",
      "- Fitting EASE_binary_l1000\n",
      "    fit_time=0.14s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd783fe7ea2a428790d2b9c3a7e59547",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / EASE_binary_l1000:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.11s\n",
      "- Fitting EASE_binary_l1200\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00cd3a63200b48b885f455281bcc052b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / EASE_binary_l1200:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.13s\n",
      "- Fitting EASE_binary_l1600\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b4c64d4387e498993f8b24ccffcf450",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / EASE_binary_l1600:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.13s\n",
      "- Fitting EASE_count_l1200\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca7a1597aca8462c9b5764ed9047b94d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / EASE_count_l1200:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.14s\n",
      "- Fitting EASE_count_l1600\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90abf0ee97664103bde22a30cbd0c464",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / EASE_count_l1600:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.14s\n",
      "========================================================================================================================\n",
      "Full regular search on formulation: H1_fine\n",
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b3ac6e21c84c4083bdd2cda3b9c5856f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / RP3beta_a1_0_b0_6:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.43s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1cea72ccb24442f81dfa54f675a9d76",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / RP3beta_a1_1_b0_7:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0a8db6d8b91f47abbec43717d98138d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / RP3beta_a1_15_b0_7:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.34s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=3696 gram_gb=0.05 prep_time=0.05s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=3696 gram_gb=0.05 prep_time=0.05s\n",
      "- Fitting EASE_binary_l1000\n",
      "    fit_time=0.10s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ff041e6bc504d7ba2643551ba2f056b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / EASE_binary_l1000:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting EASE_binary_l1200\n",
      "    fit_time=0.10s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b402b5e356541468250919ca4fa1637",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / EASE_binary_l1200:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting EASE_binary_l1600\n",
      "    fit_time=0.11s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9c264202f44f4ea2bf956b4d6596f406",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / EASE_binary_l1600:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting EASE_count_l1200\n",
      "    fit_time=0.10s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b6ab2e0451374451afda7d7d74b2589c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / EASE_count_l1200:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting EASE_count_l1600\n",
      "    fit_time=0.10s\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c858bc740b2e4f0589f792bd2442bf14",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / EASE_count_l1600:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "      <th>Phase</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099354</td>\n",
       "      <td>0.170050</td>\n",
       "      <td>0.312947</td>\n",
       "      <td>0.061000</td>\n",
       "      <td>0.086145</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099956</td>\n",
       "      <td>0.169958</td>\n",
       "      <td>0.312577</td>\n",
       "      <td>0.061645</td>\n",
       "      <td>0.086629</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099470</td>\n",
       "      <td>0.169842</td>\n",
       "      <td>0.313618</td>\n",
       "      <td>0.061275</td>\n",
       "      <td>0.086323</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099169</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.312299</td>\n",
       "      <td>0.060961</td>\n",
       "      <td>0.085968</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098567</td>\n",
       "      <td>0.169147</td>\n",
       "      <td>0.312716</td>\n",
       "      <td>0.061023</td>\n",
       "      <td>0.085960</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098289</td>\n",
       "      <td>0.168221</td>\n",
       "      <td>0.310216</td>\n",
       "      <td>0.060397</td>\n",
       "      <td>0.085273</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098127</td>\n",
       "      <td>0.168129</td>\n",
       "      <td>0.311003</td>\n",
       "      <td>0.060514</td>\n",
       "      <td>0.085340</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083137</td>\n",
       "      <td>0.146018</td>\n",
       "      <td>0.268986</td>\n",
       "      <td>0.050754</td>\n",
       "      <td>0.072685</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082300</td>\n",
       "      <td>0.145338</td>\n",
       "      <td>0.268913</td>\n",
       "      <td>0.050044</td>\n",
       "      <td>0.071969</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.084278</td>\n",
       "      <td>0.145055</td>\n",
       "      <td>0.269562</td>\n",
       "      <td>0.050897</td>\n",
       "      <td>0.072607</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083472</td>\n",
       "      <td>0.144731</td>\n",
       "      <td>0.269552</td>\n",
       "      <td>0.050221</td>\n",
       "      <td>0.071991</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.081159</td>\n",
       "      <td>0.144532</td>\n",
       "      <td>0.269740</td>\n",
       "      <td>0.048689</td>\n",
       "      <td>0.070727</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082457</td>\n",
       "      <td>0.144522</td>\n",
       "      <td>0.268892</td>\n",
       "      <td>0.049771</td>\n",
       "      <td>0.071576</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.076961</td>\n",
       "      <td>0.142292</td>\n",
       "      <td>0.268903</td>\n",
       "      <td>0.046265</td>\n",
       "      <td>0.068288</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>Popularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.002847</td>\n",
       "      <td>0.006250</td>\n",
       "      <td>0.012037</td>\n",
       "      <td>0.001778</td>\n",
       "      <td>0.002793</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>Popularity</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.002450</td>\n",
       "      <td>0.004878</td>\n",
       "      <td>0.009882</td>\n",
       "      <td>0.001397</td>\n",
       "      <td>0.002194</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Formulation                     Model  UsersEval      HR@5     HR@10  \\\n",
       "0      H1_base         EASE_binary_l1000      43199  0.099354  0.170050   \n",
       "1      H1_base         RP3beta_a1_0_b0_6      43199  0.099956  0.169958   \n",
       "2      H1_base        RP3beta_a1_15_b0_7      43199  0.099470  0.169842   \n",
       "3      H1_base  RP3beta_screen_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "4      H1_base         RP3beta_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "5      H1_base         EASE_binary_l1200      43199  0.099169  0.169356   \n",
       "6      H1_base         EASE_binary_l1600      43199  0.098567  0.169147   \n",
       "7      H1_base          EASE_count_l1200      43199  0.098289  0.168221   \n",
       "8      H1_base          EASE_count_l1600      43199  0.098127  0.168129   \n",
       "9      H1_fine          EASE_count_l1600      95529  0.083137  0.146018   \n",
       "10     H1_fine          EASE_count_l1200      95529  0.082300  0.145338   \n",
       "11     H1_fine         EASE_binary_l1600      95529  0.084278  0.145055   \n",
       "12     H1_fine         EASE_binary_l1200      95529  0.083472  0.144731   \n",
       "13     H1_fine         RP3beta_a1_0_b0_6      95529  0.081159  0.144532   \n",
       "14     H1_fine         EASE_binary_l1000      95529  0.082457  0.144522   \n",
       "15     H1_fine        RP3beta_a1_15_b0_7      95529  0.076961  0.142292   \n",
       "16     H1_fine  RP3beta_screen_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "17     H1_fine         RP3beta_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "18     H1_base                Popularity      43199  0.002847  0.006250   \n",
       "19     H1_fine                Popularity      95529  0.002450  0.004878   \n",
       "\n",
       "       HR@20    MRR@10   NDCG@10   Phase  \n",
       "0   0.312947  0.061000  0.086145    full  \n",
       "1   0.312577  0.061645  0.086629    full  \n",
       "2   0.313618  0.061275  0.086323    full  \n",
       "3   0.313109  0.061218  0.086209  screen  \n",
       "4   0.313109  0.061218  0.086209    full  \n",
       "5   0.312299  0.060961  0.085968    full  \n",
       "6   0.312716  0.061023  0.085960    full  \n",
       "7   0.310216  0.060397  0.085273    full  \n",
       "8   0.311003  0.060514  0.085340    full  \n",
       "9   0.268986  0.050754  0.072685    full  \n",
       "10  0.268913  0.050044  0.071969    full  \n",
       "11  0.269562  0.050897  0.072607    full  \n",
       "12  0.269552  0.050221  0.071991    full  \n",
       "13  0.269740  0.048689  0.070727    full  \n",
       "14  0.268892  0.049771  0.071576    full  \n",
       "15  0.268903  0.046265  0.068288    full  \n",
       "16  0.269060  0.046382  0.068361  screen  \n",
       "17  0.269060  0.046382  0.068361    full  \n",
       "18  0.012037  0.001778  0.002793  screen  \n",
       "19  0.009882  0.001397  0.002194  screen  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected formulations for neural search: ['H1_base', 'H1_fine']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Regular model search:\n",
    "# 1) Cheap screen on all formulations using only fast models\n",
    "# 2) Full regular search only on the strongest formulations\n",
    "screen_rows = []\n",
    "\n",
    "for formulation_name, bundle in bundles.items():\n",
    "    print(\"=\" * 120)\n",
    "    print(\"Cheap screen on formulation:\", formulation_name)\n",
    "\n",
    "    pop_res = evaluate_popularity(bundle, model_name=\"Popularity\")\n",
    "    pop_res[\"Phase\"] = \"screen\"\n",
    "    screen_rows.append(pop_res)\n",
    "\n",
    "    alpha, beta = SCREEN_RP3\n",
    "    clear_memory()\n",
    "    model_name = f\"RP3beta_screen_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "    print(\"- Fitting\", model_name)\n",
    "    t_fit = time.time()\n",
    "    S = fit_p3alpha(bundle[\"X_binary\"], alpha=alpha, beta=beta)\n",
    "    S_dense = S.toarray().astype(np.float32, copy=False)\n",
    "    fit_sec = time.time() - t_fit\n",
    "    print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "    t_eval = time.time()\n",
    "    rp3_res = evaluate_dense_operator(bundle, S_dense, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "    rp3_res[\"Phase\"] = \"screen\"\n",
    "    screen_rows.append(rp3_res)\n",
    "    print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "    del S, S_dense\n",
    "    clear_memory()\n",
    "\n",
    "screen_results_df = pd.DataFrame(screen_rows)\n",
    "screen_summary = (\n",
    "    screen_results_df.groupby(\"Formulation\")\n",
    "    .agg(\n",
    "        screen_best_hr10=(\"HR@10\", \"max\"),\n",
    "        screen_best_ndcg10=(\"NDCG@10\", \"max\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "screen_summary[\"screen_score\"] = screen_summary[\"screen_best_hr10\"] + 0.25 * screen_summary[\"screen_best_ndcg10\"]\n",
    "screen_summary = screen_summary.sort_values([\"screen_score\", \"screen_best_hr10\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "print(\"=\" * 120)\n",
    "print(\"Formulation cheap-screen summary\")\n",
    "display(screen_summary)\n",
    "\n",
    "selected_regular_formulations = screen_summary.head(TOP_FORMULATIONS_FOR_FULL_REGULAR)[\"Formulation\"].tolist()\n",
    "print(\"Selected formulations for full regular search:\", selected_regular_formulations)\n",
    "\n",
    "regular_results = []\n",
    "# Keep the screen results too, but only for selected formulations\n",
    "regular_results.extend(\n",
    "    screen_results_df[screen_results_df[\"Formulation\"].isin(selected_regular_formulations)].to_dict(\"records\")\n",
    ")\n",
    "\n",
    "for formulation_name in selected_regular_formulations:\n",
    "    bundle = bundles[formulation_name]\n",
    "    print(\"=\" * 120)\n",
    "    print(\"Full regular search on formulation:\", formulation_name)\n",
    "\n",
    "    # RP3beta grid\n",
    "    for alpha, beta in RP3_GRID:\n",
    "        clear_memory()\n",
    "        model_name = f\"RP3beta_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        S = fit_p3alpha(bundle[\"X_binary\"], alpha=alpha, beta=beta)\n",
    "        S_dense = S.toarray().astype(np.float32, copy=False)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        res = evaluate_dense_operator(bundle, S_dense, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        res[\"Phase\"] = \"full\"\n",
    "        regular_results.append(res)\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del S, S_dense\n",
    "        clear_memory()\n",
    "\n",
    "    # EASE only if the item space is not too large\n",
    "    n_items = int(bundle[\"num_items\"])\n",
    "    if n_items > MAX_EASE_ITEMS:\n",
    "        print(f\"- Skipping EASE on {formulation_name}: n_items={n_items} > MAX_EASE_ITEMS={MAX_EASE_ITEMS}\")\n",
    "        continue\n",
    "\n",
    "    clear_memory()\n",
    "    print(\"- Preparing EASE binary Gram cache\")\n",
    "    ease_bin_cache = prepare_ease_cache(bundle[\"X_binary\"], source_name=\"binary\", prefer_gpu=True)\n",
    "    print(\n",
    "        f\"    backend={ease_bin_cache['backend']} \"\n",
    "        f\"n_items={ease_bin_cache['n_items']} \"\n",
    "        f\"gram_gb={ease_bin_cache['gram_gb']:.2f} \"\n",
    "        f\"prep_time={ease_bin_cache['prep_sec']:.2f}s\"\n",
    "    )\n",
    "\n",
    "    print(\"- Preparing EASE count Gram cache\")\n",
    "    ease_cnt_cache = prepare_ease_cache(bundle[\"X_counts\"], source_name=\"count\", prefer_gpu=True)\n",
    "    print(\n",
    "        f\"    backend={ease_cnt_cache['backend']} \"\n",
    "        f\"n_items={ease_cnt_cache['n_items']} \"\n",
    "        f\"gram_gb={ease_cnt_cache['gram_gb']:.2f} \"\n",
    "        f\"prep_time={ease_cnt_cache['prep_sec']:.2f}s\"\n",
    "    )\n",
    "\n",
    "    for lam in EASE_BINARY_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_binary_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        B = fit_ease_from_cache(ease_bin_cache, lam=lam)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        res = evaluate_dense_operator(bundle, B, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        res[\"Phase\"] = \"full\"\n",
    "        regular_results.append(res)\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "    for lam in EASE_COUNT_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_count_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        B = fit_ease_from_cache(ease_cnt_cache, lam=lam)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        res = evaluate_dense_operator(bundle, B, model_name=model_name, source=\"count\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        res[\"Phase\"] = \"full\"\n",
    "        regular_results.append(res)\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "    del ease_bin_cache, ease_cnt_cache\n",
    "    clear_memory()\n",
    "\n",
    "regular_results_df = (\n",
    "    pd.DataFrame(regular_results)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "display(regular_results_df.head(40))\n",
    "\n",
    "\n",
    "# Choose formulations for neural search from the strongest regular results.\n",
    "selected_neural_formulations = (\n",
    "    regular_results_df.groupby(\"Formulation\")[\"HR@10\"]\n",
    "    .max()\n",
    "    .sort_values(ascending=False)\n",
    "    .head(TOP_FORMULATIONS_FOR_NEURAL)\n",
    "    .index.tolist()\n",
    ")\n",
    "print(\"Selected formulations for neural search:\", selected_neural_formulations)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ec3cf5b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing neural tensors: H1_base\n",
      "Preparing neural tensors: H1_fine\n"
     ]
    }
   ],
   "source": [
    "# Prepare formulation-specific tensors for neural models\n",
    "\n",
    "def prepare_neural_data(bundle):\n",
    "    user_feature_df = bundle[\"user_feature_df\"].copy()\n",
    "    item_feature_df = bundle[\"item_feature_df\"].copy()\n",
    "    train_interactions = bundle[\"train_interactions\"].copy()\n",
    "\n",
    "    user_cols = bundle[\"user_cols\"]\n",
    "    item_cols = bundle[\"item_cols\"]\n",
    "\n",
    "    feature_encoders = {}\n",
    "\n",
    "    for col in user_cols:\n",
    "        le = LabelEncoder()\n",
    "        user_feature_df[col] = le.fit_transform(user_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"user::{col}\"] = le\n",
    "\n",
    "    for col in item_cols:\n",
    "        le = LabelEncoder()\n",
    "        item_feature_df[col] = le.fit_transform(item_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"item::{col}\"] = le\n",
    "\n",
    "    user_feat_np = user_feature_df[user_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "    item_feat_np = item_feature_df[item_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "\n",
    "    pos_user_np = train_interactions[\"user_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_item_np = train_interactions[\"item_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_weight_np = train_interactions[\"count\"].to_numpy(dtype=np.float32, copy=True)\n",
    "\n",
    "    x_binary_np = bundle[\"X_binary\"].toarray().astype(np.float32, copy=False)\n",
    "\n",
    "    out = {\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"user_feat_t\": torch.from_numpy(user_feat_np).to(device),\n",
    "        \"item_feat_t\": torch.from_numpy(item_feat_np).to(device),\n",
    "        \"pos_user_t\": torch.from_numpy(pos_user_np).to(device),\n",
    "        \"pos_item_t\": torch.from_numpy(pos_item_np).to(device),\n",
    "        \"pos_weight_t\": torch.from_numpy(pos_weight_np).to(device),\n",
    "        \"x_binary_np\": x_binary_np,\n",
    "        # Keep the dense user-item matrix on CPU. It is only needed for the autoencoder\n",
    "        # models and for seen-item masking, and preloading it to CUDA costs too much VRAM.\n",
    "        \"x_binary_t\": None,\n",
    "        \"x_binary_on_device\": False,\n",
    "    }\n",
    "\n",
    "    out[\"user_feature_cards\"] = {\n",
    "        col: int(user_feature_df[col].nunique()) for col in user_cols\n",
    "    }\n",
    "    out[\"item_feature_cards\"] = {\n",
    "        col: int(item_feature_df[col].nunique()) for col in item_cols\n",
    "    }\n",
    "\n",
    "    return out\n",
    "\n",
    "neural_data_by_formulation = {}\n",
    "for formulation_name in selected_neural_formulations:\n",
    "    print(\"Preparing neural tensors:\", formulation_name)\n",
    "    neural_data_by_formulation[formulation_name] = prepare_neural_data(bundles[formulation_name])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7c9a604a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural architectures and batched evaluators\n",
    "\n",
    "class FeatureEmbeddingBlock(nn.Module):\n",
    "    def __init__(self, cardinalities, emb_dim):\n",
    "        super().__init__()\n",
    "        self.cols = list(cardinalities.keys())\n",
    "        self.embs = nn.ModuleList([\n",
    "            nn.Embedding(int(cardinalities[c]), emb_dim) for c in self.cols\n",
    "        ])\n",
    "        for emb in self.embs:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, x):\n",
    "        parts = [emb(x[:, i]) for i, emb in enumerate(self.embs)]\n",
    "        return torch.cat(parts, dim=1)\n",
    "\n",
    "class MLPBlock(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, dropout=0.1, final_dim=None):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        prev = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(dropout)])\n",
    "            prev = h\n",
    "        if final_dim is not None:\n",
    "            layers.append(nn.Linear(prev, final_dim))\n",
    "            prev = final_dim\n",
    "        self.net = nn.Sequential(*layers)\n",
    "        self.output_dim = prev\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "class TowerEncoder(nn.Module):\n",
    "    def __init__(self, cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.embed = FeatureEmbeddingBlock(cards, emb_dim)\n",
    "        input_dim = len(cards) * emb_dim\n",
    "        self.mlp = MLPBlock(input_dim, hidden_dims, dropout=dropout, final_dim=out_dim)\n",
    "\n",
    "    def forward(self, feat_batch):\n",
    "        x = self.embed(feat_batch)\n",
    "        x = self.mlp(x)\n",
    "        return F.normalize(x, dim=-1)\n",
    "\n",
    "class TwoTowerModel(nn.Module):\n",
    "    def __init__(self, user_cards, item_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "        super().__init__()\n",
    "        self.user_tower = TowerEncoder(user_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "        self.item_tower = TowerEncoder(item_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "class MultVAE(nn.Module):\n",
    "    def __init__(self, num_items, hidden_dim=2048, latent_dim=512, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(num_items, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(latent_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, num_items),\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = self.encoder(self.dropout(x))\n",
    "        return self.mu(h), self.logvar(h)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        if self.training:\n",
    "            std = torch.exp(0.5 * logvar)\n",
    "            eps = torch.randn_like(std)\n",
    "            return mu + eps * std\n",
    "        return mu\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        logits = self.decoder(z)\n",
    "        return logits, mu, logvar\n",
    "\n",
    "class NeuMF(nn.Module):\n",
    "    def __init__(self, num_users, num_items, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128)):\n",
    "        super().__init__()\n",
    "        self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "        self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "        self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "        self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "        mlp_layers = []\n",
    "        prev = 2 * mlp_dim\n",
    "        for h in hidden_dims:\n",
    "            mlp_layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(0.1)])\n",
    "            prev = h\n",
    "        self.mlp = nn.Sequential(*mlp_layers)\n",
    "        self.out = nn.Linear(prev + mf_dim, 1)\n",
    "\n",
    "        for emb in [self.user_mf, self.item_mf, self.user_mlp, self.item_mlp]:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, user_idx, item_idx):\n",
    "        mf_part = self.user_mf(user_idx) * self.item_mf(item_idx)\n",
    "        mlp_in = torch.cat([self.user_mlp(user_idx), self.item_mlp(item_idx)], dim=1)\n",
    "        mlp_part = self.mlp(mlp_in)\n",
    "        x = torch.cat([mf_part, mlp_part], dim=1)\n",
    "        return self.out(x).squeeze(-1)\n",
    "\n",
    "def iter_positive_batches(ndata, batch_size):\n",
    "    n = ndata[\"pos_user_t\"].shape[0]\n",
    "    perm = torch.randperm(n, device=device)\n",
    "    for start in range(0, n, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        yield (\n",
    "            ndata[\"pos_user_t\"][idx],\n",
    "            ndata[\"pos_item_t\"][idx],\n",
    "            ndata[\"pos_weight_t\"][idx],\n",
    "        )\n",
    "\n",
    "def iter_dense_user_batches(ndata, batch_size):\n",
    "    n_users = ndata[\"x_binary_np\"].shape[0]\n",
    "    perm = np.random.permutation(n_users)\n",
    "    for start in range(0, n_users, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        if ndata[\"x_binary_on_device\"]:\n",
    "            yield ndata[\"x_binary_t\"][idx], torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "        else:\n",
    "            batch = torch.from_numpy(ndata[\"x_binary_np\"][idx]).to(device)\n",
    "            yield batch, torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_twotower(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    item_feats = ndata[\"item_feat_t\"]\n",
    "    item_vecs = []\n",
    "    for start in range(0, item_feats.shape[0], 4096):\n",
    "        item_vecs.append(model.item_tower(item_feats[start:start + 4096]))\n",
    "    item_matrix = torch.cat(item_vecs, dim=0)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        feat_batch = ndata[\"user_feat_t\"][batch_uids]\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        user_vec = model.user_tower(feat_batch)\n",
    "        scores = user_vec @ item_matrix.T\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del feat_batch, x_seen, user_vec, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_multvae(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        x_batch = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        logits, _, _ = model(x_batch)\n",
    "        logits = logits.masked_fill(x_batch > 0, -1e9)\n",
    "        topk = torch.topk(logits, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([bundle[\"test_item_by_user\"][int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del x_batch, logits, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_neumf(bundle, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    num_items = bundle[\"num_items\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    all_items = torch.arange(num_items, device=device, dtype=torch.long)\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEUMF_EVAL_USER_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEUMF_EVAL_USER_BATCH]\n",
    "        user_batch = torch.as_tensor(batch_uids, device=device, dtype=torch.long)\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "\n",
    "        parts = []\n",
    "        for istart in range(0, num_items, NEUMF_EVAL_ITEM_BATCH):\n",
    "            item_batch = all_items[istart:istart + NEUMF_EVAL_ITEM_BATCH]\n",
    "            u_expand = user_batch[:, None].expand(-1, item_batch.shape[0]).reshape(-1)\n",
    "            i_expand = item_batch[None, :].expand(user_batch.shape[0], -1).reshape(-1)\n",
    "            scores_part = model(u_expand, i_expand).reshape(user_batch.shape[0], -1)\n",
    "            parts.append(scores_part)\n",
    "\n",
    "        scores = torch.cat(parts, dim=1)\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del user_batch, x_seen, parts, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7f000e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural training functions\n",
    "\n",
    "def is_oom_error(exc):\n",
    "    text = str(exc).lower()\n",
    "    return (\"out of memory\" in text) or (\"cuda error: out of memory\" in text)\n",
    "\n",
    "\n",
    "def train_two_tower(bundle, ndata, cfg):\n",
    "    model = TwoTowerModel(\n",
    "        user_cards=ndata[\"user_feature_cards\"],\n",
    "        item_cards=ndata[\"item_feature_cards\"],\n",
    "        emb_dim=cfg[\"emb_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "        out_dim=cfg[\"out_dim\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    current_batch = int(cfg.get(\"batch_size\", TWOTOWER_BATCH_START))\n",
    "    min_batch = int(cfg.get(\"min_batch_size\", TWOTOWER_BATCH_MIN))\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        while True:\n",
    "            try:\n",
    "                total_loss = 0.0\n",
    "                total_w = 0.0\n",
    "\n",
    "                for user_idx_batch, item_idx_batch, weight_batch in iter_positive_batches(ndata, current_batch):\n",
    "                    user_batch = ndata[\"user_feat_t\"][user_idx_batch]\n",
    "                    item_batch = ndata[\"item_feat_t\"][item_idx_batch]\n",
    "\n",
    "                    optimizer.zero_grad(set_to_none=True)\n",
    "                    with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                        user_vec = model.user_tower(user_batch)\n",
    "                        item_vec = model.item_tower(item_batch)\n",
    "                        logits = (user_vec @ item_vec.T) / cfg[\"temperature\"]\n",
    "                        targets = torch.arange(logits.shape[0], device=device)\n",
    "                        loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                        loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "                    scaler.scale(loss).backward()\n",
    "                    scaler.step(optimizer)\n",
    "                    scaler.update()\n",
    "\n",
    "                    total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                    total_w += float(weight_batch.sum().item())\n",
    "\n",
    "                    del user_batch, item_batch, user_vec, item_vec, logits, targets, loss_vec, loss\n",
    "\n",
    "                print(\n",
    "                    f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} \"\n",
    "                    f\"- batch: {current_batch} - loss: {total_loss / max(total_w, 1e-8):.6f}\"\n",
    "                )\n",
    "                break\n",
    "\n",
    "            except RuntimeError as exc:\n",
    "                if not is_oom_error(exc):\n",
    "                    raise\n",
    "                if current_batch <= min_batch:\n",
    "                    raise\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                clear_memory()\n",
    "                current_batch = max(min_batch, current_batch // 2)\n",
    "                print(\n",
    "                    f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}: OOM, \"\n",
    "                    f\"reducing batch to {current_batch} and retrying the epoch\"\n",
    "                )\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "\n",
    "def train_multvae(bundle, ndata, cfg):\n",
    "    model = MultVAE(\n",
    "        num_items=bundle[\"num_items\"],\n",
    "        hidden_dim=cfg[\"hidden_dim\"],\n",
    "        latent_dim=cfg[\"latent_dim\"],\n",
    "        dropout=cfg[\"dropout\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    current_batch = int(cfg.get(\"batch_size\", AUTOENC_BATCH_START))\n",
    "    min_batch = int(cfg.get(\"min_batch_size\", AUTOENC_BATCH_MIN))\n",
    "    step = 0\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        while True:\n",
    "            try:\n",
    "                total_loss = 0.0\n",
    "                total_rows = 0\n",
    "                steps_per_epoch = max(1, math.ceil(n_users / current_batch))\n",
    "                total_steps = max(1, cfg[\"epochs\"] * steps_per_epoch)\n",
    "\n",
    "                for batch_x, _ in iter_dense_user_batches(ndata, current_batch):\n",
    "                    optimizer.zero_grad(set_to_none=True)\n",
    "                    anneal = min(cfg[\"anneal_cap\"], step / max(total_steps * 0.3, 1))\n",
    "\n",
    "                    with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                        logits, mu, logvar = model(batch_x)\n",
    "                        recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                        loss = (recon + anneal * kl).mean()\n",
    "\n",
    "                    scaler.scale(loss).backward()\n",
    "                    scaler.step(optimizer)\n",
    "                    scaler.update()\n",
    "\n",
    "                    total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "                    total_rows += batch_x.shape[0]\n",
    "                    step += 1\n",
    "\n",
    "                    del batch_x, logits, mu, logvar, recon, kl, loss\n",
    "\n",
    "                if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == cfg[\"epochs\"]:\n",
    "                    print(\n",
    "                        f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} \"\n",
    "                        f\"- batch: {current_batch} - loss: {total_loss / max(total_rows, 1):.6f}\"\n",
    "                    )\n",
    "                break\n",
    "\n",
    "            except RuntimeError as exc:\n",
    "                if not is_oom_error(exc):\n",
    "                    raise\n",
    "                if current_batch <= min_batch:\n",
    "                    raise\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                clear_memory()\n",
    "                current_batch = max(min_batch, current_batch // 2)\n",
    "                print(\n",
    "                    f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}: OOM, \"\n",
    "                    f\"reducing batch to {current_batch} and retrying the epoch\"\n",
    "                )\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "\n",
    "def train_neumf(bundle, cfg):\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    n_items = bundle[\"num_items\"]\n",
    "    train_df = bundle[\"train_interactions\"]\n",
    "\n",
    "    pos_user = torch.as_tensor(train_df[\"user_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_item = torch.as_tensor(train_df[\"item_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_weight = torch.as_tensor(train_df[\"count\"].to_numpy(dtype=np.float32), device=device)\n",
    "\n",
    "    model = NeuMF(\n",
    "        num_users=n_users,\n",
    "        num_items=n_items,\n",
    "        mf_dim=cfg[\"mf_dim\"],\n",
    "        mlp_dim=cfg[\"mlp_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        perm = torch.randperm(pos_user.shape[0], device=device)\n",
    "        for start in range(0, pos_user.shape[0], PAIRWISE_BATCH):\n",
    "            idx = perm[start:start + PAIRWISE_BATCH]\n",
    "            u = pos_user[idx]\n",
    "            i = pos_item[idx]\n",
    "            w = pos_weight[idx]\n",
    "            j = torch.randint(0, n_items, size=(u.shape[0],), device=device)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                pos_scores = model(u, i)\n",
    "                neg_scores = model(u, j)\n",
    "                loss_vec = F.softplus(-(pos_scores - neg_scores))\n",
    "                loss = (loss_vec * w).sum() / w.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(w.sum().item())\n",
    "            total_w += float(w.sum().item())\n",
    "\n",
    "            del u, i, w, j, pos_scores, neg_scores, loss_vec, loss\n",
    "\n",
    "        print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "38b2ee8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Neural search on formulation: H1_base\n",
      "H1_base / TwoTower_t2 epoch 1/28 - batch: 8192 - loss: 7.066663\n",
      "H1_base / TwoTower_t2 epoch 2/28 - batch: 8192 - loss: 5.734326\n",
      "H1_base / TwoTower_t2 epoch 3/28 - batch: 8192 - loss: 5.658068\n",
      "H1_base / TwoTower_t2 epoch 4/28 - batch: 8192 - loss: 5.647869\n",
      "H1_base / TwoTower_t2 epoch 5/28 - batch: 8192 - loss: 5.643628\n",
      "H1_base / TwoTower_t2 epoch 6/28 - batch: 8192 - loss: 5.640994\n",
      "H1_base / TwoTower_t2 epoch 7/28 - batch: 8192 - loss: 5.639192\n",
      "H1_base / TwoTower_t2 epoch 8/28 - batch: 8192 - loss: 5.638070\n",
      "H1_base / TwoTower_t2 epoch 9/28 - batch: 8192 - loss: 5.637313\n",
      "H1_base / TwoTower_t2 epoch 10/28 - batch: 8192 - loss: 5.636331\n",
      "H1_base / TwoTower_t2 epoch 11/28 - batch: 8192 - loss: 5.635828\n",
      "H1_base / TwoTower_t2 epoch 12/28 - batch: 8192 - loss: 5.635172\n",
      "H1_base / TwoTower_t2 epoch 13/28 - batch: 8192 - loss: 5.634857\n",
      "H1_base / TwoTower_t2 epoch 14/28 - batch: 8192 - loss: 5.634489\n",
      "H1_base / TwoTower_t2 epoch 15/28 - batch: 8192 - loss: 5.634136\n",
      "H1_base / TwoTower_t2 epoch 16/28 - batch: 8192 - loss: 5.633702\n",
      "H1_base / TwoTower_t2 epoch 17/28 - batch: 8192 - loss: 5.633578\n",
      "H1_base / TwoTower_t2 epoch 18/28 - batch: 8192 - loss: 5.633255\n",
      "H1_base / TwoTower_t2 epoch 19/28 - batch: 8192 - loss: 5.632926\n",
      "H1_base / TwoTower_t2 epoch 20/28 - batch: 8192 - loss: 5.632799\n",
      "H1_base / TwoTower_t2 epoch 21/28 - batch: 8192 - loss: 5.632720\n",
      "H1_base / TwoTower_t2 epoch 22/28 - batch: 8192 - loss: 5.632377\n",
      "H1_base / TwoTower_t2 epoch 23/28 - batch: 8192 - loss: 5.632387\n",
      "H1_base / TwoTower_t2 epoch 24/28 - batch: 8192 - loss: 5.632149\n",
      "H1_base / TwoTower_t2 epoch 25/28 - batch: 8192 - loss: 5.632171\n",
      "H1_base / TwoTower_t2 epoch 26/28 - batch: 8192 - loss: 5.631890\n",
      "H1_base / TwoTower_t2 epoch 27/28 - batch: 8192 - loss: 5.632047\n",
      "H1_base / TwoTower_t2 epoch 28/28 - batch: 8192 - loss: 5.631732\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12e96d50d8374793afafcd2190ee1478",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / TwoTower_t2:   0%|          | 0/22 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_base / TwoTower_t3 epoch 1/32 - batch: 8192 - loss: 6.735120\n",
      "H1_base / TwoTower_t3 epoch 2/32 - batch: 8192 - loss: 5.650604\n",
      "H1_base / TwoTower_t3 epoch 3/32 - batch: 8192 - loss: 5.640766\n",
      "H1_base / TwoTower_t3 epoch 4/32 - batch: 8192 - loss: 5.637974\n",
      "H1_base / TwoTower_t3 epoch 5/32 - batch: 8192 - loss: 5.636767\n",
      "H1_base / TwoTower_t3 epoch 6/32 - batch: 8192 - loss: 5.635574\n",
      "H1_base / TwoTower_t3 epoch 7/32 - batch: 8192 - loss: 5.634801\n",
      "H1_base / TwoTower_t3 epoch 8/32 - batch: 8192 - loss: 5.634283\n",
      "H1_base / TwoTower_t3 epoch 9/32 - batch: 8192 - loss: 5.633769\n",
      "H1_base / TwoTower_t3 epoch 10/32 - batch: 8192 - loss: 5.633430\n",
      "H1_base / TwoTower_t3 epoch 11/32 - batch: 8192 - loss: 5.633109\n",
      "H1_base / TwoTower_t3 epoch 12/32 - batch: 8192 - loss: 5.632697\n",
      "H1_base / TwoTower_t3 epoch 13/32 - batch: 8192 - loss: 5.632772\n",
      "H1_base / TwoTower_t3 epoch 14/32 - batch: 8192 - loss: 5.632514\n",
      "H1_base / TwoTower_t3 epoch 15/32 - batch: 8192 - loss: 5.632210\n",
      "H1_base / TwoTower_t3 epoch 16/32 - batch: 8192 - loss: 5.632038\n",
      "H1_base / TwoTower_t3 epoch 17/32 - batch: 8192 - loss: 5.632091\n",
      "H1_base / TwoTower_t3 epoch 18/32 - batch: 8192 - loss: 5.632020\n",
      "H1_base / TwoTower_t3 epoch 19/32 - batch: 8192 - loss: 5.631778\n",
      "H1_base / TwoTower_t3 epoch 20/32 - batch: 8192 - loss: 5.631920\n",
      "H1_base / TwoTower_t3 epoch 21/32 - batch: 8192 - loss: 5.631652\n",
      "H1_base / TwoTower_t3 epoch 22/32 - batch: 8192 - loss: 5.631503\n",
      "H1_base / TwoTower_t3 epoch 23/32 - batch: 8192 - loss: 5.631486\n",
      "H1_base / TwoTower_t3 epoch 24/32 - batch: 8192 - loss: 5.631517\n",
      "H1_base / TwoTower_t3 epoch 25/32 - batch: 8192 - loss: 5.631330\n",
      "H1_base / TwoTower_t3 epoch 26/32 - batch: 8192 - loss: 5.631440\n",
      "H1_base / TwoTower_t3 epoch 27/32 - batch: 8192 - loss: 5.631247\n",
      "H1_base / TwoTower_t3 epoch 28/32 - batch: 8192 - loss: 5.631046\n",
      "H1_base / TwoTower_t3 epoch 29/32 - batch: 8192 - loss: 5.631009\n",
      "H1_base / TwoTower_t3 epoch 30/32 - batch: 8192 - loss: 5.631025\n",
      "H1_base / TwoTower_t3 epoch 31/32 - batch: 8192 - loss: 5.630961\n",
      "H1_base / TwoTower_t3 epoch 32/32 - batch: 8192 - loss: 5.631033\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0abef958d9da4c8893256c7fa94666df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / TwoTower_t3:   0%|          | 0/22 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_base / MultVAE_v3 epoch 1/100 - batch: 2048 - loss: 127.802556\n",
      "H1_base / MultVAE_v3 epoch 10/100 - batch: 2048 - loss: 92.121343\n",
      "H1_base / MultVAE_v3 epoch 20/100 - batch: 2048 - loss: 95.629602\n",
      "H1_base / MultVAE_v3 epoch 30/100 - batch: 2048 - loss: 98.333383\n",
      "H1_base / MultVAE_v3 epoch 40/100 - batch: 2048 - loss: 98.070144\n",
      "H1_base / MultVAE_v3 epoch 50/100 - batch: 2048 - loss: 97.927085\n",
      "H1_base / MultVAE_v3 epoch 60/100 - batch: 2048 - loss: 97.725223\n",
      "H1_base / MultVAE_v3 epoch 70/100 - batch: 2048 - loss: 97.747511\n",
      "H1_base / MultVAE_v3 epoch 80/100 - batch: 2048 - loss: 97.627432\n",
      "H1_base / MultVAE_v3 epoch 90/100 - batch: 2048 - loss: 97.454212\n",
      "H1_base / MultVAE_v3 epoch 100/100 - batch: 2048 - loss: 97.197594\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b00b29df0c4149c286495a5ec4e54e93",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / MultVAE_v3:   0%|          | 0/22 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_base / MultVAE_v4 epoch 1/110 - batch: 2048 - loss: 128.534698\n",
      "H1_base / MultVAE_v4 epoch 10/110 - batch: 2048 - loss: 91.873938\n",
      "H1_base / MultVAE_v4 epoch 20/110 - batch: 2048 - loss: 95.151020\n",
      "H1_base / MultVAE_v4 epoch 30/110 - batch: 2048 - loss: 98.066548\n",
      "H1_base / MultVAE_v4 epoch 40/110 - batch: 2048 - loss: 98.301732\n",
      "H1_base / MultVAE_v4 epoch 50/110 - batch: 2048 - loss: 98.052017\n",
      "H1_base / MultVAE_v4 epoch 60/110 - batch: 2048 - loss: 98.011184\n",
      "H1_base / MultVAE_v4 epoch 70/110 - batch: 2048 - loss: 97.890835\n",
      "H1_base / MultVAE_v4 epoch 80/110 - batch: 2048 - loss: 97.763473\n",
      "H1_base / MultVAE_v4 epoch 90/110 - batch: 2048 - loss: 97.842719\n",
      "H1_base / MultVAE_v4 epoch 100/110 - batch: 2048 - loss: 97.821111\n",
      "H1_base / MultVAE_v4 epoch 110/110 - batch: 2048 - loss: 97.650121\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86a2c0eb536f44409d7e7450b9b2faab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / MultVAE_v4:   0%|          | 0/22 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_base / NeuMF_n2 epoch 1/26 - loss: 0.632809\n",
      "H1_base / NeuMF_n2 epoch 2/26 - loss: 0.361154\n",
      "H1_base / NeuMF_n2 epoch 3/26 - loss: 0.278561\n",
      "H1_base / NeuMF_n2 epoch 4/26 - loss: 0.146170\n",
      "H1_base / NeuMF_n2 epoch 5/26 - loss: 0.073701\n",
      "H1_base / NeuMF_n2 epoch 6/26 - loss: 0.047303\n",
      "H1_base / NeuMF_n2 epoch 7/26 - loss: 0.036484\n",
      "H1_base / NeuMF_n2 epoch 8/26 - loss: 0.031460\n",
      "H1_base / NeuMF_n2 epoch 9/26 - loss: 0.029175\n",
      "H1_base / NeuMF_n2 epoch 10/26 - loss: 0.027569\n",
      "H1_base / NeuMF_n2 epoch 11/26 - loss: 0.026375\n",
      "H1_base / NeuMF_n2 epoch 12/26 - loss: 0.025737\n",
      "H1_base / NeuMF_n2 epoch 13/26 - loss: 0.025545\n",
      "H1_base / NeuMF_n2 epoch 14/26 - loss: 0.025251\n",
      "H1_base / NeuMF_n2 epoch 15/26 - loss: 0.024822\n",
      "H1_base / NeuMF_n2 epoch 16/26 - loss: 0.024732\n",
      "H1_base / NeuMF_n2 epoch 17/26 - loss: 0.024117\n",
      "H1_base / NeuMF_n2 epoch 18/26 - loss: 0.024336\n",
      "H1_base / NeuMF_n2 epoch 19/26 - loss: 0.023767\n",
      "H1_base / NeuMF_n2 epoch 20/26 - loss: 0.023682\n",
      "H1_base / NeuMF_n2 epoch 21/26 - loss: 0.023763\n",
      "H1_base / NeuMF_n2 epoch 22/26 - loss: 0.023856\n",
      "H1_base / NeuMF_n2 epoch 23/26 - loss: 0.023581\n",
      "H1_base / NeuMF_n2 epoch 24/26 - loss: 0.023461\n",
      "H1_base / NeuMF_n2 epoch 25/26 - loss: 0.023444\n",
      "H1_base / NeuMF_n2 epoch 26/26 - loss: 0.023232\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6314e7553f754975b52c81e86fc83f08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_base / NeuMF_n2:   0%|          | 0/85 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Neural search on formulation: H1_fine\n",
      "H1_fine / TwoTower_t2 epoch 1/28 - batch: 8192 - loss: 7.052134\n",
      "H1_fine / TwoTower_t2 epoch 2/28 - batch: 8192 - loss: 5.614614\n",
      "H1_fine / TwoTower_t2 epoch 3/28 - batch: 8192 - loss: 5.378916\n",
      "H1_fine / TwoTower_t2 epoch 4/28 - batch: 8192 - loss: 5.359276\n",
      "H1_fine / TwoTower_t2 epoch 5/28 - batch: 8192 - loss: 5.352427\n",
      "H1_fine / TwoTower_t2 epoch 6/28 - batch: 8192 - loss: 5.348926\n",
      "H1_fine / TwoTower_t2 epoch 7/28 - batch: 8192 - loss: 5.346229\n",
      "H1_fine / TwoTower_t2 epoch 8/28 - batch: 8192 - loss: 5.344596\n",
      "H1_fine / TwoTower_t2 epoch 9/28 - batch: 8192 - loss: 5.343014\n",
      "H1_fine / TwoTower_t2 epoch 10/28 - batch: 8192 - loss: 5.341697\n",
      "H1_fine / TwoTower_t2 epoch 11/28 - batch: 8192 - loss: 5.340783\n",
      "H1_fine / TwoTower_t2 epoch 12/28 - batch: 8192 - loss: 5.339844\n",
      "H1_fine / TwoTower_t2 epoch 13/28 - batch: 8192 - loss: 5.339013\n",
      "H1_fine / TwoTower_t2 epoch 14/28 - batch: 8192 - loss: 5.338448\n",
      "H1_fine / TwoTower_t2 epoch 15/28 - batch: 8192 - loss: 5.337947\n",
      "H1_fine / TwoTower_t2 epoch 16/28 - batch: 8192 - loss: 5.337350\n",
      "H1_fine / TwoTower_t2 epoch 17/28 - batch: 8192 - loss: 5.337147\n",
      "H1_fine / TwoTower_t2 epoch 18/28 - batch: 8192 - loss: 5.336655\n",
      "H1_fine / TwoTower_t2 epoch 19/28 - batch: 8192 - loss: 5.336393\n",
      "H1_fine / TwoTower_t2 epoch 20/28 - batch: 8192 - loss: 5.335974\n",
      "H1_fine / TwoTower_t2 epoch 21/28 - batch: 8192 - loss: 5.335633\n",
      "H1_fine / TwoTower_t2 epoch 22/28 - batch: 8192 - loss: 5.335420\n",
      "H1_fine / TwoTower_t2 epoch 23/28 - batch: 8192 - loss: 5.335161\n",
      "H1_fine / TwoTower_t2 epoch 24/28 - batch: 8192 - loss: 5.334941\n",
      "H1_fine / TwoTower_t2 epoch 25/28 - batch: 8192 - loss: 5.334824\n",
      "H1_fine / TwoTower_t2 epoch 26/28 - batch: 8192 - loss: 5.334649\n",
      "H1_fine / TwoTower_t2 epoch 27/28 - batch: 8192 - loss: 5.334399\n",
      "H1_fine / TwoTower_t2 epoch 28/28 - batch: 8192 - loss: 5.334117\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7937c9a8cd5c418a9cfdd07024789fea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / TwoTower_t2:   0%|          | 0/47 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_fine / TwoTower_t3 epoch 1/32 - batch: 8192 - loss: 6.587047\n",
      "H1_fine / TwoTower_t3 epoch 2/32 - batch: 8192 - loss: 5.377902\n",
      "H1_fine / TwoTower_t3 epoch 3/32 - batch: 8192 - loss: 5.348970\n",
      "H1_fine / TwoTower_t3 epoch 4/32 - batch: 8192 - loss: 5.343229\n",
      "H1_fine / TwoTower_t3 epoch 5/32 - batch: 8192 - loss: 5.340990\n",
      "H1_fine / TwoTower_t3 epoch 6/32 - batch: 8192 - loss: 5.339294\n",
      "H1_fine / TwoTower_t3 epoch 7/32 - batch: 8192 - loss: 5.338309\n",
      "H1_fine / TwoTower_t3 epoch 8/32 - batch: 8192 - loss: 5.337404\n",
      "H1_fine / TwoTower_t3 epoch 9/32 - batch: 8192 - loss: 5.336630\n",
      "H1_fine / TwoTower_t3 epoch 10/32 - batch: 8192 - loss: 5.336284\n",
      "H1_fine / TwoTower_t3 epoch 11/32 - batch: 8192 - loss: 5.335932\n",
      "H1_fine / TwoTower_t3 epoch 12/32 - batch: 8192 - loss: 5.335663\n",
      "H1_fine / TwoTower_t3 epoch 13/32 - batch: 8192 - loss: 5.335109\n",
      "H1_fine / TwoTower_t3 epoch 14/32 - batch: 8192 - loss: 5.335027\n",
      "H1_fine / TwoTower_t3 epoch 15/32 - batch: 8192 - loss: 5.334673\n",
      "H1_fine / TwoTower_t3 epoch 16/32 - batch: 8192 - loss: 5.334594\n",
      "H1_fine / TwoTower_t3 epoch 17/32 - batch: 8192 - loss: 5.334405\n",
      "H1_fine / TwoTower_t3 epoch 18/32 - batch: 8192 - loss: 5.334398\n",
      "H1_fine / TwoTower_t3 epoch 19/32 - batch: 8192 - loss: 5.334109\n",
      "H1_fine / TwoTower_t3 epoch 20/32 - batch: 8192 - loss: 5.334008\n",
      "H1_fine / TwoTower_t3 epoch 21/32 - batch: 8192 - loss: 5.333882\n",
      "H1_fine / TwoTower_t3 epoch 22/32 - batch: 8192 - loss: 5.333743\n",
      "H1_fine / TwoTower_t3 epoch 23/32 - batch: 8192 - loss: 5.333715\n",
      "H1_fine / TwoTower_t3 epoch 24/32 - batch: 8192 - loss: 5.333595\n",
      "H1_fine / TwoTower_t3 epoch 25/32 - batch: 8192 - loss: 5.333374\n",
      "H1_fine / TwoTower_t3 epoch 26/32 - batch: 8192 - loss: 5.333389\n",
      "H1_fine / TwoTower_t3 epoch 27/32 - batch: 8192 - loss: 5.333323\n",
      "H1_fine / TwoTower_t3 epoch 28/32 - batch: 8192 - loss: 5.333191\n",
      "H1_fine / TwoTower_t3 epoch 29/32 - batch: 8192 - loss: 5.333135\n",
      "H1_fine / TwoTower_t3 epoch 30/32 - batch: 8192 - loss: 5.333098\n",
      "H1_fine / TwoTower_t3 epoch 31/32 - batch: 8192 - loss: 5.333075\n",
      "H1_fine / TwoTower_t3 epoch 32/32 - batch: 8192 - loss: 5.333007\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "983e2260edf44224b87f63b70e0decc2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / TwoTower_t3:   0%|          | 0/47 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_fine / MultVAE_v3 epoch 1/100 - batch: 2048 - loss: 54.371517\n",
      "H1_fine / MultVAE_v3 epoch 10/100 - batch: 2048 - loss: 42.181241\n",
      "H1_fine / MultVAE_v3 epoch 20/100 - batch: 2048 - loss: 45.944527\n",
      "H1_fine / MultVAE_v3 epoch 30/100 - batch: 2048 - loss: 48.342049\n",
      "H1_fine / MultVAE_v3 epoch 40/100 - batch: 2048 - loss: 48.113153\n",
      "H1_fine / MultVAE_v3 epoch 50/100 - batch: 2048 - loss: 47.509765\n",
      "H1_fine / MultVAE_v3 epoch 60/100 - batch: 2048 - loss: 46.736373\n",
      "H1_fine / MultVAE_v3 epoch 70/100 - batch: 2048 - loss: 45.897734\n",
      "H1_fine / MultVAE_v3 epoch 80/100 - batch: 2048 - loss: 45.143274\n",
      "H1_fine / MultVAE_v3 epoch 90/100 - batch: 2048 - loss: 44.818130\n",
      "H1_fine / MultVAE_v3 epoch 100/100 - batch: 2048 - loss: 44.687635\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4bd85b4b5ccf47b8a516f02e8ed1c74c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / MultVAE_v3:   0%|          | 0/47 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_fine / MultVAE_v4 epoch 1/110 - batch: 2048 - loss: 54.954806\n",
      "H1_fine / MultVAE_v4 epoch 10/110 - batch: 2048 - loss: 41.999466\n",
      "H1_fine / MultVAE_v4 epoch 20/110 - batch: 2048 - loss: 45.531393\n",
      "H1_fine / MultVAE_v4 epoch 30/110 - batch: 2048 - loss: 47.843527\n",
      "H1_fine / MultVAE_v4 epoch 40/110 - batch: 2048 - loss: 48.340765\n",
      "H1_fine / MultVAE_v4 epoch 50/110 - batch: 2048 - loss: 48.095178\n",
      "H1_fine / MultVAE_v4 epoch 60/110 - batch: 2048 - loss: 47.863671\n",
      "H1_fine / MultVAE_v4 epoch 70/110 - batch: 2048 - loss: 47.316821\n",
      "H1_fine / MultVAE_v4 epoch 80/110 - batch: 2048 - loss: 46.458057\n",
      "H1_fine / MultVAE_v4 epoch 90/110 - batch: 2048 - loss: 45.813502\n",
      "H1_fine / MultVAE_v4 epoch 100/110 - batch: 2048 - loss: 45.237868\n",
      "H1_fine / MultVAE_v4 epoch 110/110 - batch: 2048 - loss: 44.919116\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e55ffc601f1b4b3f8cd7695373469783",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / MultVAE_v4:   0%|          | 0/47 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H1_fine / NeuMF_n2 epoch 1/26 - loss: 0.648538\n",
      "H1_fine / NeuMF_n2 epoch 2/26 - loss: 0.535616\n",
      "H1_fine / NeuMF_n2 epoch 3/26 - loss: 0.501414\n",
      "H1_fine / NeuMF_n2 epoch 4/26 - loss: 0.331117\n",
      "H1_fine / NeuMF_n2 epoch 5/26 - loss: 0.138894\n",
      "H1_fine / NeuMF_n2 epoch 6/26 - loss: 0.072931\n",
      "H1_fine / NeuMF_n2 epoch 7/26 - loss: 0.049571\n",
      "H1_fine / NeuMF_n2 epoch 8/26 - loss: 0.038407\n",
      "H1_fine / NeuMF_n2 epoch 9/26 - loss: 0.031989\n",
      "H1_fine / NeuMF_n2 epoch 10/26 - loss: 0.028042\n",
      "H1_fine / NeuMF_n2 epoch 11/26 - loss: 0.025400\n",
      "H1_fine / NeuMF_n2 epoch 12/26 - loss: 0.023307\n",
      "H1_fine / NeuMF_n2 epoch 13/26 - loss: 0.021769\n",
      "H1_fine / NeuMF_n2 epoch 14/26 - loss: 0.020643\n",
      "H1_fine / NeuMF_n2 epoch 15/26 - loss: 0.019752\n",
      "H1_fine / NeuMF_n2 epoch 16/26 - loss: 0.019316\n",
      "H1_fine / NeuMF_n2 epoch 17/26 - loss: 0.018653\n",
      "H1_fine / NeuMF_n2 epoch 18/26 - loss: 0.018106\n",
      "H1_fine / NeuMF_n2 epoch 19/26 - loss: 0.017656\n",
      "H1_fine / NeuMF_n2 epoch 20/26 - loss: 0.017504\n",
      "H1_fine / NeuMF_n2 epoch 21/26 - loss: 0.017083\n",
      "H1_fine / NeuMF_n2 epoch 22/26 - loss: 0.016782\n",
      "H1_fine / NeuMF_n2 epoch 23/26 - loss: 0.016483\n",
      "H1_fine / NeuMF_n2 epoch 24/26 - loss: 0.016228\n",
      "H1_fine / NeuMF_n2 epoch 25/26 - loss: 0.015998\n",
      "H1_fine / NeuMF_n2 epoch 26/26 - loss: 0.015803\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a25a2561c32e4391aa7f4d9add71cfc0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "H1_fine / NeuMF_n2:   0%|          | 0/187 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top neural results:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097919</td>\n",
       "      <td>0.171115</td>\n",
       "      <td>0.313688</td>\n",
       "      <td>0.061253</td>\n",
       "      <td>0.086542</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>NeuMF_n2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.091183</td>\n",
       "      <td>0.167157</td>\n",
       "      <td>0.311489</td>\n",
       "      <td>0.056135</td>\n",
       "      <td>0.081583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>TwoTower_t3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096530</td>\n",
       "      <td>0.166323</td>\n",
       "      <td>0.310331</td>\n",
       "      <td>0.059995</td>\n",
       "      <td>0.084505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.092363</td>\n",
       "      <td>0.165027</td>\n",
       "      <td>0.307229</td>\n",
       "      <td>0.054289</td>\n",
       "      <td>0.079736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>MultVAE_v4</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.090372</td>\n",
       "      <td>0.164958</td>\n",
       "      <td>0.310979</td>\n",
       "      <td>0.053616</td>\n",
       "      <td>0.079191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.076186</td>\n",
       "      <td>0.141549</td>\n",
       "      <td>0.268390</td>\n",
       "      <td>0.045482</td>\n",
       "      <td>0.067511</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077442</td>\n",
       "      <td>0.140711</td>\n",
       "      <td>0.267259</td>\n",
       "      <td>0.046376</td>\n",
       "      <td>0.068040</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>TwoTower_t3</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.074606</td>\n",
       "      <td>0.139748</td>\n",
       "      <td>0.263250</td>\n",
       "      <td>0.044784</td>\n",
       "      <td>0.066552</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>MultVAE_v4</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.075160</td>\n",
       "      <td>0.139696</td>\n",
       "      <td>0.265459</td>\n",
       "      <td>0.045227</td>\n",
       "      <td>0.066881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>NeuMF_n2</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.073883</td>\n",
       "      <td>0.138712</td>\n",
       "      <td>0.265207</td>\n",
       "      <td>0.044334</td>\n",
       "      <td>0.065945</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Formulation        Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0     H1_base  TwoTower_t2      43199  0.097919  0.171115  0.313688  0.061253   \n",
       "1     H1_base     NeuMF_n2      43199  0.091183  0.167157  0.311489  0.056135   \n",
       "2     H1_base  TwoTower_t3      43199  0.096530  0.166323  0.310331  0.059995   \n",
       "3     H1_base   MultVAE_v3      43199  0.092363  0.165027  0.307229  0.054289   \n",
       "4     H1_base   MultVAE_v4      43199  0.090372  0.164958  0.310979  0.053616   \n",
       "5     H1_fine   MultVAE_v3      95529  0.076186  0.141549  0.268390  0.045482   \n",
       "6     H1_fine  TwoTower_t2      95529  0.077442  0.140711  0.267259  0.046376   \n",
       "7     H1_fine  TwoTower_t3      95529  0.074606  0.139748  0.263250  0.044784   \n",
       "8     H1_fine   MultVAE_v4      95529  0.075160  0.139696  0.265459  0.045227   \n",
       "9     H1_fine     NeuMF_n2      95529  0.073883  0.138712  0.265207  0.044334   \n",
       "\n",
       "    NDCG@10  \n",
       "0  0.086542  \n",
       "1  0.081583  \n",
       "2  0.084505  \n",
       "3  0.079736  \n",
       "4  0.079191  \n",
       "5  0.067511  \n",
       "6  0.068040  \n",
       "7  0.066552  \n",
       "8  0.066881  \n",
       "9  0.065945  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Neural search on the best formulations only\n",
    "\n",
    "neural_results = []\n",
    "\n",
    "if RUN_NEURAL:\n",
    "    for formulation_name in selected_neural_formulations:\n",
    "        print(\"=\" * 120)\n",
    "        print(\"Neural search on formulation:\", formulation_name)\n",
    "\n",
    "        bundle = bundles[formulation_name]\n",
    "        ndata = neural_data_by_formulation[formulation_name]\n",
    "\n",
    "        for cfg in TWOTOWER_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_two_tower(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_twotower(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in MULTVAE_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_multvae(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_multvae(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in NEUMF_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_neumf(bundle, cfg)\n",
    "            neural_results.append(evaluate_neumf(bundle, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "    neural_results_df = (\n",
    "        pd.DataFrame(neural_results)\n",
    "        .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "else:\n",
    "    neural_results_df = pd.DataFrame()\n",
    "\n",
    "print(\"Top neural results:\")\n",
    "display(neural_results_df.head(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "63d0e73c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fusion on formulation: H1_base\n",
      "Best for fusion: EASE_binary_l1000 RP3beta_a1_0_b0_6 TwoTower_t2 MultVAE_v3\n"
     ]
    }
   ],
   "source": [
    "# Optional fusion on the best formulation\n",
    "\n",
    "def make_python_recommenders_for_fusion(bundle, formulation_name, regular_results_df, neural_results_df):\n",
    "    out = {}\n",
    "\n",
    "    best_rp3 = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"RP3beta_\")\n",
    "    best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_binary_\")\n",
    "    if best_ease is None:\n",
    "        best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_count_\")\n",
    "    best_tt = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"TwoTower_\") if not neural_results_df.empty else None\n",
    "    best_vae = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"MultVAE_\") if not neural_results_df.empty else None\n",
    "\n",
    "    return best_ease, best_rp3, best_tt, best_vae\n",
    "\n",
    "fusion_results = []\n",
    "\n",
    "if RUN_FUSION and not neural_results_df.empty:\n",
    "    best_formulation = regular_results_df.iloc[0][\"Formulation\"]\n",
    "    bundle = bundles[best_formulation]\n",
    "    print(\"Fusion on formulation:\", best_formulation)\n",
    "\n",
    "    reg_sub = regular_results_df[regular_results_df[\"Formulation\"] == best_formulation]\n",
    "    neu_sub = neural_results_df[neural_results_df[\"Formulation\"] == best_formulation]\n",
    "\n",
    "    ease_name = top_model_name(reg_sub, \"EASE_binary_\")\n",
    "    if ease_name is None:\n",
    "        ease_name = top_model_name(reg_sub, \"EASE_count_\")\n",
    "    rp3_name = top_model_name(reg_sub, \"RP3beta_\")\n",
    "    tt_name = top_model_name(neu_sub, \"TwoTower_\")\n",
    "    vae_name = top_model_name(neu_sub, \"MultVAE_\")\n",
    "\n",
    "    print(\"Best for fusion:\", ease_name, rp3_name, tt_name, vae_name)\n",
    "else:\n",
    "    print(\"Fusion skipped or no neural results.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fb491a3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top overall results:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "      <th>Phase</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097919</td>\n",
       "      <td>0.171115</td>\n",
       "      <td>0.313688</td>\n",
       "      <td>0.061253</td>\n",
       "      <td>0.086542</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099354</td>\n",
       "      <td>0.170050</td>\n",
       "      <td>0.312947</td>\n",
       "      <td>0.061000</td>\n",
       "      <td>0.086145</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099956</td>\n",
       "      <td>0.169958</td>\n",
       "      <td>0.312577</td>\n",
       "      <td>0.061645</td>\n",
       "      <td>0.086629</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099470</td>\n",
       "      <td>0.169842</td>\n",
       "      <td>0.313618</td>\n",
       "      <td>0.061275</td>\n",
       "      <td>0.086323</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099169</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.312299</td>\n",
       "      <td>0.060961</td>\n",
       "      <td>0.085968</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098567</td>\n",
       "      <td>0.169147</td>\n",
       "      <td>0.312716</td>\n",
       "      <td>0.061023</td>\n",
       "      <td>0.085960</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098289</td>\n",
       "      <td>0.168221</td>\n",
       "      <td>0.310216</td>\n",
       "      <td>0.060397</td>\n",
       "      <td>0.085273</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098127</td>\n",
       "      <td>0.168129</td>\n",
       "      <td>0.311003</td>\n",
       "      <td>0.060514</td>\n",
       "      <td>0.085340</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>NeuMF_n2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.091183</td>\n",
       "      <td>0.167157</td>\n",
       "      <td>0.311489</td>\n",
       "      <td>0.056135</td>\n",
       "      <td>0.081583</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>TwoTower_t3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096530</td>\n",
       "      <td>0.166323</td>\n",
       "      <td>0.310331</td>\n",
       "      <td>0.059995</td>\n",
       "      <td>0.084505</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.092363</td>\n",
       "      <td>0.165027</td>\n",
       "      <td>0.307229</td>\n",
       "      <td>0.054289</td>\n",
       "      <td>0.079736</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>MultVAE_v4</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.090372</td>\n",
       "      <td>0.164958</td>\n",
       "      <td>0.310979</td>\n",
       "      <td>0.053616</td>\n",
       "      <td>0.079191</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083137</td>\n",
       "      <td>0.146018</td>\n",
       "      <td>0.268986</td>\n",
       "      <td>0.050754</td>\n",
       "      <td>0.072685</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082300</td>\n",
       "      <td>0.145338</td>\n",
       "      <td>0.268913</td>\n",
       "      <td>0.050044</td>\n",
       "      <td>0.071969</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.084278</td>\n",
       "      <td>0.145055</td>\n",
       "      <td>0.269562</td>\n",
       "      <td>0.050897</td>\n",
       "      <td>0.072607</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083472</td>\n",
       "      <td>0.144731</td>\n",
       "      <td>0.269552</td>\n",
       "      <td>0.050221</td>\n",
       "      <td>0.071991</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.081159</td>\n",
       "      <td>0.144532</td>\n",
       "      <td>0.269740</td>\n",
       "      <td>0.048689</td>\n",
       "      <td>0.070727</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082457</td>\n",
       "      <td>0.144522</td>\n",
       "      <td>0.268892</td>\n",
       "      <td>0.049771</td>\n",
       "      <td>0.071576</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.076961</td>\n",
       "      <td>0.142292</td>\n",
       "      <td>0.268903</td>\n",
       "      <td>0.046265</td>\n",
       "      <td>0.068288</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.076186</td>\n",
       "      <td>0.141549</td>\n",
       "      <td>0.268390</td>\n",
       "      <td>0.045482</td>\n",
       "      <td>0.067511</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077442</td>\n",
       "      <td>0.140711</td>\n",
       "      <td>0.267259</td>\n",
       "      <td>0.046376</td>\n",
       "      <td>0.068040</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>TwoTower_t3</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.074606</td>\n",
       "      <td>0.139748</td>\n",
       "      <td>0.263250</td>\n",
       "      <td>0.044784</td>\n",
       "      <td>0.066552</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>MultVAE_v4</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.075160</td>\n",
       "      <td>0.139696</td>\n",
       "      <td>0.265459</td>\n",
       "      <td>0.045227</td>\n",
       "      <td>0.066881</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>NeuMF_n2</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.073883</td>\n",
       "      <td>0.138712</td>\n",
       "      <td>0.265207</td>\n",
       "      <td>0.044334</td>\n",
       "      <td>0.065945</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>Popularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.002847</td>\n",
       "      <td>0.006250</td>\n",
       "      <td>0.012037</td>\n",
       "      <td>0.001778</td>\n",
       "      <td>0.002793</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>Popularity</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.002450</td>\n",
       "      <td>0.004878</td>\n",
       "      <td>0.009882</td>\n",
       "      <td>0.001397</td>\n",
       "      <td>0.002194</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Formulation                     Model  UsersEval      HR@5     HR@10  \\\n",
       "0      H1_base               TwoTower_t2      43199  0.097919  0.171115   \n",
       "1      H1_base         EASE_binary_l1000      43199  0.099354  0.170050   \n",
       "2      H1_base         RP3beta_a1_0_b0_6      43199  0.099956  0.169958   \n",
       "3      H1_base        RP3beta_a1_15_b0_7      43199  0.099470  0.169842   \n",
       "4      H1_base  RP3beta_screen_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "5      H1_base         RP3beta_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "6      H1_base         EASE_binary_l1200      43199  0.099169  0.169356   \n",
       "7      H1_base         EASE_binary_l1600      43199  0.098567  0.169147   \n",
       "8      H1_base          EASE_count_l1200      43199  0.098289  0.168221   \n",
       "9      H1_base          EASE_count_l1600      43199  0.098127  0.168129   \n",
       "10     H1_base                  NeuMF_n2      43199  0.091183  0.167157   \n",
       "11     H1_base               TwoTower_t3      43199  0.096530  0.166323   \n",
       "12     H1_base                MultVAE_v3      43199  0.092363  0.165027   \n",
       "13     H1_base                MultVAE_v4      43199  0.090372  0.164958   \n",
       "14     H1_fine          EASE_count_l1600      95529  0.083137  0.146018   \n",
       "15     H1_fine          EASE_count_l1200      95529  0.082300  0.145338   \n",
       "16     H1_fine         EASE_binary_l1600      95529  0.084278  0.145055   \n",
       "17     H1_fine         EASE_binary_l1200      95529  0.083472  0.144731   \n",
       "18     H1_fine         RP3beta_a1_0_b0_6      95529  0.081159  0.144532   \n",
       "19     H1_fine         EASE_binary_l1000      95529  0.082457  0.144522   \n",
       "20     H1_fine        RP3beta_a1_15_b0_7      95529  0.076961  0.142292   \n",
       "21     H1_fine  RP3beta_screen_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "22     H1_fine         RP3beta_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "23     H1_fine                MultVAE_v3      95529  0.076186  0.141549   \n",
       "24     H1_fine               TwoTower_t2      95529  0.077442  0.140711   \n",
       "25     H1_fine               TwoTower_t3      95529  0.074606  0.139748   \n",
       "26     H1_fine                MultVAE_v4      95529  0.075160  0.139696   \n",
       "27     H1_fine                  NeuMF_n2      95529  0.073883  0.138712   \n",
       "28     H1_base                Popularity      43199  0.002847  0.006250   \n",
       "29     H1_fine                Popularity      95529  0.002450  0.004878   \n",
       "\n",
       "       HR@20    MRR@10   NDCG@10   Phase  \n",
       "0   0.313688  0.061253  0.086542     NaN  \n",
       "1   0.312947  0.061000  0.086145    full  \n",
       "2   0.312577  0.061645  0.086629    full  \n",
       "3   0.313618  0.061275  0.086323    full  \n",
       "4   0.313109  0.061218  0.086209  screen  \n",
       "5   0.313109  0.061218  0.086209    full  \n",
       "6   0.312299  0.060961  0.085968    full  \n",
       "7   0.312716  0.061023  0.085960    full  \n",
       "8   0.310216  0.060397  0.085273    full  \n",
       "9   0.311003  0.060514  0.085340    full  \n",
       "10  0.311489  0.056135  0.081583     NaN  \n",
       "11  0.310331  0.059995  0.084505     NaN  \n",
       "12  0.307229  0.054289  0.079736     NaN  \n",
       "13  0.310979  0.053616  0.079191     NaN  \n",
       "14  0.268986  0.050754  0.072685    full  \n",
       "15  0.268913  0.050044  0.071969    full  \n",
       "16  0.269562  0.050897  0.072607    full  \n",
       "17  0.269552  0.050221  0.071991    full  \n",
       "18  0.269740  0.048689  0.070727    full  \n",
       "19  0.268892  0.049771  0.071576    full  \n",
       "20  0.268903  0.046265  0.068288    full  \n",
       "21  0.269060  0.046382  0.068361  screen  \n",
       "22  0.269060  0.046382  0.068361    full  \n",
       "23  0.268390  0.045482  0.067511     NaN  \n",
       "24  0.267259  0.046376  0.068040     NaN  \n",
       "25  0.263250  0.044784  0.066552     NaN  \n",
       "26  0.265459  0.045227  0.066881     NaN  \n",
       "27  0.265207  0.044334  0.065945     NaN  \n",
       "28  0.012037  0.001778  0.002793  screen  \n",
       "29  0.009882  0.001397  0.002194  screen  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best per formulation:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "      <th>Phase</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097919</td>\n",
       "      <td>0.171115</td>\n",
       "      <td>0.313688</td>\n",
       "      <td>0.061253</td>\n",
       "      <td>0.086542</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099354</td>\n",
       "      <td>0.170050</td>\n",
       "      <td>0.312947</td>\n",
       "      <td>0.061000</td>\n",
       "      <td>0.086145</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099956</td>\n",
       "      <td>0.169958</td>\n",
       "      <td>0.312577</td>\n",
       "      <td>0.061645</td>\n",
       "      <td>0.086629</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099470</td>\n",
       "      <td>0.169842</td>\n",
       "      <td>0.313618</td>\n",
       "      <td>0.061275</td>\n",
       "      <td>0.086323</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099169</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.312299</td>\n",
       "      <td>0.060961</td>\n",
       "      <td>0.085968</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098567</td>\n",
       "      <td>0.169147</td>\n",
       "      <td>0.312716</td>\n",
       "      <td>0.061023</td>\n",
       "      <td>0.085960</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098289</td>\n",
       "      <td>0.168221</td>\n",
       "      <td>0.310216</td>\n",
       "      <td>0.060397</td>\n",
       "      <td>0.085273</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098127</td>\n",
       "      <td>0.168129</td>\n",
       "      <td>0.311003</td>\n",
       "      <td>0.060514</td>\n",
       "      <td>0.085340</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083137</td>\n",
       "      <td>0.146018</td>\n",
       "      <td>0.268986</td>\n",
       "      <td>0.050754</td>\n",
       "      <td>0.072685</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082300</td>\n",
       "      <td>0.145338</td>\n",
       "      <td>0.268913</td>\n",
       "      <td>0.050044</td>\n",
       "      <td>0.071969</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.084278</td>\n",
       "      <td>0.145055</td>\n",
       "      <td>0.269562</td>\n",
       "      <td>0.050897</td>\n",
       "      <td>0.072607</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083472</td>\n",
       "      <td>0.144731</td>\n",
       "      <td>0.269552</td>\n",
       "      <td>0.050221</td>\n",
       "      <td>0.071991</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.081159</td>\n",
       "      <td>0.144532</td>\n",
       "      <td>0.269740</td>\n",
       "      <td>0.048689</td>\n",
       "      <td>0.070727</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082457</td>\n",
       "      <td>0.144522</td>\n",
       "      <td>0.268892</td>\n",
       "      <td>0.049771</td>\n",
       "      <td>0.071576</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.076961</td>\n",
       "      <td>0.142292</td>\n",
       "      <td>0.268903</td>\n",
       "      <td>0.046265</td>\n",
       "      <td>0.068288</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.076186</td>\n",
       "      <td>0.141549</td>\n",
       "      <td>0.268390</td>\n",
       "      <td>0.045482</td>\n",
       "      <td>0.067511</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Formulation                     Model  UsersEval      HR@5     HR@10  \\\n",
       "0      H1_base               TwoTower_t2      43199  0.097919  0.171115   \n",
       "1      H1_base         EASE_binary_l1000      43199  0.099354  0.170050   \n",
       "2      H1_base         RP3beta_a1_0_b0_6      43199  0.099956  0.169958   \n",
       "3      H1_base        RP3beta_a1_15_b0_7      43199  0.099470  0.169842   \n",
       "4      H1_base  RP3beta_screen_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "5      H1_base         RP3beta_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "6      H1_base         EASE_binary_l1200      43199  0.099169  0.169356   \n",
       "7      H1_base         EASE_binary_l1600      43199  0.098567  0.169147   \n",
       "8      H1_base          EASE_count_l1200      43199  0.098289  0.168221   \n",
       "9      H1_base          EASE_count_l1600      43199  0.098127  0.168129   \n",
       "10     H1_fine          EASE_count_l1600      95529  0.083137  0.146018   \n",
       "11     H1_fine          EASE_count_l1200      95529  0.082300  0.145338   \n",
       "12     H1_fine         EASE_binary_l1600      95529  0.084278  0.145055   \n",
       "13     H1_fine         EASE_binary_l1200      95529  0.083472  0.144731   \n",
       "14     H1_fine         RP3beta_a1_0_b0_6      95529  0.081159  0.144532   \n",
       "15     H1_fine         EASE_binary_l1000      95529  0.082457  0.144522   \n",
       "16     H1_fine        RP3beta_a1_15_b0_7      95529  0.076961  0.142292   \n",
       "17     H1_fine  RP3beta_screen_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "18     H1_fine         RP3beta_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "19     H1_fine                MultVAE_v3      95529  0.076186  0.141549   \n",
       "\n",
       "       HR@20    MRR@10   NDCG@10   Phase  \n",
       "0   0.313688  0.061253  0.086542     NaN  \n",
       "1   0.312947  0.061000  0.086145    full  \n",
       "2   0.312577  0.061645  0.086629    full  \n",
       "3   0.313618  0.061275  0.086323    full  \n",
       "4   0.313109  0.061218  0.086209  screen  \n",
       "5   0.313109  0.061218  0.086209    full  \n",
       "6   0.312299  0.060961  0.085968    full  \n",
       "7   0.312716  0.061023  0.085960    full  \n",
       "8   0.310216  0.060397  0.085273    full  \n",
       "9   0.311003  0.060514  0.085340    full  \n",
       "10  0.268986  0.050754  0.072685    full  \n",
       "11  0.268913  0.050044  0.071969    full  \n",
       "12  0.269562  0.050897  0.072607    full  \n",
       "13  0.269552  0.050221  0.071991    full  \n",
       "14  0.269740  0.048689  0.070727    full  \n",
       "15  0.268892  0.049771  0.071576    full  \n",
       "16  0.268903  0.046265  0.068288    full  \n",
       "17  0.269060  0.046382  0.068361  screen  \n",
       "18  0.269060  0.046382  0.068361    full  \n",
       "19  0.268390  0.045482  0.067511     NaN  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best per model family:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Family</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TwoTower_t</td>\n",
       "      <td>0.171115</td>\n",
       "      <td>0.313688</td>\n",
       "      <td>0.061253</td>\n",
       "      <td>0.086542</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>EASE_binary</td>\n",
       "      <td>0.170050</td>\n",
       "      <td>0.312947</td>\n",
       "      <td>0.061023</td>\n",
       "      <td>0.086145</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>RP</td>\n",
       "      <td>0.169958</td>\n",
       "      <td>0.313618</td>\n",
       "      <td>0.061645</td>\n",
       "      <td>0.086629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_count</td>\n",
       "      <td>0.168221</td>\n",
       "      <td>0.311003</td>\n",
       "      <td>0.060514</td>\n",
       "      <td>0.085340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NeuMF_n</td>\n",
       "      <td>0.167157</td>\n",
       "      <td>0.311489</td>\n",
       "      <td>0.056135</td>\n",
       "      <td>0.081583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>MultVAE_v</td>\n",
       "      <td>0.165027</td>\n",
       "      <td>0.310979</td>\n",
       "      <td>0.054289</td>\n",
       "      <td>0.079736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>0.006250</td>\n",
       "      <td>0.012037</td>\n",
       "      <td>0.001778</td>\n",
       "      <td>0.002793</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Family     HR@10     HR@20    MRR@10   NDCG@10\n",
       "0   TwoTower_t  0.171115  0.313688  0.061253  0.086542\n",
       "1  EASE_binary  0.170050  0.312947  0.061023  0.086145\n",
       "2           RP  0.169958  0.313618  0.061645  0.086629\n",
       "3   EASE_count  0.168221  0.311003  0.060514  0.085340\n",
       "4      NeuMF_n  0.167157  0.311489  0.056135  0.081583\n",
       "5    MultVAE_v  0.165027  0.310979  0.054289  0.079736\n",
       "6   Popularity  0.006250  0.012037  0.001778  0.002793"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Final comparison tables\n",
    "\n",
    "all_frames = [regular_results_df]\n",
    "if not neural_results_df.empty:\n",
    "    all_frames.append(neural_results_df)\n",
    "\n",
    "all_results_df = (\n",
    "    pd.concat(all_frames, ignore_index=True)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "print(\"Top overall results:\")\n",
    "display(all_results_df.head(30))\n",
    "\n",
    "print(\"Best per formulation:\")\n",
    "best_per_formulation = (\n",
    "    all_results_df.groupby(\"Formulation\")\n",
    "    .head(10)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "display(best_per_formulation)\n",
    "\n",
    "print(\"Best per model family:\")\n",
    "family_df = all_results_df.copy()\n",
    "family_df[\"Family\"] = family_df[\"Model\"].str.extract(r\"^([A-Za-z]+(?:_[A-Za-z]+)?)\")[0]\n",
    "display(\n",
    "    family_df.groupby(\"Family\")[[\"HR@10\", \"HR@20\", \"MRR@10\", \"NDCG@10\"]]\n",
    "    .max()\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\"], ascending=False)\n",
    "    .reset_index()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc27b61",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "This notebook removes the old automatic screening stage that previously caused very long EASE runs and high RAM usage.\n",
    "\n",
    "What is threaded or batched here:\n",
    "- BLAS / LAPACK thread env vars are set before importing NumPy and SciPy\n",
    "- `threadpoolctl` is applied when available\n",
    "- EASE uses a Cholesky-based dense solve\n",
    "- EASE and RP3 evaluation are fully batched\n",
    "- heavy score computation is moved to the GPU when available\n",
    "- neural training uses manual on-device batching instead of multiprocessing DataLoaders\n",
    "\n",
    "If you still need more speed:\n",
    "- lower `TOP_FORMULATIONS_FOR_NEURAL`\n",
    "- lower `TWOTOWER_CONFIGS` / `MULTVAE_CONFIGS`\n",
    "- lower `EVAL_MAX_USERS`\n",
    "- lower `LINEAR_EVAL_BATCH` or `NEURAL_EVAL_BATCH` if GPU memory becomes the bottleneck\n",
    "\n",
    "Additional EASE fit changes in this version:\n",
    "- Gram matrices are built once per formulation and reused across all EASE lambdas\n",
    "- EASE fitting can use GPU Cholesky solve when CUDA is available\n",
    "- fit time and eval time are printed separately so long pre-bar stalls are visible\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
