{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6979c756",
   "metadata": {},
   "source": [
    "# Car recommendation notebook — H1 variants, tuned top models, batched evaluation\n",
    "\n",
    "This notebook focuses on richer H1-style formulations only and removes the old slow screening path.\n",
    "\n",
    "Goals:\n",
    "- try several harder H1 variants\n",
    "- improve the current top families: RP3beta, EASE, TwoTower, MultVAE\n",
    "- batch all heavy evaluation paths\n",
    "- use GPU where it helps and CPU threads where possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8fb2968a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU thread target: 16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "\n",
    "CPU_THREADS = min(16, os.cpu_count() or 1)\n",
    "for var in [\n",
    "    \"OMP_NUM_THREADS\",\n",
    "    \"OPENBLAS_NUM_THREADS\",\n",
    "    \"MKL_NUM_THREADS\",\n",
    "    \"NUMEXPR_NUM_THREADS\",\n",
    "    \"VECLIB_MAXIMUM_THREADS\",\n",
    "]:\n",
    "    os.environ[var] = str(CPU_THREADS)\n",
    "\n",
    "import gc\n",
    "import math\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy import linalg as sla\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "try:\n",
    "    from threadpoolctl import threadpool_limits\n",
    "except Exception:\n",
    "    threadpool_limits = None\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "\n",
    "if threadpool_limits is not None:\n",
    "    try:\n",
    "        threadpool_limits(limits=CPU_THREADS)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"CPU thread target:\", CPU_THREADS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6731d71d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Formulations: ['H1_base', 'H1_color', 'H1_fine', 'H1_fine_color']\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "CURRENT_YEAR = 2026\n",
    "\n",
    "# Binning\n",
    "PRICE_BINS_BASE = 12\n",
    "PRICE_BINS_FINE = 16\n",
    "MILEAGE_BINS_BASE = 12\n",
    "MILEAGE_BINS_FINE = 16\n",
    "AGE_BINS_BASE = 10\n",
    "AGE_BINS_FINE = 14\n",
    "\n",
    "# Filtering\n",
    "FILTER_SCHEDULE = [\n",
    "    (5, 10),\n",
    "    (4, 8),\n",
    "    (3, 5),\n",
    "    (2, 3),\n",
    "]\n",
    "MAX_FILTER_ITERS = 10\n",
    "\n",
    "# Evaluation\n",
    "TOP_KS = [5, 10, 20]\n",
    "EVAL_MAX_USERS = None\n",
    "LINEAR_EVAL_BATCH = 4096\n",
    "NEURAL_EVAL_BATCH = 4096\n",
    "\n",
    "# Manual richer H1 variants only\n",
    "FORMULATIONS = {\n",
    "    \"H1_base\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H1_fine\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_fine_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "# Cheap screen all formulations first, then run the heavy regular/neural search only on the best ones\n",
    "TOP_FORMULATIONS_FOR_FULL_REGULAR = 2\n",
    "TOP_FORMULATIONS_FOR_NEURAL = 2\n",
    "\n",
    "# Best current families only\n",
    "# One cheap RP3 config for formulation screening, then a tighter grid for the full search\n",
    "SCREEN_RP3 = (1.10, 0.70)\n",
    "\n",
    "RP3_GRID = [\n",
    "    (1.00, 0.60),\n",
    "    (1.10, 0.70),\n",
    "    (1.15, 0.70),\n",
    "]\n",
    "\n",
    "# Keep only the EASE lambdas that were already near the top in previous runs.\n",
    "EASE_BINARY_LAMBDAS = [1000.0, 1200.0, 1600.0]\n",
    "EASE_COUNT_LAMBDAS = [1200.0, 1600.0]\n",
    "\n",
    "# Skip very large EASE formulations instead of spending tens of minutes building/factoring huge Gram matrices.\n",
    "MAX_EASE_ITEMS = 12000\n",
    "# Only try GPU EASE for comparatively small item spaces; larger ones fall back to CPU torch.\n",
    "EASE_GPU_MAX_GRAM_GB = 0.60\n",
    "\n",
    "TWOTOWER_CONFIGS = [\n",
    "    {\"name\": \"TwoTower_t2\", \"emb_dim\": 96,  \"hidden_dims\": (768, 384, 192),  \"out_dim\": 160, \"epochs\": 28, \"lr\": 1.5e-3, \"wd\": 1e-5, \"temperature\": 0.05},\n",
    "    {\"name\": \"TwoTower_t3\", \"emb_dim\": 128, \"hidden_dims\": (1024, 512, 256), \"out_dim\": 192, \"epochs\": 32, \"lr\": 1.2e-3, \"wd\": 1e-5, \"temperature\": 0.04},\n",
    "]\n",
    "\n",
    "MULTVAE_CONFIGS = [\n",
    "    {\"name\": \"MultVAE_v3\", \"hidden_dim\": 2048, \"latent_dim\": 512, \"dropout\": 0.30, \"epochs\": 100, \"lr\": 6e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "    {\"name\": \"MultVAE_v4\", \"hidden_dim\": 2560, \"latent_dim\": 640, \"dropout\": 0.35, \"epochs\": 110, \"lr\": 5e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "]\n",
    "\n",
    "NEUMF_CONFIGS = [\n",
    "    {\"name\": \"NeuMF_n2\", \"mf_dim\": 96, \"mlp_dim\": 192, \"hidden_dims\": (384, 192, 96), \"epochs\": 26, \"lr\": 1.5e-3, \"wd\": 1e-6},\n",
    "]\n",
    "\n",
    "TWOTOWER_BATCH = 32768\n",
    "PAIRWISE_BATCH = 32768\n",
    "AUTOENC_BATCH = 4096\n",
    "NEUMF_EVAL_USER_BATCH = 512\n",
    "NEUMF_EVAL_ITEM_BATCH = 1024\n",
    "\n",
    "RUN_NEURAL = True\n",
    "RUN_FUSION = True\n",
    "FUSION_FETCH_N = 100\n",
    "FUSION_RRF_K = 60\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "349eba60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch version: 2.11.0\n",
      "Torch device : cuda\n",
      "USE_AMP      : True\n"
     ]
    }
   ],
   "source": [
    "# PyTorch runtime setup\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "\n",
    "try:\n",
    "    torch.set_num_threads(CPU_THREADS)\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "try:\n",
    "    torch.set_num_interop_threads(max(1, CPU_THREADS // 2))\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "USE_AMP = (device.type == \"cuda\")\n",
    "\n",
    "if device.type == \"cuda\":\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    try:\n",
    "        torch.set_float32_matmul_precision(\"high\")\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"Torch version:\", torch.__version__)\n",
    "print(\"Torch device :\", device)\n",
    "print(\"USE_AMP      :\", USE_AMP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "145329d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows: 1000000\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin12</th>\n",
       "      <th>PriceBin16</th>\n",
       "      <th>MileageBin12</th>\n",
       "      <th>MileageBin16</th>\n",
       "      <th>AgeBin10</th>\n",
       "      <th>AgeBin14</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>P12_3</td>\n",
       "      <td>P16_4</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_0</td>\n",
       "      <td>A14_0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_9</td>\n",
       "      <td>A14_13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>P12_7</td>\n",
       "      <td>P16_9</td>\n",
       "      <td>M12_0</td>\n",
       "      <td>M16_0</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_1</td>\n",
       "      <td>M16_2</td>\n",
       "      <td>A10_8</td>\n",
       "      <td>A14_11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>P12_0</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_4</td>\n",
       "      <td>M16_6</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    3   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   26   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   12   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   23   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   14   \n",
       "\n",
       "  PriceBin12 PriceBin16 MileageBin12 MileageBin16 AgeBin10 AgeBin14  \n",
       "0      P12_3      P16_4        M12_3        M16_4    A10_0    A14_0  \n",
       "1      P12_1      P16_1        M12_3        M16_4    A10_9   A14_13  \n",
       "2      P12_7      P16_9        M12_0        M16_0    A10_4    A14_5  \n",
       "3      P12_1      P16_1        M12_1        M16_2    A10_8   A14_11  \n",
       "4      P12_0      P16_1        M12_4        M16_6    A10_4    A14_6  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean the dataset\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\"]:\n",
    "    df[col] = df[col].astype(str).str.strip()\n",
    "\n",
    "df[\"Year\"] = pd.to_numeric(df[\"Year\"], errors=\"coerce\")\n",
    "df[\"Price\"] = pd.to_numeric(df[\"Price\"], errors=\"coerce\")\n",
    "df[\"Mileage\"] = pd.to_numeric(df[\"Mileage\"], errors=\"coerce\")\n",
    "\n",
    "df = df.dropna(subset=[\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\"]).copy()\n",
    "\n",
    "df[\"Year\"] = df[\"Year\"].astype(int)\n",
    "df[\"Age\"] = (CURRENT_YEAR - df[\"Year\"]).clip(lower=0)\n",
    "\n",
    "def make_qbin_labels(series, q, prefix):\n",
    "    cats = pd.qcut(series, q=q, labels=False, duplicates=\"drop\")\n",
    "    cats = cats.astype(int)\n",
    "    return cats.map(lambda x: f\"{prefix}{int(x)}\")\n",
    "\n",
    "df[\"PriceBin12\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_BASE, \"P12_\")\n",
    "df[\"PriceBin16\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_FINE, \"P16_\")\n",
    "df[\"MileageBin12\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_BASE, \"M12_\")\n",
    "df[\"MileageBin16\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_FINE, \"M16_\")\n",
    "df[\"AgeBin10\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_BASE, \"A10_\")\n",
    "df[\"AgeBin14\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_FINE, \"A14_\")\n",
    "\n",
    "print(\"Rows:\", len(df))\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "450426c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user, min_users_per_item, max_iters=10):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_n = len(out)\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if len(out) == old_n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    raw_interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "        .size()\n",
    "        .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = None\n",
    "    used_thresholds = None\n",
    "\n",
    "    for min_items_per_user, min_users_per_item in FILTER_SCHEDULE:\n",
    "        candidate = iterative_filter(\n",
    "            raw_interactions,\n",
    "            min_items_per_user=min_items_per_user,\n",
    "            min_users_per_item=min_users_per_item,\n",
    "            max_iters=MAX_FILTER_ITERS,\n",
    "        ).reset_index(drop=True)\n",
    "        if not candidate.empty:\n",
    "            interactions = candidate\n",
    "            used_thresholds = (min_items_per_user, min_users_per_item)\n",
    "            break\n",
    "\n",
    "    if interactions is None or interactions.empty:\n",
    "        raise ValueError(f\"{formulation_name} produced no interactions after filtering\")\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    user_feature_df = (\n",
    "        tmp[tmp[\"user_id\"].isin(interactions[\"user_id\"])]\n",
    "        [[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "    user_feature_df = user_feature_df.sort_values(\"user_idx\").reset_index(drop=True)\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[tmp[\"item_id\"].isin(interactions[\"item_id\"])]\n",
    "        [[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "    item_feature_df = item_feature_df.sort_values(\"item_idx\").reset_index(drop=True)\n",
    "\n",
    "    # Weighted leave-one-out\n",
    "    rng = np.random.default_rng(RANDOM_STATE)\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = rng.choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(int(picked))\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    num_users = int(interactions[\"user_idx\"].nunique())\n",
    "    num_items = int(interactions[\"item_idx\"].nunique())\n",
    "\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(u): int(i)\n",
    "        for u, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    user_seen_arrays = {}\n",
    "    for uid, g in train_interactions.groupby(\"user_idx\"):\n",
    "        user_seen_arrays[int(uid)] = np.sort(g[\"item_idx\"].astype(np.int32).to_numpy())\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy(dtype=np.int32)\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": list(user_cols),\n",
    "        \"item_cols\": list(item_cols),\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df,\n",
    "        \"item_feature_df\": item_feature_df,\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen_arrays\": user_seen_arrays,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_encoder.classes_.tolist(),\n",
    "        \"used_thresholds\": used_thresholds,\n",
    "    }\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "    avg_train = len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1)\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Thresholds :\", bundle[\"used_thresholds\"])\n",
    "    print(\"Train rows :\", len(bundle[\"train_interactions\"]))\n",
    "    print(\"Test rows  :\", len(bundle[\"test_interactions\"]))\n",
    "    print(\"Avg train interactions/user:\", round(avg_train, 3))\n",
    "\n",
    "def get_eval_users(bundle):\n",
    "    user_indices = np.array(sorted(bundle[\"test_item_by_user\"].keys()), dtype=np.int32)\n",
    "    if EVAL_MAX_USERS is not None and len(user_indices) > EVAL_MAX_USERS:\n",
    "        rng = np.random.default_rng(RANDOM_STATE)\n",
    "        user_indices = rng.choice(user_indices, size=EVAL_MAX_USERS, replace=False)\n",
    "    return np.sort(user_indices)\n",
    "\n",
    "def batch_metric_update(topk_idx, true_items, acc, ks):\n",
    "    matches = (topk_idx == true_items[:, None])\n",
    "    for k in ks:\n",
    "        acc[f\"HR@{k}\"] += matches[:, :k].any(axis=1).sum()\n",
    "\n",
    "    top10 = matches[:, :10]\n",
    "    found10 = top10.any(axis=1)\n",
    "    first_rank10 = np.where(found10, top10.argmax(axis=1) + 1, 0)\n",
    "\n",
    "    acc[\"MRR@10\"] += np.where(found10, 1.0 / first_rank10, 0.0).sum()\n",
    "    acc[\"NDCG@10\"] += np.where(found10, 1.0 / np.log2(first_rank10 + 1), 0.0).sum()\n",
    "    acc[\"UsersEval\"] += len(true_items)\n",
    "\n",
    "def finalize_metrics(acc, model_name, formulation_name):\n",
    "    users_eval = max(int(acc[\"UsersEval\"]), 1)\n",
    "    out = {\n",
    "        \"Formulation\": formulation_name,\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": users_eval,\n",
    "        \"HR@5\": float(acc[\"HR@5\"]) / users_eval,\n",
    "        \"HR@10\": float(acc[\"HR@10\"]) / users_eval,\n",
    "        \"HR@20\": float(acc[\"HR@20\"]) / users_eval,\n",
    "        \"MRR@10\": float(acc[\"MRR@10\"]) / users_eval,\n",
    "        \"NDCG@10\": float(acc[\"NDCG@10\"]) / users_eval,\n",
    "    }\n",
    "    return out\n",
    "\n",
    "def popularity_recommend_one(global_pop_rank, seen_array, n):\n",
    "    if seen_array is None or len(seen_array) == 0:\n",
    "        return global_pop_rank[:n].tolist()\n",
    "    seen_set = set(int(x) for x in seen_array)\n",
    "    out = []\n",
    "    for item in global_pop_rank:\n",
    "        item = int(item)\n",
    "        if item not in seen_set:\n",
    "            out.append(item)\n",
    "            if len(out) >= n:\n",
    "                break\n",
    "    return out\n",
    "\n",
    "def evaluate_popularity(bundle, model_name=\"Popularity\"):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    global_pop_rank = bundle[\"global_pop_rank\"]\n",
    "    user_seen_arrays = bundle[\"user_seen_arrays\"]\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    for uid in tqdm(user_indices, desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        true_item = test_item_by_user[int(uid)]\n",
    "        recs = popularity_recommend_one(global_pop_rank, user_seen_arrays.get(int(uid)), max(TOP_KS))\n",
    "        recs_arr = np.array(recs, dtype=np.int32)[None, :]\n",
    "        true_arr = np.array([true_item], dtype=np.int32)\n",
    "        batch_metric_update(recs_arr, true_arr, acc, TOP_KS)\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def clear_memory():\n",
    "    gc.collect()\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e001d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model fitting helpers and batched evaluators\n",
    "\n",
    "def l1_row_normalize(X):\n",
    "    X = X.tocsr(copy=True)\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel().astype(np.float32)\n",
    "    inv = np.zeros_like(row_sums, dtype=np.float32)\n",
    "    mask = row_sums > 0\n",
    "    inv[mask] = 1.0 / row_sums[mask]\n",
    "    return sp.diags(inv) @ X\n",
    "\n",
    "\n",
    "def fit_p3alpha(X_binary, alpha=1.0, beta=0.0):\n",
    "    Pui = l1_row_normalize(X_binary).power(alpha).tocsr()\n",
    "    Piu = l1_row_normalize(X_binary.T).power(alpha).tocsr()\n",
    "\n",
    "    S = (Piu @ Pui).astype(np.float32).tocsr()\n",
    "    S.setdiag(0.0)\n",
    "    S.eliminate_zeros()\n",
    "\n",
    "    if beta > 0:\n",
    "        item_degree = np.asarray(X_binary.sum(axis=0)).ravel().astype(np.float32)\n",
    "        penalty = np.ones_like(item_degree, dtype=np.float32)\n",
    "        mask = item_degree > 0\n",
    "        penalty[mask] = np.power(item_degree[mask], -beta)\n",
    "        S = S @ sp.diags(penalty.astype(np.float32))\n",
    "\n",
    "    return S.tocsr()\n",
    "\n",
    "\n",
    "def prepare_ease_cache(X_matrix, source_name=\"binary\", prefer_gpu=True):\n",
    "    t0 = time.time()\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32, copy=False)\n",
    "    n_items = int(G.shape[0])\n",
    "    gram_gb = (G.nbytes / (1024 ** 3))\n",
    "\n",
    "    backend = \"cpu_torch\"\n",
    "    prep = {\n",
    "        \"source_name\": source_name,\n",
    "        \"backend\": backend,\n",
    "        \"gram_np\": G,\n",
    "        \"n_items\": n_items,\n",
    "        \"gram_gb\": gram_gb,\n",
    "        \"prep_sec\": time.time() - t0,\n",
    "    }\n",
    "\n",
    "    # GPU EASE fit needs room for the Gram matrix, a factorization workspace, and the solved matrix.\n",
    "    # Use it only when the matrix is small enough to fit comfortably.\n",
    "    if prefer_gpu and device.type == \"cuda\" and gram_gb <= EASE_GPU_MAX_GRAM_GB:\n",
    "        try:\n",
    "            free_bytes, total_bytes = torch.cuda.mem_get_info(device=device)\n",
    "            free_gb = free_bytes / (1024 ** 3)\n",
    "            needed_gb = gram_gb * 3.2\n",
    "            if free_gb >= needed_gb:\n",
    "                gram_t = torch.from_numpy(G).to(device=device, dtype=torch.float32, non_blocking=True)\n",
    "                prep.update({\n",
    "                    \"backend\": \"gpu\",\n",
    "                    \"gram_t\": gram_t,\n",
    "                })\n",
    "            else:\n",
    "                prep[\"backend\"] = \"cpu_torch\"\n",
    "        except Exception:\n",
    "            prep[\"backend\"] = \"cpu_torch\"\n",
    "\n",
    "    return prep\n",
    "\n",
    "def fit_ease_from_cache(cache, lam):\n",
    "    n_items = int(cache[\"n_items\"])\n",
    "\n",
    "    if cache.get(\"backend\") == \"gpu\":\n",
    "        try:\n",
    "            gram_t = cache[\"gram_t\"].clone()\n",
    "            gram_t.diagonal().add_(float(lam))\n",
    "            L, info = torch.linalg.cholesky_ex(gram_t, upper=False)\n",
    "            info_val = int(info.item()) if hasattr(info, \"item\") else int(info)\n",
    "            if info_val == 0:\n",
    "                eye_t = torch.eye(n_items, dtype=torch.float32, device=gram_t.device)\n",
    "                P = torch.cholesky_solve(eye_t, L, upper=False)\n",
    "                diag = torch.diagonal(P).clone()\n",
    "                B = -P / diag.unsqueeze(0)\n",
    "                B.fill_diagonal_(0.0)\n",
    "                out = B.detach().cpu().numpy().astype(np.float32, copy=False)\n",
    "                del gram_t, L, eye_t, P, diag, B\n",
    "                clear_memory()\n",
    "                return out\n",
    "            else:\n",
    "                print(f\"    GPU Cholesky failed at lambda={lam} (info={info_val}), falling back to CPU torch.\")\n",
    "        except RuntimeError as e:\n",
    "            print(f\"    GPU EASE failed at lambda={lam}, falling back to CPU torch: {e}\")\n",
    "\n",
    "    # CPU torch path: uses torch's threaded linear algebra without requiring huge GPU memory.\n",
    "    G_t = torch.from_numpy(cache[\"gram_np\"]).to(dtype=torch.float32, device=\"cpu\").clone()\n",
    "    G_t.diagonal().add_(float(lam))\n",
    "    eye_t = torch.eye(n_items, dtype=torch.float32, device=\"cpu\")\n",
    "    L, info = torch.linalg.cholesky_ex(G_t, upper=False)\n",
    "    info_val = int(info.item()) if hasattr(info, \"item\") else int(info)\n",
    "    if info_val != 0:\n",
    "        # Final fallback to SciPy if CPU torch cholesky reports numerical trouble.\n",
    "        del G_t, eye_t, L, info\n",
    "        clear_memory()\n",
    "        G = cache[\"gram_np\"].copy()\n",
    "        G[np.diag_indices_from(G)] += np.float32(lam)\n",
    "        c, lower = sla.cho_factor(G, lower=True, overwrite_a=True, check_finite=False)\n",
    "        I = np.eye(G.shape[0], dtype=np.float32)\n",
    "        P = sla.cho_solve((c, lower), I, overwrite_b=True, check_finite=False)\n",
    "        diag = np.diag(P).copy()\n",
    "        B = -P / diag[None, :]\n",
    "        np.fill_diagonal(B, 0.0)\n",
    "        return B.astype(np.float32, copy=False)\n",
    "\n",
    "    P = torch.cholesky_solve(eye_t, L, upper=False)\n",
    "    diag = torch.diagonal(P).clone()\n",
    "    B = -P / diag.unsqueeze(0)\n",
    "    B.fill_diagonal_(0.0)\n",
    "    out = B.numpy().astype(np.float32, copy=False)\n",
    "    del G_t, eye_t, L, info, P, diag, B\n",
    "    clear_memory()\n",
    "    return out\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_dense_operator(bundle, operator_matrix, model_name, source=\"binary\", batch_size=4096):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    X_source = bundle[\"X_binary\"] if source == \"binary\" else bundle[\"X_counts\"]\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    if isinstance(operator_matrix, torch.Tensor):\n",
    "        op_t = operator_matrix.to(device=device, dtype=torch.float32)\n",
    "    else:\n",
    "        op_t = torch.as_tensor(operator_matrix, dtype=torch.float32, device=device)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), batch_size), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + batch_size]\n",
    "        X_np = X_source[batch_uids].toarray().astype(np.float32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "\n",
    "        X_t = torch.from_numpy(X_np).to(device, non_blocking=True)\n",
    "        scores = X_t @ op_t\n",
    "        scores = scores.masked_fill(X_t > 0, -1e9)\n",
    "\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "\n",
    "        del X_t, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "def make_rrf_scores(score_matrices, rrf_k=60):\n",
    "    # score_matrices: list of top-k item index arrays with shape [batch, fetch_n]\n",
    "    batch, fetch_n = score_matrices[0].shape\n",
    "    fused = np.zeros((batch, fetch_n * len(score_matrices)), dtype=np.float32)\n",
    "    del fused\n",
    "    # kept only as a placeholder helper; fusion is done in a simple Python routine later\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "def make_rrf_scores(score_matrices, rrf_k=60):\n",
    "    # score_matrices: list of top-k item index arrays with shape [batch, fetch_n]\n",
    "    batch, fetch_n = score_matrices[0].shape\n",
    "    fused = np.zeros((batch, fetch_n * len(score_matrices)), dtype=np.float32)\n",
    "    del fused\n",
    "    # kept only as a placeholder helper; fusion is done in a simple Python routine later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "33f155c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Formulation: H1_base\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition']\n",
      "Users      : 43199\n",
      "Items      : 2640\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 831152\n",
      "Test rows  : 43199\n",
      "Avg train interactions/user: 19.24\n",
      "====================================================================================================\n",
      "Formulation: H1_color\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition', 'Color']\n",
      "Users      : 62921\n",
      "Items      : 13198\n",
      "Thresholds : (4, 8)\n",
      "Train rows : 237108\n",
      "Test rows  : 62921\n",
      "Avg train interactions/user: 3.768\n",
      "====================================================================================================\n",
      "Formulation: H1_fine\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition']\n",
      "Users      : 95529\n",
      "Items      : 3696\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 812918\n",
      "Test rows  : 95529\n",
      "Avg train interactions/user: 8.51\n",
      "====================================================================================================\n",
      "Formulation: H1_fine_color\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition', 'Color']\n",
      "Users      : 59346\n",
      "Items      : 23514\n",
      "Thresholds : (3, 5)\n",
      "Train rows : 134712\n",
      "Test rows  : 59346\n",
      "Avg train interactions/user: 2.27\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>TrainRows</th>\n",
       "      <th>AvgTrainPerUser</th>\n",
       "      <th>Thresholds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_fine_color</td>\n",
       "      <td>59346</td>\n",
       "      <td>23514</td>\n",
       "      <td>134712</td>\n",
       "      <td>2.269942</td>\n",
       "      <td>(3, 5)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_color</td>\n",
       "      <td>62921</td>\n",
       "      <td>13198</td>\n",
       "      <td>237108</td>\n",
       "      <td>3.768344</td>\n",
       "      <td>(4, 8)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>95529</td>\n",
       "      <td>3696</td>\n",
       "      <td>812918</td>\n",
       "      <td>8.509646</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>2640</td>\n",
       "      <td>831152</td>\n",
       "      <td>19.240075</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     Formulation  Users  Items  TrainRows  AvgTrainPerUser Thresholds\n",
       "0  H1_fine_color  59346  23514     134712         2.269942     (3, 5)\n",
       "1       H1_color  62921  13198     237108         3.768344     (4, 8)\n",
       "2        H1_fine  95529   3696     812918         8.509646    (5, 10)\n",
       "3        H1_base  43199   2640     831152        19.240075    (5, 10)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Build all richer H1 formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"TrainRows\": len(bundle[\"train_interactions\"]),\n",
    "        \"AvgTrainPerUser\": len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1),\n",
    "        \"Thresholds\": bundle[\"used_thresholds\"],\n",
    "    })\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values([\"Items\", \"Users\"], ascending=False).reset_index(drop=True)\n",
    "display(summary_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a61a7425",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_base\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / Popularity: 100%|██████████| 43199/43199 [00:01<00:00, 31965.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_screen_a1_1_b0_7: 100%|██████████| 11/11 [00:00<00:00, 36.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.60s\n",
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / Popularity: 100%|██████████| 62921/62921 [00:01<00:00, 33816.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_screen_a1_1_b0_7: 100%|██████████| 16/16 [00:01<00:00, 12.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.42s\n",
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_fine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / Popularity: 100%|██████████| 95529/95529 [00:02<00:00, 33229.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_screen_a1_1_b0_7: 100%|██████████| 24/24 [00:00<00:00, 50.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.48s\n",
      "========================================================================================================================\n",
      "Cheap screen on formulation: H1_fine_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine_color / Popularity: 100%|██████████| 59346/59346 [00:01<00:00, 33767.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_screen_a1_1_b0_7\n",
      "    fit_time=0.15s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine_color / RP3beta_screen_a1_1_b0_7: 100%|██████████| 15/15 [00:03<00:00,  4.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=3.34s\n",
      "========================================================================================================================\n",
      "Formulation cheap-screen summary\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>screen_best_hr10</th>\n",
       "      <th>screen_best_ndcg10</th>\n",
       "      <th>screen_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>0.191070</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>0.159298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_fine_color</td>\n",
       "      <td>0.127877</td>\n",
       "      <td>0.060050</td>\n",
       "      <td>0.142890</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_color</td>\n",
       "      <td>0.119626</td>\n",
       "      <td>0.054071</td>\n",
       "      <td>0.133144</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     Formulation  screen_best_hr10  screen_best_ndcg10  screen_score\n",
       "0        H1_base          0.169518            0.086209      0.191070\n",
       "1        H1_fine          0.142208            0.068361      0.159298\n",
       "2  H1_fine_color          0.127877            0.060050      0.142890\n",
       "3       H1_color          0.119626            0.054071      0.133144"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected formulations for full regular search: ['H1_base', 'H1_fine']\n",
      "========================================================================================================================\n",
      "Full regular search on formulation: H1_base\n",
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_0_b0_6: 100%|██████████| 11/11 [00:00<00:00, 97.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_1_b0_7: 100%|██████████| 11/11 [00:00<00:00, 96.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_15_b0_7: 100%|██████████| 11/11 [00:00<00:00, 99.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=2640 gram_gb=0.03 prep_time=0.06s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=2640 gram_gb=0.03 prep_time=0.06s\n",
      "- Fitting EASE_binary_l1000\n",
      "    fit_time=0.15s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1000: 100%|██████████| 11/11 [00:00<00:00, 99.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Fitting EASE_binary_l1200\n",
      "    fit_time=0.10s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1200: 100%|██████████| 11/11 [00:00<00:00, 87.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.13s\n",
      "- Fitting EASE_binary_l1600\n",
      "    fit_time=0.11s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1600: 100%|██████████| 11/11 [00:00<00:00, 85.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l1200\n",
      "    fit_time=0.11s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l1200: 100%|██████████| 11/11 [00:00<00:00, 82.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l1600\n",
      "    fit_time=0.12s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l1600: 100%|██████████| 11/11 [00:00<00:00, 80.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Full regular search on formulation: H1_fine\n",
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_a1_0_b0_6: 100%|██████████| 24/24 [00:00<00:00, 59.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.41s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_a1_1_b0_7: 100%|██████████| 24/24 [00:00<00:00, 69.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_a1_15_b0_7: 100%|██████████| 24/24 [00:00<00:00, 67.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.36s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=3696 gram_gb=0.05 prep_time=0.05s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=3696 gram_gb=0.05 prep_time=0.04s\n",
      "- Fitting EASE_binary_l1000\n",
      "    fit_time=0.12s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l1000: 100%|██████████| 24/24 [00:00<00:00, 68.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.36s\n",
      "- Fitting EASE_binary_l1200\n",
      "    fit_time=0.13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l1200: 100%|██████████| 24/24 [00:00<00:00, 68.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.36s\n",
      "- Fitting EASE_binary_l1600\n",
      "    fit_time=0.12s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l1600: 100%|██████████| 24/24 [00:00<00:00, 68.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.36s\n",
      "- Fitting EASE_count_l1200\n",
      "    fit_time=0.13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_count_l1200: 100%|██████████| 24/24 [00:00<00:00, 69.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.36s\n",
      "- Fitting EASE_count_l1600\n",
      "    fit_time=0.13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_count_l1600: 100%|██████████| 24/24 [00:00<00:00, 66.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.37s\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "      <th>Phase</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099354</td>\n",
       "      <td>0.170050</td>\n",
       "      <td>0.312947</td>\n",
       "      <td>0.061000</td>\n",
       "      <td>0.086145</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099956</td>\n",
       "      <td>0.169958</td>\n",
       "      <td>0.312577</td>\n",
       "      <td>0.061645</td>\n",
       "      <td>0.086629</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099470</td>\n",
       "      <td>0.169842</td>\n",
       "      <td>0.313618</td>\n",
       "      <td>0.061275</td>\n",
       "      <td>0.086323</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099262</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.313109</td>\n",
       "      <td>0.061218</td>\n",
       "      <td>0.086209</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.099169</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.312299</td>\n",
       "      <td>0.060961</td>\n",
       "      <td>0.085968</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098567</td>\n",
       "      <td>0.169147</td>\n",
       "      <td>0.312716</td>\n",
       "      <td>0.061023</td>\n",
       "      <td>0.085960</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098289</td>\n",
       "      <td>0.168221</td>\n",
       "      <td>0.310216</td>\n",
       "      <td>0.060397</td>\n",
       "      <td>0.085273</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098127</td>\n",
       "      <td>0.168129</td>\n",
       "      <td>0.311003</td>\n",
       "      <td>0.060514</td>\n",
       "      <td>0.085340</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083137</td>\n",
       "      <td>0.146018</td>\n",
       "      <td>0.268986</td>\n",
       "      <td>0.050754</td>\n",
       "      <td>0.072685</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082300</td>\n",
       "      <td>0.145338</td>\n",
       "      <td>0.268913</td>\n",
       "      <td>0.050044</td>\n",
       "      <td>0.071969</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.084278</td>\n",
       "      <td>0.145055</td>\n",
       "      <td>0.269562</td>\n",
       "      <td>0.050897</td>\n",
       "      <td>0.072607</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.083472</td>\n",
       "      <td>0.144731</td>\n",
       "      <td>0.269552</td>\n",
       "      <td>0.050221</td>\n",
       "      <td>0.071991</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.081159</td>\n",
       "      <td>0.144532</td>\n",
       "      <td>0.269740</td>\n",
       "      <td>0.048689</td>\n",
       "      <td>0.070727</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.082457</td>\n",
       "      <td>0.144522</td>\n",
       "      <td>0.268892</td>\n",
       "      <td>0.049771</td>\n",
       "      <td>0.071576</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_15_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.076961</td>\n",
       "      <td>0.142292</td>\n",
       "      <td>0.268903</td>\n",
       "      <td>0.046265</td>\n",
       "      <td>0.068288</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_screen_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.077170</td>\n",
       "      <td>0.142208</td>\n",
       "      <td>0.269060</td>\n",
       "      <td>0.046382</td>\n",
       "      <td>0.068361</td>\n",
       "      <td>full</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>Popularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.002847</td>\n",
       "      <td>0.006250</td>\n",
       "      <td>0.012037</td>\n",
       "      <td>0.001778</td>\n",
       "      <td>0.002793</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>Popularity</td>\n",
       "      <td>95529</td>\n",
       "      <td>0.002450</td>\n",
       "      <td>0.004878</td>\n",
       "      <td>0.009882</td>\n",
       "      <td>0.001397</td>\n",
       "      <td>0.002194</td>\n",
       "      <td>screen</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Formulation                     Model  UsersEval      HR@5     HR@10  \\\n",
       "0      H1_base         EASE_binary_l1000      43199  0.099354  0.170050   \n",
       "1      H1_base         RP3beta_a1_0_b0_6      43199  0.099956  0.169958   \n",
       "2      H1_base        RP3beta_a1_15_b0_7      43199  0.099470  0.169842   \n",
       "3      H1_base  RP3beta_screen_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "4      H1_base         RP3beta_a1_1_b0_7      43199  0.099262  0.169518   \n",
       "5      H1_base         EASE_binary_l1200      43199  0.099169  0.169356   \n",
       "6      H1_base         EASE_binary_l1600      43199  0.098567  0.169147   \n",
       "7      H1_base          EASE_count_l1200      43199  0.098289  0.168221   \n",
       "8      H1_base          EASE_count_l1600      43199  0.098127  0.168129   \n",
       "9      H1_fine          EASE_count_l1600      95529  0.083137  0.146018   \n",
       "10     H1_fine          EASE_count_l1200      95529  0.082300  0.145338   \n",
       "11     H1_fine         EASE_binary_l1600      95529  0.084278  0.145055   \n",
       "12     H1_fine         EASE_binary_l1200      95529  0.083472  0.144731   \n",
       "13     H1_fine         RP3beta_a1_0_b0_6      95529  0.081159  0.144532   \n",
       "14     H1_fine         EASE_binary_l1000      95529  0.082457  0.144522   \n",
       "15     H1_fine        RP3beta_a1_15_b0_7      95529  0.076961  0.142292   \n",
       "16     H1_fine  RP3beta_screen_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "17     H1_fine         RP3beta_a1_1_b0_7      95529  0.077170  0.142208   \n",
       "18     H1_base                Popularity      43199  0.002847  0.006250   \n",
       "19     H1_fine                Popularity      95529  0.002450  0.004878   \n",
       "\n",
       "       HR@20    MRR@10   NDCG@10   Phase  \n",
       "0   0.312947  0.061000  0.086145    full  \n",
       "1   0.312577  0.061645  0.086629    full  \n",
       "2   0.313618  0.061275  0.086323    full  \n",
       "3   0.313109  0.061218  0.086209  screen  \n",
       "4   0.313109  0.061218  0.086209    full  \n",
       "5   0.312299  0.060961  0.085968    full  \n",
       "6   0.312716  0.061023  0.085960    full  \n",
       "7   0.310216  0.060397  0.085273    full  \n",
       "8   0.311003  0.060514  0.085340    full  \n",
       "9   0.268986  0.050754  0.072685    full  \n",
       "10  0.268913  0.050044  0.071969    full  \n",
       "11  0.269562  0.050897  0.072607    full  \n",
       "12  0.269552  0.050221  0.071991    full  \n",
       "13  0.269740  0.048689  0.070727    full  \n",
       "14  0.268892  0.049771  0.071576    full  \n",
       "15  0.268903  0.046265  0.068288    full  \n",
       "16  0.269060  0.046382  0.068361  screen  \n",
       "17  0.269060  0.046382  0.068361    full  \n",
       "18  0.012037  0.001778  0.002793  screen  \n",
       "19  0.009882  0.001397  0.002194  screen  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected formulations for neural search: ['H1_base', 'H1_fine']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Regular model search:\n",
    "# 1) Cheap screen on all formulations using only fast models\n",
    "# 2) Full regular search only on the strongest formulations\n",
    "screen_rows = []\n",
    "\n",
    "for formulation_name, bundle in bundles.items():\n",
    "    print(\"=\" * 120)\n",
    "    print(\"Cheap screen on formulation:\", formulation_name)\n",
    "\n",
    "    pop_res = evaluate_popularity(bundle, model_name=\"Popularity\")\n",
    "    pop_res[\"Phase\"] = \"screen\"\n",
    "    screen_rows.append(pop_res)\n",
    "\n",
    "    alpha, beta = SCREEN_RP3\n",
    "    clear_memory()\n",
    "    model_name = f\"RP3beta_screen_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "    print(\"- Fitting\", model_name)\n",
    "    t_fit = time.time()\n",
    "    S = fit_p3alpha(bundle[\"X_binary\"], alpha=alpha, beta=beta)\n",
    "    S_dense = S.toarray().astype(np.float32, copy=False)\n",
    "    fit_sec = time.time() - t_fit\n",
    "    print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "    t_eval = time.time()\n",
    "    rp3_res = evaluate_dense_operator(bundle, S_dense, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "    rp3_res[\"Phase\"] = \"screen\"\n",
    "    screen_rows.append(rp3_res)\n",
    "    print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "    del S, S_dense\n",
    "    clear_memory()\n",
    "\n",
    "screen_results_df = pd.DataFrame(screen_rows)\n",
    "screen_summary = (\n",
    "    screen_results_df.groupby(\"Formulation\")\n",
    "    .agg(\n",
    "        screen_best_hr10=(\"HR@10\", \"max\"),\n",
    "        screen_best_ndcg10=(\"NDCG@10\", \"max\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "screen_summary[\"screen_score\"] = screen_summary[\"screen_best_hr10\"] + 0.25 * screen_summary[\"screen_best_ndcg10\"]\n",
    "screen_summary = screen_summary.sort_values([\"screen_score\", \"screen_best_hr10\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "print(\"=\" * 120)\n",
    "print(\"Formulation cheap-screen summary\")\n",
    "display(screen_summary)\n",
    "\n",
    "selected_regular_formulations = screen_summary.head(TOP_FORMULATIONS_FOR_FULL_REGULAR)[\"Formulation\"].tolist()\n",
    "print(\"Selected formulations for full regular search:\", selected_regular_formulations)\n",
    "\n",
    "regular_results = []\n",
    "# Keep the screen results too, but only for selected formulations\n",
    "regular_results.extend(\n",
    "    screen_results_df[screen_results_df[\"Formulation\"].isin(selected_regular_formulations)].to_dict(\"records\")\n",
    ")\n",
    "\n",
    "for formulation_name in selected_regular_formulations:\n",
    "    bundle = bundles[formulation_name]\n",
    "    print(\"=\" * 120)\n",
    "    print(\"Full regular search on formulation:\", formulation_name)\n",
    "\n",
    "    # RP3beta grid\n",
    "    for alpha, beta in RP3_GRID:\n",
    "        clear_memory()\n",
    "        model_name = f\"RP3beta_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        S = fit_p3alpha(bundle[\"X_binary\"], alpha=alpha, beta=beta)\n",
    "        S_dense = S.toarray().astype(np.float32, copy=False)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        res = evaluate_dense_operator(bundle, S_dense, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        res[\"Phase\"] = \"full\"\n",
    "        regular_results.append(res)\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del S, S_dense\n",
    "        clear_memory()\n",
    "\n",
    "    # EASE only if the item space is not too large\n",
    "    n_items = int(bundle[\"num_items\"])\n",
    "    if n_items > MAX_EASE_ITEMS:\n",
    "        print(f\"- Skipping EASE on {formulation_name}: n_items={n_items} > MAX_EASE_ITEMS={MAX_EASE_ITEMS}\")\n",
    "        continue\n",
    "\n",
    "    clear_memory()\n",
    "    print(\"- Preparing EASE binary Gram cache\")\n",
    "    ease_bin_cache = prepare_ease_cache(bundle[\"X_binary\"], source_name=\"binary\", prefer_gpu=True)\n",
    "    print(\n",
    "        f\"    backend={ease_bin_cache['backend']} \"\n",
    "        f\"n_items={ease_bin_cache['n_items']} \"\n",
    "        f\"gram_gb={ease_bin_cache['gram_gb']:.2f} \"\n",
    "        f\"prep_time={ease_bin_cache['prep_sec']:.2f}s\"\n",
    "    )\n",
    "\n",
    "    print(\"- Preparing EASE count Gram cache\")\n",
    "    ease_cnt_cache = prepare_ease_cache(bundle[\"X_counts\"], source_name=\"count\", prefer_gpu=True)\n",
    "    print(\n",
    "        f\"    backend={ease_cnt_cache['backend']} \"\n",
    "        f\"n_items={ease_cnt_cache['n_items']} \"\n",
    "        f\"gram_gb={ease_cnt_cache['gram_gb']:.2f} \"\n",
    "        f\"prep_time={ease_cnt_cache['prep_sec']:.2f}s\"\n",
    "    )\n",
    "\n",
    "    for lam in EASE_BINARY_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_binary_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        B = fit_ease_from_cache(ease_bin_cache, lam=lam)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        res = evaluate_dense_operator(bundle, B, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        res[\"Phase\"] = \"full\"\n",
    "        regular_results.append(res)\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "    for lam in EASE_COUNT_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_count_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        B = fit_ease_from_cache(ease_cnt_cache, lam=lam)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        res = evaluate_dense_operator(bundle, B, model_name=model_name, source=\"count\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        res[\"Phase\"] = \"full\"\n",
    "        regular_results.append(res)\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "    del ease_bin_cache, ease_cnt_cache\n",
    "    clear_memory()\n",
    "\n",
    "regular_results_df = (\n",
    "    pd.DataFrame(regular_results)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "display(regular_results_df.head(40))\n",
    "\n",
    "\n",
    "# Choose formulations for neural search from the strongest regular results.\n",
    "selected_neural_formulations = (\n",
    "    regular_results_df.groupby(\"Formulation\")[\"HR@10\"]\n",
    "    .max()\n",
    "    .sort_values(ascending=False)\n",
    "    .head(TOP_FORMULATIONS_FOR_NEURAL)\n",
    "    .index.tolist()\n",
    ")\n",
    "print(\"Selected formulations for neural search:\", selected_neural_formulations)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ec3cf5b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing neural tensors: H1_base\n",
      "Preparing neural tensors: H1_fine\n"
     ]
    }
   ],
   "source": [
    "# Prepare formulation-specific tensors for neural models\n",
    "\n",
    "def prepare_neural_data(bundle):\n",
    "    user_feature_df = bundle[\"user_feature_df\"].copy()\n",
    "    item_feature_df = bundle[\"item_feature_df\"].copy()\n",
    "    train_interactions = bundle[\"train_interactions\"].copy()\n",
    "\n",
    "    user_cols = bundle[\"user_cols\"]\n",
    "    item_cols = bundle[\"item_cols\"]\n",
    "\n",
    "    feature_encoders = {}\n",
    "\n",
    "    for col in user_cols:\n",
    "        le = LabelEncoder()\n",
    "        user_feature_df[col] = le.fit_transform(user_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"user::{col}\"] = le\n",
    "\n",
    "    for col in item_cols:\n",
    "        le = LabelEncoder()\n",
    "        item_feature_df[col] = le.fit_transform(item_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"item::{col}\"] = le\n",
    "\n",
    "    user_feat_np = user_feature_df[user_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "    item_feat_np = item_feature_df[item_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "\n",
    "    pos_user_np = train_interactions[\"user_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_item_np = train_interactions[\"item_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_weight_np = train_interactions[\"count\"].to_numpy(dtype=np.float32, copy=True)\n",
    "\n",
    "    x_binary_np = bundle[\"X_binary\"].toarray().astype(np.float32, copy=False)\n",
    "\n",
    "    out = {\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"user_feat_t\": torch.from_numpy(user_feat_np).to(device),\n",
    "        \"item_feat_t\": torch.from_numpy(item_feat_np).to(device),\n",
    "        \"pos_user_t\": torch.from_numpy(pos_user_np).to(device),\n",
    "        \"pos_item_t\": torch.from_numpy(pos_item_np).to(device),\n",
    "        \"pos_weight_t\": torch.from_numpy(pos_weight_np).to(device),\n",
    "        \"x_binary_np\": x_binary_np,\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        out[\"x_binary_t\"] = torch.from_numpy(x_binary_np).to(device)\n",
    "        out[\"x_binary_on_device\"] = True\n",
    "    except RuntimeError:\n",
    "        out[\"x_binary_t\"] = None\n",
    "        out[\"x_binary_on_device\"] = False\n",
    "        clear_memory()\n",
    "\n",
    "    out[\"user_feature_cards\"] = {\n",
    "        col: int(user_feature_df[col].nunique()) for col in user_cols\n",
    "    }\n",
    "    out[\"item_feature_cards\"] = {\n",
    "        col: int(item_feature_df[col].nunique()) for col in item_cols\n",
    "    }\n",
    "\n",
    "    return out\n",
    "\n",
    "neural_data_by_formulation = {}\n",
    "for formulation_name in selected_neural_formulations:\n",
    "    print(\"Preparing neural tensors:\", formulation_name)\n",
    "    neural_data_by_formulation[formulation_name] = prepare_neural_data(bundles[formulation_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7c9a604a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural architectures and batched evaluators\n",
    "\n",
    "class FeatureEmbeddingBlock(nn.Module):\n",
    "    def __init__(self, cardinalities, emb_dim):\n",
    "        super().__init__()\n",
    "        self.cols = list(cardinalities.keys())\n",
    "        self.embs = nn.ModuleList([\n",
    "            nn.Embedding(int(cardinalities[c]), emb_dim) for c in self.cols\n",
    "        ])\n",
    "        for emb in self.embs:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, x):\n",
    "        parts = [emb(x[:, i]) for i, emb in enumerate(self.embs)]\n",
    "        return torch.cat(parts, dim=1)\n",
    "\n",
    "class MLPBlock(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, dropout=0.1, final_dim=None):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        prev = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(dropout)])\n",
    "            prev = h\n",
    "        if final_dim is not None:\n",
    "            layers.append(nn.Linear(prev, final_dim))\n",
    "            prev = final_dim\n",
    "        self.net = nn.Sequential(*layers)\n",
    "        self.output_dim = prev\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "class TowerEncoder(nn.Module):\n",
    "    def __init__(self, cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.embed = FeatureEmbeddingBlock(cards, emb_dim)\n",
    "        input_dim = len(cards) * emb_dim\n",
    "        self.mlp = MLPBlock(input_dim, hidden_dims, dropout=dropout, final_dim=out_dim)\n",
    "\n",
    "    def forward(self, feat_batch):\n",
    "        x = self.embed(feat_batch)\n",
    "        x = self.mlp(x)\n",
    "        return F.normalize(x, dim=-1)\n",
    "\n",
    "class TwoTowerModel(nn.Module):\n",
    "    def __init__(self, user_cards, item_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "        super().__init__()\n",
    "        self.user_tower = TowerEncoder(user_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "        self.item_tower = TowerEncoder(item_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "class MultVAE(nn.Module):\n",
    "    def __init__(self, num_items, hidden_dim=2048, latent_dim=512, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(num_items, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(latent_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, num_items),\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = self.encoder(self.dropout(x))\n",
    "        return self.mu(h), self.logvar(h)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        if self.training:\n",
    "            std = torch.exp(0.5 * logvar)\n",
    "            eps = torch.randn_like(std)\n",
    "            return mu + eps * std\n",
    "        return mu\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        logits = self.decoder(z)\n",
    "        return logits, mu, logvar\n",
    "\n",
    "class NeuMF(nn.Module):\n",
    "    def __init__(self, num_users, num_items, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128)):\n",
    "        super().__init__()\n",
    "        self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "        self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "        self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "        self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "        mlp_layers = []\n",
    "        prev = 2 * mlp_dim\n",
    "        for h in hidden_dims:\n",
    "            mlp_layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(0.1)])\n",
    "            prev = h\n",
    "        self.mlp = nn.Sequential(*mlp_layers)\n",
    "        self.out = nn.Linear(prev + mf_dim, 1)\n",
    "\n",
    "        for emb in [self.user_mf, self.item_mf, self.user_mlp, self.item_mlp]:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, user_idx, item_idx):\n",
    "        mf_part = self.user_mf(user_idx) * self.item_mf(item_idx)\n",
    "        mlp_in = torch.cat([self.user_mlp(user_idx), self.item_mlp(item_idx)], dim=1)\n",
    "        mlp_part = self.mlp(mlp_in)\n",
    "        x = torch.cat([mf_part, mlp_part], dim=1)\n",
    "        return self.out(x).squeeze(-1)\n",
    "\n",
    "def iter_positive_batches(ndata, batch_size):\n",
    "    n = ndata[\"pos_user_t\"].shape[0]\n",
    "    perm = torch.randperm(n, device=device)\n",
    "    for start in range(0, n, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        yield (\n",
    "            ndata[\"pos_user_t\"][idx],\n",
    "            ndata[\"pos_item_t\"][idx],\n",
    "            ndata[\"pos_weight_t\"][idx],\n",
    "        )\n",
    "\n",
    "def iter_dense_user_batches(ndata, batch_size):\n",
    "    n_users = ndata[\"x_binary_np\"].shape[0]\n",
    "    perm = np.random.permutation(n_users)\n",
    "    for start in range(0, n_users, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        if ndata[\"x_binary_on_device\"]:\n",
    "            yield ndata[\"x_binary_t\"][idx], torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "        else:\n",
    "            batch = torch.from_numpy(ndata[\"x_binary_np\"][idx]).to(device)\n",
    "            yield batch, torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_twotower(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    item_feats = ndata[\"item_feat_t\"]\n",
    "    item_vecs = []\n",
    "    for start in range(0, item_feats.shape[0], 4096):\n",
    "        item_vecs.append(model.item_tower(item_feats[start:start + 4096]))\n",
    "    item_matrix = torch.cat(item_vecs, dim=0)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        feat_batch = ndata[\"user_feat_t\"][batch_uids]\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        user_vec = model.user_tower(feat_batch)\n",
    "        scores = user_vec @ item_matrix.T\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del feat_batch, x_seen, user_vec, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_multvae(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        x_batch = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        logits, _, _ = model(x_batch)\n",
    "        logits = logits.masked_fill(x_batch > 0, -1e9)\n",
    "        topk = torch.topk(logits, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([bundle[\"test_item_by_user\"][int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del x_batch, logits, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_neumf(bundle, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    num_items = bundle[\"num_items\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    all_items = torch.arange(num_items, device=device, dtype=torch.long)\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEUMF_EVAL_USER_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEUMF_EVAL_USER_BATCH]\n",
    "        user_batch = torch.as_tensor(batch_uids, device=device, dtype=torch.long)\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "\n",
    "        parts = []\n",
    "        for istart in range(0, num_items, NEUMF_EVAL_ITEM_BATCH):\n",
    "            item_batch = all_items[istart:istart + NEUMF_EVAL_ITEM_BATCH]\n",
    "            u_expand = user_batch[:, None].expand(-1, item_batch.shape[0]).reshape(-1)\n",
    "            i_expand = item_batch[None, :].expand(user_batch.shape[0], -1).reshape(-1)\n",
    "            scores_part = model(u_expand, i_expand).reshape(user_batch.shape[0], -1)\n",
    "            parts.append(scores_part)\n",
    "\n",
    "        scores = torch.cat(parts, dim=1)\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del user_batch, x_seen, parts, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7f000e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural training functions\n",
    "\n",
    "def train_two_tower(bundle, ndata, cfg):\n",
    "    model = TwoTowerModel(\n",
    "        user_cards=ndata[\"user_feature_cards\"],\n",
    "        item_cards=ndata[\"item_feature_cards\"],\n",
    "        emb_dim=cfg[\"emb_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "        out_dim=cfg[\"out_dim\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "        for user_idx_batch, item_idx_batch, weight_batch in iter_positive_batches(ndata, TWOTOWER_BATCH):\n",
    "            user_batch = ndata[\"user_feat_t\"][user_idx_batch]\n",
    "            item_batch = ndata[\"item_feat_t\"][item_idx_batch]\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                user_vec = model.user_tower(user_batch)\n",
    "                item_vec = model.item_tower(item_batch)\n",
    "                logits = (user_vec @ item_vec.T) / cfg[\"temperature\"]\n",
    "                targets = torch.arange(logits.shape[0], device=device)\n",
    "                loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "            total_w += float(weight_batch.sum().item())\n",
    "\n",
    "        print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def train_multvae(bundle, ndata, cfg):\n",
    "    model = MultVAE(\n",
    "        num_items=bundle[\"num_items\"],\n",
    "        hidden_dim=cfg[\"hidden_dim\"],\n",
    "        latent_dim=cfg[\"latent_dim\"],\n",
    "        dropout=cfg[\"dropout\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    steps_per_epoch = max(1, math.ceil(n_users / AUTOENC_BATCH))\n",
    "    total_steps = max(1, cfg[\"epochs\"] * steps_per_epoch)\n",
    "    step = 0\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        total_loss = 0.0\n",
    "        total_rows = 0\n",
    "\n",
    "        for batch_x, _ in iter_dense_user_batches(ndata, AUTOENC_BATCH):\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            anneal = min(cfg[\"anneal_cap\"], step / max(total_steps * 0.3, 1))\n",
    "\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                logits, mu, logvar = model(batch_x)\n",
    "                recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                loss = (recon + anneal * kl).mean()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "            total_rows += batch_x.shape[0]\n",
    "            step += 1\n",
    "\n",
    "        if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == cfg[\"epochs\"]:\n",
    "            print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_rows, 1):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def train_neumf(bundle, cfg):\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    n_items = bundle[\"num_items\"]\n",
    "    train_df = bundle[\"train_interactions\"]\n",
    "\n",
    "    pos_user = torch.as_tensor(train_df[\"user_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_item = torch.as_tensor(train_df[\"item_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_weight = torch.as_tensor(train_df[\"count\"].to_numpy(dtype=np.float32), device=device)\n",
    "\n",
    "    model = NeuMF(\n",
    "        num_users=n_users,\n",
    "        num_items=n_items,\n",
    "        mf_dim=cfg[\"mf_dim\"],\n",
    "        mlp_dim=cfg[\"mlp_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    n = pos_user.shape[0]\n",
    "    model.train()\n",
    "\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        perm = torch.randperm(n, device=device)\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        for start in range(0, n, PAIRWISE_BATCH):\n",
    "            idx = perm[start:start + PAIRWISE_BATCH]\n",
    "            u = pos_user[idx]\n",
    "            p = pos_item[idx]\n",
    "            w = pos_weight[idx]\n",
    "            n_item = torch.randint(0, n_items, size=p.shape, device=device)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                pos_scores = model(u, p)\n",
    "                neg_scores = model(u, n_item)\n",
    "                loss_vec = -F.logsigmoid(pos_scores - neg_scores)\n",
    "                loss = (loss_vec * w).sum() / w.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(w.sum().item())\n",
    "            total_w += float(w.sum().item())\n",
    "\n",
    "        print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "38b2ee8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Neural search on formulation: H1_base\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 4.00 GiB. GPU 0 has a total capacity of 9.64 GiB of which 2.34 GiB is free. Process 1192 has 22.22 MiB memory in use. Including non-PyTorch memory, this process has 6.64 GiB memory in use. Of the allocated memory 6.35 GiB is allocated by PyTorch, and 17.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mOutOfMemoryError\u001b[39m                          Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 15\u001b[39m\n\u001b[32m     11\u001b[39m         ndata = neural_data_by_formulation[formulation_name]\n\u001b[32m     12\u001b[39m \n\u001b[32m     13\u001b[39m         \u001b[38;5;28;01mfor\u001b[39;00m cfg \u001b[38;5;28;01min\u001b[39;00m TWOTOWER_CONFIGS:\n\u001b[32m     14\u001b[39m             clear_memory()\n\u001b[32m---> \u001b[39m\u001b[32m15\u001b[39m             model = train_two_tower(bundle, ndata, cfg)\n\u001b[32m     16\u001b[39m             neural_results.append(evaluate_twotower(bundle, ndata, model, cfg[\u001b[33m\"name\"\u001b[39m]))\n\u001b[32m     17\u001b[39m             \u001b[38;5;28;01mdel\u001b[39;00m model\n\u001b[32m     18\u001b[39m             clear_memory()\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[11]\u001b[39m\u001b[32m, line 29\u001b[39m, in \u001b[36mtrain_two_tower\u001b[39m\u001b[34m(bundle, ndata, cfg)\u001b[39m\n\u001b[32m     25\u001b[39m                 user_vec = model.user_tower(user_batch)\n\u001b[32m     26\u001b[39m                 item_vec = model.item_tower(item_batch)\n\u001b[32m     27\u001b[39m                 logits = (user_vec @ item_vec.T) / cfg[\u001b[33m\"temperature\"\u001b[39m]\n\u001b[32m     28\u001b[39m                 targets = torch.arange(logits.shape[\u001b[32m0\u001b[39m], device=device)\n\u001b[32m---> \u001b[39m\u001b[32m29\u001b[39m                 loss_vec = F.cross_entropy(logits, targets, reduction=\u001b[33m\"none\"\u001b[39m)\n\u001b[32m     30\u001b[39m                 loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n\u001b[32m     31\u001b[39m \n\u001b[32m     32\u001b[39m             scaler.scale(loss).backward()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.14/site-packages/torch/nn/functional.py:3504\u001b[39m, in \u001b[36mcross_entropy\u001b[39m\u001b[34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[39m\n\u001b[32m   3502\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m   3503\u001b[39m     reduction = _Reduction.legacy_get_string(size_average, reduce)\n\u001b[32m-> \u001b[39m\u001b[32m3504\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_C\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_nn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcross_entropy_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m   3505\u001b[39m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m   3506\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   3507\u001b[39m \u001b[43m    \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   3508\u001b[39m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# pyrefly: ignore [bad-argument-type]\u001b[39;49;00m\n\u001b[32m   3509\u001b[39m \u001b[43m    \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   3510\u001b[39m \u001b[43m    \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   3511\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   3512\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[31mOutOfMemoryError\u001b[39m: CUDA out of memory. Tried to allocate 4.00 GiB. GPU 0 has a total capacity of 9.64 GiB of which 2.34 GiB is free. Process 1192 has 22.22 MiB memory in use. Including non-PyTorch memory, this process has 6.64 GiB memory in use. Of the allocated memory 6.35 GiB is allocated by PyTorch, and 17.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)"
     ]
    }
   ],
   "source": [
    "# Neural search on the best formulations only\n",
    "\n",
    "neural_results = []\n",
    "\n",
    "if RUN_NEURAL:\n",
    "    for formulation_name in selected_neural_formulations:\n",
    "        print(\"=\" * 120)\n",
    "        print(\"Neural search on formulation:\", formulation_name)\n",
    "\n",
    "        bundle = bundles[formulation_name]\n",
    "        ndata = neural_data_by_formulation[formulation_name]\n",
    "\n",
    "        for cfg in TWOTOWER_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_two_tower(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_twotower(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in MULTVAE_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_multvae(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_multvae(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in NEUMF_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_neumf(bundle, cfg)\n",
    "            neural_results.append(evaluate_neumf(bundle, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "    neural_results_df = (\n",
    "        pd.DataFrame(neural_results)\n",
    "        .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "else:\n",
    "    neural_results_df = pd.DataFrame()\n",
    "\n",
    "print(\"Top neural results:\")\n",
    "display(neural_results_df.head(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d0e73c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional fusion on the best formulation\n",
    "\n",
    "def make_python_recommenders_for_fusion(bundle, formulation_name, regular_results_df, neural_results_df):\n",
    "    out = {}\n",
    "\n",
    "    best_rp3 = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"RP3beta_\")\n",
    "    best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_binary_\")\n",
    "    if best_ease is None:\n",
    "        best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_count_\")\n",
    "    best_tt = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"TwoTower_\") if not neural_results_df.empty else None\n",
    "    best_vae = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"MultVAE_\") if not neural_results_df.empty else None\n",
    "\n",
    "    return best_ease, best_rp3, best_tt, best_vae\n",
    "\n",
    "fusion_results = []\n",
    "\n",
    "if RUN_FUSION and not neural_results_df.empty:\n",
    "    best_formulation = regular_results_df.iloc[0][\"Formulation\"]\n",
    "    bundle = bundles[best_formulation]\n",
    "    print(\"Fusion on formulation:\", best_formulation)\n",
    "\n",
    "    reg_sub = regular_results_df[regular_results_df[\"Formulation\"] == best_formulation]\n",
    "    neu_sub = neural_results_df[neural_results_df[\"Formulation\"] == best_formulation]\n",
    "\n",
    "    ease_name = top_model_name(reg_sub, \"EASE_binary_\")\n",
    "    if ease_name is None:\n",
    "        ease_name = top_model_name(reg_sub, \"EASE_count_\")\n",
    "    rp3_name = top_model_name(reg_sub, \"RP3beta_\")\n",
    "    tt_name = top_model_name(neu_sub, \"TwoTower_\")\n",
    "    vae_name = top_model_name(neu_sub, \"MultVAE_\")\n",
    "\n",
    "    print(\"Best for fusion:\", ease_name, rp3_name, tt_name, vae_name)\n",
    "else:\n",
    "    print(\"Fusion skipped or no neural results.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb491a3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Final comparison tables\n",
    "\n",
    "all_frames = [regular_results_df]\n",
    "if not neural_results_df.empty:\n",
    "    all_frames.append(neural_results_df)\n",
    "\n",
    "all_results_df = (\n",
    "    pd.concat(all_frames, ignore_index=True)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "print(\"Top overall results:\")\n",
    "display(all_results_df.head(30))\n",
    "\n",
    "print(\"Best per formulation:\")\n",
    "best_per_formulation = (\n",
    "    all_results_df.groupby(\"Formulation\")\n",
    "    .head(10)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "display(best_per_formulation)\n",
    "\n",
    "print(\"Best per model family:\")\n",
    "family_df = all_results_df.copy()\n",
    "family_df[\"Family\"] = family_df[\"Model\"].str.extract(r\"^([A-Za-z]+(?:_[A-Za-z]+)?)\")[0]\n",
    "display(\n",
    "    family_df.groupby(\"Family\")[[\"HR@10\", \"HR@20\", \"MRR@10\", \"NDCG@10\"]]\n",
    "    .max()\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\"], ascending=False)\n",
    "    .reset_index()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc27b61",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "This notebook removes the old automatic screening stage that previously caused very long EASE runs and high RAM usage.\n",
    "\n",
    "What is threaded or batched here:\n",
    "- BLAS / LAPACK thread env vars are set before importing NumPy and SciPy\n",
    "- `threadpoolctl` is applied when available\n",
    "- EASE uses a Cholesky-based dense solve\n",
    "- EASE and RP3 evaluation are fully batched\n",
    "- heavy score computation is moved to the GPU when available\n",
    "- neural training uses manual on-device batching instead of multiprocessing DataLoaders\n",
    "\n",
    "If you still need more speed:\n",
    "- lower `TOP_FORMULATIONS_FOR_NEURAL`\n",
    "- lower `TWOTOWER_CONFIGS` / `MULTVAE_CONFIGS`\n",
    "- lower `EVAL_MAX_USERS`\n",
    "- lower `LINEAR_EVAL_BATCH` or `NEURAL_EVAL_BATCH` if GPU memory becomes the bottleneck\n",
    "\n",
    "Additional EASE fit changes in this version:\n",
    "- Gram matrices are built once per formulation and reused across all EASE lambdas\n",
    "- EASE fitting can use GPU Cholesky solve when CUDA is available\n",
    "- fit time and eval time are printed separately so long pre-bar stalls are visible\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
