{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6979c756",
   "metadata": {},
   "source": [
    "# Car recommendation notebook — H1 variants, tuned top models, batched evaluation\n",
    "\n",
    "This notebook focuses on richer H1-style formulations only and removes the old slow screening path.\n",
    "\n",
    "Goals:\n",
    "- try several harder H1 variants\n",
    "- improve the current top families: RP3beta, EASE, TwoTower, MultVAE\n",
    "- batch all heavy evaluation paths\n",
    "- use GPU where it helps and CPU threads where possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8fb2968a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU thread target: 16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "\n",
    "CPU_THREADS = min(16, os.cpu_count() or 1)\n",
    "for var in [\n",
    "    \"OMP_NUM_THREADS\",\n",
    "    \"OPENBLAS_NUM_THREADS\",\n",
    "    \"MKL_NUM_THREADS\",\n",
    "    \"NUMEXPR_NUM_THREADS\",\n",
    "    \"VECLIB_MAXIMUM_THREADS\",\n",
    "]:\n",
    "    os.environ[var] = str(CPU_THREADS)\n",
    "\n",
    "import gc\n",
    "import math\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy import linalg as sla\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "try:\n",
    "    from threadpoolctl import threadpool_limits\n",
    "except Exception:\n",
    "    threadpool_limits = None\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "\n",
    "if threadpool_limits is not None:\n",
    "    try:\n",
    "        threadpool_limits(limits=CPU_THREADS)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"CPU thread target:\", CPU_THREADS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6731d71d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Formulations: ['H1_base', 'H1_color', 'H1_fine', 'H1_fine_color']\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "CURRENT_YEAR = 2026\n",
    "\n",
    "# Binning\n",
    "PRICE_BINS_BASE = 12\n",
    "PRICE_BINS_FINE = 16\n",
    "MILEAGE_BINS_BASE = 12\n",
    "MILEAGE_BINS_FINE = 16\n",
    "AGE_BINS_BASE = 10\n",
    "AGE_BINS_FINE = 14\n",
    "\n",
    "# Filtering\n",
    "FILTER_SCHEDULE = [\n",
    "    (5, 10),\n",
    "    (4, 8),\n",
    "    (3, 5),\n",
    "    (2, 3),\n",
    "]\n",
    "MAX_FILTER_ITERS = 10\n",
    "\n",
    "# Evaluation\n",
    "TOP_KS = [5, 10, 20]\n",
    "EVAL_MAX_USERS = None\n",
    "LINEAR_EVAL_BATCH = 4096\n",
    "NEURAL_EVAL_BATCH = 4096\n",
    "\n",
    "# Manual richer H1 variants only\n",
    "FORMULATIONS = {\n",
    "    \"H1_base\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H1_fine\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_fine_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "# Try all formulations, then keep the top two for neural training\n",
    "TOP_FORMULATIONS_FOR_NEURAL = 2\n",
    "\n",
    "# Best current families only\n",
    "RP3_GRID = [\n",
    "    (1.00, 0.60),\n",
    "    (1.05, 0.60),\n",
    "    (1.10, 0.70),\n",
    "    (1.15, 0.70),\n",
    "]\n",
    "\n",
    "EASE_BINARY_LAMBDAS = [800.0, 1000.0, 1200.0, 1600.0, 2200.0]\n",
    "EASE_COUNT_LAMBDAS = [800.0, 1200.0, 1600.0, 2200.0]\n",
    "\n",
    "TWOTOWER_CONFIGS = [\n",
    "    {\"name\": \"TwoTower_t2\", \"emb_dim\": 96,  \"hidden_dims\": (768, 384, 192),  \"out_dim\": 160, \"epochs\": 28, \"lr\": 1.5e-3, \"wd\": 1e-5, \"temperature\": 0.05},\n",
    "    {\"name\": \"TwoTower_t3\", \"emb_dim\": 128, \"hidden_dims\": (1024, 512, 256), \"out_dim\": 192, \"epochs\": 32, \"lr\": 1.2e-3, \"wd\": 1e-5, \"temperature\": 0.04},\n",
    "]\n",
    "\n",
    "MULTVAE_CONFIGS = [\n",
    "    {\"name\": \"MultVAE_v3\", \"hidden_dim\": 2048, \"latent_dim\": 512, \"dropout\": 0.30, \"epochs\": 100, \"lr\": 6e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "    {\"name\": \"MultVAE_v4\", \"hidden_dim\": 2560, \"latent_dim\": 640, \"dropout\": 0.35, \"epochs\": 110, \"lr\": 5e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "]\n",
    "\n",
    "NEUMF_CONFIGS = [\n",
    "    {\"name\": \"NeuMF_n2\", \"mf_dim\": 96, \"mlp_dim\": 192, \"hidden_dims\": (384, 192, 96), \"epochs\": 26, \"lr\": 1.5e-3, \"wd\": 1e-6},\n",
    "]\n",
    "\n",
    "TWOTOWER_BATCH = 32768\n",
    "PAIRWISE_BATCH = 32768\n",
    "AUTOENC_BATCH = 4096\n",
    "NEUMF_EVAL_USER_BATCH = 512\n",
    "NEUMF_EVAL_ITEM_BATCH = 1024\n",
    "\n",
    "RUN_NEURAL = True\n",
    "RUN_FUSION = True\n",
    "FUSION_FETCH_N = 100\n",
    "FUSION_RRF_K = 60\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "349eba60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch version: 2.11.0\n",
      "Torch device : cuda\n",
      "USE_AMP      : True\n"
     ]
    }
   ],
   "source": [
    "# PyTorch runtime setup\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "\n",
    "try:\n",
    "    torch.set_num_threads(CPU_THREADS)\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "try:\n",
    "    torch.set_num_interop_threads(max(1, CPU_THREADS // 2))\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "USE_AMP = (device.type == \"cuda\")\n",
    "\n",
    "if device.type == \"cuda\":\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    try:\n",
    "        torch.set_float32_matmul_precision(\"high\")\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"Torch version:\", torch.__version__)\n",
    "print(\"Torch device :\", device)\n",
    "print(\"USE_AMP      :\", USE_AMP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "145329d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows: 1000000\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin12</th>\n",
       "      <th>PriceBin16</th>\n",
       "      <th>MileageBin12</th>\n",
       "      <th>MileageBin16</th>\n",
       "      <th>AgeBin10</th>\n",
       "      <th>AgeBin14</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>P12_3</td>\n",
       "      <td>P16_4</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_0</td>\n",
       "      <td>A14_0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_9</td>\n",
       "      <td>A14_13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>P12_7</td>\n",
       "      <td>P16_9</td>\n",
       "      <td>M12_0</td>\n",
       "      <td>M16_0</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_1</td>\n",
       "      <td>M16_2</td>\n",
       "      <td>A10_8</td>\n",
       "      <td>A14_11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>P12_0</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_4</td>\n",
       "      <td>M16_6</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    3   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   26   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   12   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   23   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   14   \n",
       "\n",
       "  PriceBin12 PriceBin16 MileageBin12 MileageBin16 AgeBin10 AgeBin14  \n",
       "0      P12_3      P16_4        M12_3        M16_4    A10_0    A14_0  \n",
       "1      P12_1      P16_1        M12_3        M16_4    A10_9   A14_13  \n",
       "2      P12_7      P16_9        M12_0        M16_0    A10_4    A14_5  \n",
       "3      P12_1      P16_1        M12_1        M16_2    A10_8   A14_11  \n",
       "4      P12_0      P16_1        M12_4        M16_6    A10_4    A14_6  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean the dataset\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\"]:\n",
    "    df[col] = df[col].astype(str).str.strip()\n",
    "\n",
    "df[\"Year\"] = pd.to_numeric(df[\"Year\"], errors=\"coerce\")\n",
    "df[\"Price\"] = pd.to_numeric(df[\"Price\"], errors=\"coerce\")\n",
    "df[\"Mileage\"] = pd.to_numeric(df[\"Mileage\"], errors=\"coerce\")\n",
    "\n",
    "df = df.dropna(subset=[\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\"]).copy()\n",
    "\n",
    "df[\"Year\"] = df[\"Year\"].astype(int)\n",
    "df[\"Age\"] = (CURRENT_YEAR - df[\"Year\"]).clip(lower=0)\n",
    "\n",
    "def make_qbin_labels(series, q, prefix):\n",
    "    cats = pd.qcut(series, q=q, labels=False, duplicates=\"drop\")\n",
    "    cats = cats.astype(int)\n",
    "    return cats.map(lambda x: f\"{prefix}{int(x)}\")\n",
    "\n",
    "df[\"PriceBin12\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_BASE, \"P12_\")\n",
    "df[\"PriceBin16\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_FINE, \"P16_\")\n",
    "df[\"MileageBin12\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_BASE, \"M12_\")\n",
    "df[\"MileageBin16\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_FINE, \"M16_\")\n",
    "df[\"AgeBin10\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_BASE, \"A10_\")\n",
    "df[\"AgeBin14\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_FINE, \"A14_\")\n",
    "\n",
    "print(\"Rows:\", len(df))\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "450426c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user, min_users_per_item, max_iters=10):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_n = len(out)\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if len(out) == old_n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    raw_interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "        .size()\n",
    "        .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = None\n",
    "    used_thresholds = None\n",
    "\n",
    "    for min_items_per_user, min_users_per_item in FILTER_SCHEDULE:\n",
    "        candidate = iterative_filter(\n",
    "            raw_interactions,\n",
    "            min_items_per_user=min_items_per_user,\n",
    "            min_users_per_item=min_users_per_item,\n",
    "            max_iters=MAX_FILTER_ITERS,\n",
    "        ).reset_index(drop=True)\n",
    "        if not candidate.empty:\n",
    "            interactions = candidate\n",
    "            used_thresholds = (min_items_per_user, min_users_per_item)\n",
    "            break\n",
    "\n",
    "    if interactions is None or interactions.empty:\n",
    "        raise ValueError(f\"{formulation_name} produced no interactions after filtering\")\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    user_feature_df = (\n",
    "        tmp[tmp[\"user_id\"].isin(interactions[\"user_id\"])]\n",
    "        [[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "    user_feature_df = user_feature_df.sort_values(\"user_idx\").reset_index(drop=True)\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[tmp[\"item_id\"].isin(interactions[\"item_id\"])]\n",
    "        [[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "    item_feature_df = item_feature_df.sort_values(\"item_idx\").reset_index(drop=True)\n",
    "\n",
    "    # Weighted leave-one-out\n",
    "    rng = np.random.default_rng(RANDOM_STATE)\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = rng.choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(int(picked))\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    num_users = int(interactions[\"user_idx\"].nunique())\n",
    "    num_items = int(interactions[\"item_idx\"].nunique())\n",
    "\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(u): int(i)\n",
    "        for u, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    user_seen_arrays = {}\n",
    "    for uid, g in train_interactions.groupby(\"user_idx\"):\n",
    "        user_seen_arrays[int(uid)] = np.sort(g[\"item_idx\"].astype(np.int32).to_numpy())\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy(dtype=np.int32)\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": list(user_cols),\n",
    "        \"item_cols\": list(item_cols),\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df,\n",
    "        \"item_feature_df\": item_feature_df,\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen_arrays\": user_seen_arrays,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_encoder.classes_.tolist(),\n",
    "        \"used_thresholds\": used_thresholds,\n",
    "    }\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "    avg_train = len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1)\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Thresholds :\", bundle[\"used_thresholds\"])\n",
    "    print(\"Train rows :\", len(bundle[\"train_interactions\"]))\n",
    "    print(\"Test rows  :\", len(bundle[\"test_interactions\"]))\n",
    "    print(\"Avg train interactions/user:\", round(avg_train, 3))\n",
    "\n",
    "def get_eval_users(bundle):\n",
    "    user_indices = np.array(sorted(bundle[\"test_item_by_user\"].keys()), dtype=np.int32)\n",
    "    if EVAL_MAX_USERS is not None and len(user_indices) > EVAL_MAX_USERS:\n",
    "        rng = np.random.default_rng(RANDOM_STATE)\n",
    "        user_indices = rng.choice(user_indices, size=EVAL_MAX_USERS, replace=False)\n",
    "    return np.sort(user_indices)\n",
    "\n",
    "def batch_metric_update(topk_idx, true_items, acc, ks):\n",
    "    matches = (topk_idx == true_items[:, None])\n",
    "    for k in ks:\n",
    "        acc[f\"HR@{k}\"] += matches[:, :k].any(axis=1).sum()\n",
    "\n",
    "    top10 = matches[:, :10]\n",
    "    found10 = top10.any(axis=1)\n",
    "    first_rank10 = np.where(found10, top10.argmax(axis=1) + 1, 0)\n",
    "\n",
    "    acc[\"MRR@10\"] += np.where(found10, 1.0 / first_rank10, 0.0).sum()\n",
    "    acc[\"NDCG@10\"] += np.where(found10, 1.0 / np.log2(first_rank10 + 1), 0.0).sum()\n",
    "    acc[\"UsersEval\"] += len(true_items)\n",
    "\n",
    "def finalize_metrics(acc, model_name, formulation_name):\n",
    "    users_eval = max(int(acc[\"UsersEval\"]), 1)\n",
    "    out = {\n",
    "        \"Formulation\": formulation_name,\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": users_eval,\n",
    "        \"HR@5\": float(acc[\"HR@5\"]) / users_eval,\n",
    "        \"HR@10\": float(acc[\"HR@10\"]) / users_eval,\n",
    "        \"HR@20\": float(acc[\"HR@20\"]) / users_eval,\n",
    "        \"MRR@10\": float(acc[\"MRR@10\"]) / users_eval,\n",
    "        \"NDCG@10\": float(acc[\"NDCG@10\"]) / users_eval,\n",
    "    }\n",
    "    return out\n",
    "\n",
    "def popularity_recommend_one(global_pop_rank, seen_array, n):\n",
    "    if seen_array is None or len(seen_array) == 0:\n",
    "        return global_pop_rank[:n].tolist()\n",
    "    seen_set = set(int(x) for x in seen_array)\n",
    "    out = []\n",
    "    for item in global_pop_rank:\n",
    "        item = int(item)\n",
    "        if item not in seen_set:\n",
    "            out.append(item)\n",
    "            if len(out) >= n:\n",
    "                break\n",
    "    return out\n",
    "\n",
    "def evaluate_popularity(bundle, model_name=\"Popularity\"):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    global_pop_rank = bundle[\"global_pop_rank\"]\n",
    "    user_seen_arrays = bundle[\"user_seen_arrays\"]\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    for uid in tqdm(user_indices, desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        true_item = test_item_by_user[int(uid)]\n",
    "        recs = popularity_recommend_one(global_pop_rank, user_seen_arrays.get(int(uid)), max(TOP_KS))\n",
    "        recs_arr = np.array(recs, dtype=np.int32)[None, :]\n",
    "        true_arr = np.array([true_item], dtype=np.int32)\n",
    "        batch_metric_update(recs_arr, true_arr, acc, TOP_KS)\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def clear_memory():\n",
    "    gc.collect()\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e001d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model fitting helpers and batched evaluators\n",
    "\n",
    "def l1_row_normalize(X):\n",
    "    X = X.tocsr(copy=True)\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel().astype(np.float32)\n",
    "    inv = np.zeros_like(row_sums, dtype=np.float32)\n",
    "    mask = row_sums > 0\n",
    "    inv[mask] = 1.0 / row_sums[mask]\n",
    "    return sp.diags(inv) @ X\n",
    "\n",
    "def fit_p3alpha(X_binary, alpha=1.0, beta=0.0):\n",
    "    Pui = l1_row_normalize(X_binary).power(alpha).tocsr()\n",
    "    Piu = l1_row_normalize(X_binary.T).power(alpha).tocsr()\n",
    "\n",
    "    S = (Piu @ Pui).astype(np.float32).tocsr()\n",
    "    S.setdiag(0.0)\n",
    "    S.eliminate_zeros()\n",
    "\n",
    "    if beta > 0:\n",
    "        item_degree = np.asarray(X_binary.sum(axis=0)).ravel().astype(np.float32)\n",
    "        penalty = np.ones_like(item_degree, dtype=np.float32)\n",
    "        mask = item_degree > 0\n",
    "        penalty[mask] = np.power(item_degree[mask], -beta)\n",
    "        S = S @ sp.diags(penalty.astype(np.float32))\n",
    "\n",
    "    return S.tocsr()\n",
    "\n",
    "def fit_ease(X_matrix, lam):\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32, copy=False)\n",
    "    G[np.diag_indices_from(G)] += np.float32(lam)\n",
    "    c, lower = sla.cho_factor(G, lower=True, overwrite_a=True, check_finite=False)\n",
    "    I = np.eye(G.shape[0], dtype=np.float32)\n",
    "    P = sla.cho_solve((c, lower), I, overwrite_b=True, check_finite=False)\n",
    "    diag = np.diag(P).copy()\n",
    "    B = -P / diag[None, :]\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    return B.astype(np.float32, copy=False)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_dense_operator(bundle, operator_matrix, model_name, source=\"binary\", batch_size=4096):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    X_source = bundle[\"X_binary\"] if source == \"binary\" else bundle[\"X_counts\"]\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    op_t = torch.as_tensor(operator_matrix, dtype=torch.float32, device=device)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), batch_size), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + batch_size]\n",
    "        X_np = X_source[batch_uids].toarray().astype(np.float32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "\n",
    "        X_t = torch.from_numpy(X_np).to(device, non_blocking=True)\n",
    "        scores = X_t @ op_t\n",
    "        scores = scores.masked_fill(X_t > 0, -1e9)\n",
    "\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "\n",
    "        del X_t, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "def make_rrf_scores(score_matrices, rrf_k=60):\n",
    "    # score_matrices: list of top-k item index arrays with shape [batch, fetch_n]\n",
    "    batch, fetch_n = score_matrices[0].shape\n",
    "    fused = np.zeros((batch, fetch_n * len(score_matrices)), dtype=np.float32)\n",
    "    del fused\n",
    "    # kept only as a placeholder helper; fusion is done in a simple Python routine later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "33f155c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Formulation: H1_base\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition']\n",
      "Users      : 43199\n",
      "Items      : 2640\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 831152\n",
      "Test rows  : 43199\n",
      "Avg train interactions/user: 19.24\n",
      "====================================================================================================\n",
      "Formulation: H1_color\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition', 'Color']\n",
      "Users      : 62921\n",
      "Items      : 13198\n",
      "Thresholds : (4, 8)\n",
      "Train rows : 237108\n",
      "Test rows  : 62921\n",
      "Avg train interactions/user: 3.768\n",
      "====================================================================================================\n",
      "Formulation: H1_fine\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition']\n",
      "Users      : 95529\n",
      "Items      : 3696\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 812918\n",
      "Test rows  : 95529\n",
      "Avg train interactions/user: 8.51\n",
      "====================================================================================================\n",
      "Formulation: H1_fine_color\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition', 'Color']\n",
      "Users      : 59346\n",
      "Items      : 23514\n",
      "Thresholds : (3, 5)\n",
      "Train rows : 134712\n",
      "Test rows  : 59346\n",
      "Avg train interactions/user: 2.27\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>TrainRows</th>\n",
       "      <th>AvgTrainPerUser</th>\n",
       "      <th>Thresholds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_fine_color</td>\n",
       "      <td>59346</td>\n",
       "      <td>23514</td>\n",
       "      <td>134712</td>\n",
       "      <td>2.269942</td>\n",
       "      <td>(3, 5)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_color</td>\n",
       "      <td>62921</td>\n",
       "      <td>13198</td>\n",
       "      <td>237108</td>\n",
       "      <td>3.768344</td>\n",
       "      <td>(4, 8)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>95529</td>\n",
       "      <td>3696</td>\n",
       "      <td>812918</td>\n",
       "      <td>8.509646</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>2640</td>\n",
       "      <td>831152</td>\n",
       "      <td>19.240075</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     Formulation  Users  Items  TrainRows  AvgTrainPerUser Thresholds\n",
       "0  H1_fine_color  59346  23514     134712         2.269942     (3, 5)\n",
       "1       H1_color  62921  13198     237108         3.768344     (4, 8)\n",
       "2        H1_fine  95529   3696     812918         8.509646    (5, 10)\n",
       "3        H1_base  43199   2640     831152        19.240075    (5, 10)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Build all richer H1 formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"TrainRows\": len(bundle[\"train_interactions\"]),\n",
    "        \"AvgTrainPerUser\": len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1),\n",
    "        \"Thresholds\": bundle[\"used_thresholds\"],\n",
    "    })\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values([\"Items\", \"Users\"], ascending=False).reset_index(drop=True)\n",
    "display(summary_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a61a7425",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Regular search on formulation: H1_base\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / Popularity: 100%|██████████| 43199/43199 [00:01<00:00, 31139.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_0_b0_6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_0_b0_6: 100%|██████████| 11/11 [00:00<00:00, 42.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_05_b0_6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_05_b0_6: 100%|██████████| 11/11 [00:00<00:00, 87.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_1_b0_7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_1_b0_7: 100%|██████████| 11/11 [00:00<00:00, 85.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_15_b0_7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_15_b0_7: 100%|██████████| 11/11 [00:00<00:00, 87.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l800\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l800: 100%|██████████| 11/11 [00:00<00:00, 26.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l1000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1000: 100%|██████████| 11/11 [00:00<00:00, 33.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l1200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1200: 100%|██████████| 11/11 [00:00<00:00, 28.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l1600\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1600: 100%|██████████| 11/11 [00:00<00:00, 30.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l2200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l2200: 100%|██████████| 11/11 [00:00<00:00, 32.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l800\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l800: 100%|██████████| 11/11 [00:00<00:00, 26.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l1200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l1200: 100%|██████████| 11/11 [00:00<00:00, 41.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l1600\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l1600: 100%|██████████| 11/11 [00:00<00:00, 41.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l2200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l2200: 100%|██████████| 11/11 [00:00<00:00, 27.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Regular search on formulation: H1_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / Popularity: 100%|██████████| 62921/62921 [00:01<00:00, 33405.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_0_b0_6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_0_b0_6: 100%|██████████| 16/16 [00:01<00:00, 11.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_05_b0_6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_05_b0_6: 100%|██████████| 16/16 [00:01<00:00, 13.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_1_b0_7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_1_b0_7: 100%|██████████| 16/16 [00:01<00:00, 13.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_15_b0_7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_15_b0_7: 100%|██████████| 16/16 [00:01<00:00, 13.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l800\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_binary_l800: 100%|██████████| 16/16 [00:01<00:00, 12.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l1000\n"
     ]
    }
   ],
   "source": [
    "# Regular model search across all H1 variants\n",
    "\n",
    "regular_results = []\n",
    "\n",
    "for formulation_name, bundle in bundles.items():\n",
    "    print(\"=\" * 120)\n",
    "    print(\"Regular search on formulation:\", formulation_name)\n",
    "\n",
    "    # Popularity baseline\n",
    "    regular_results.append(evaluate_popularity(bundle, model_name=\"Popularity\"))\n",
    "\n",
    "    # RP3beta grid\n",
    "    for alpha, beta in RP3_GRID:\n",
    "        clear_memory()\n",
    "        model_name = f\"RP3beta_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        S = fit_p3alpha(bundle[\"X_binary\"], alpha=alpha, beta=beta)\n",
    "        S_dense = S.toarray().astype(np.float32, copy=False)\n",
    "        regular_results.append(\n",
    "            evaluate_dense_operator(bundle, S_dense, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        )\n",
    "        del S, S_dense\n",
    "        clear_memory()\n",
    "\n",
    "    # EASE binary\n",
    "    for lam in EASE_BINARY_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_binary_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        B = fit_ease(bundle[\"X_binary\"], lam=lam)\n",
    "        regular_results.append(\n",
    "            evaluate_dense_operator(bundle, B, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        )\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "    # EASE count\n",
    "    for lam in EASE_COUNT_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_count_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        B = fit_ease(bundle[\"X_counts\"], lam=lam)\n",
    "        regular_results.append(\n",
    "            evaluate_dense_operator(bundle, B, model_name=model_name, source=\"count\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        )\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "regular_results_df = (\n",
    "    pd.DataFrame(regular_results)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "print(\"Top regular results overall:\")\n",
    "display(regular_results_df.head(20))\n",
    "\n",
    "best_formulations_df = (\n",
    "    regular_results_df.groupby(\"Formulation\")[[\"HR@10\", \"HR@20\", \"MRR@10\", \"NDCG@10\"]]\n",
    "    .max()\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "print(\"Best formulation scores:\")\n",
    "display(best_formulations_df)\n",
    "\n",
    "selected_formulations = best_formulations_df[\"Formulation\"].head(TOP_FORMULATIONS_FOR_NEURAL).tolist()\n",
    "print(\"Selected formulations for neural search:\", selected_formulations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3cf5b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare formulation-specific tensors for neural models\n",
    "\n",
    "def prepare_neural_data(bundle):\n",
    "    user_feature_df = bundle[\"user_feature_df\"].copy()\n",
    "    item_feature_df = bundle[\"item_feature_df\"].copy()\n",
    "    train_interactions = bundle[\"train_interactions\"].copy()\n",
    "\n",
    "    user_cols = bundle[\"user_cols\"]\n",
    "    item_cols = bundle[\"item_cols\"]\n",
    "\n",
    "    feature_encoders = {}\n",
    "\n",
    "    for col in user_cols:\n",
    "        le = LabelEncoder()\n",
    "        user_feature_df[col] = le.fit_transform(user_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"user::{col}\"] = le\n",
    "\n",
    "    for col in item_cols:\n",
    "        le = LabelEncoder()\n",
    "        item_feature_df[col] = le.fit_transform(item_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"item::{col}\"] = le\n",
    "\n",
    "    user_feat_np = user_feature_df[user_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "    item_feat_np = item_feature_df[item_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "\n",
    "    pos_user_np = train_interactions[\"user_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_item_np = train_interactions[\"item_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_weight_np = train_interactions[\"count\"].to_numpy(dtype=np.float32, copy=True)\n",
    "\n",
    "    x_binary_np = bundle[\"X_binary\"].toarray().astype(np.float32, copy=False)\n",
    "\n",
    "    out = {\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"user_feat_t\": torch.from_numpy(user_feat_np).to(device),\n",
    "        \"item_feat_t\": torch.from_numpy(item_feat_np).to(device),\n",
    "        \"pos_user_t\": torch.from_numpy(pos_user_np).to(device),\n",
    "        \"pos_item_t\": torch.from_numpy(pos_item_np).to(device),\n",
    "        \"pos_weight_t\": torch.from_numpy(pos_weight_np).to(device),\n",
    "        \"x_binary_np\": x_binary_np,\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        out[\"x_binary_t\"] = torch.from_numpy(x_binary_np).to(device)\n",
    "        out[\"x_binary_on_device\"] = True\n",
    "    except RuntimeError:\n",
    "        out[\"x_binary_t\"] = None\n",
    "        out[\"x_binary_on_device\"] = False\n",
    "        clear_memory()\n",
    "\n",
    "    out[\"user_feature_cards\"] = {\n",
    "        col: int(user_feature_df[col].nunique()) for col in user_cols\n",
    "    }\n",
    "    out[\"item_feature_cards\"] = {\n",
    "        col: int(item_feature_df[col].nunique()) for col in item_cols\n",
    "    }\n",
    "\n",
    "    return out\n",
    "\n",
    "neural_data_by_formulation = {}\n",
    "for formulation_name in selected_formulations:\n",
    "    print(\"Preparing neural tensors:\", formulation_name)\n",
    "    neural_data_by_formulation[formulation_name] = prepare_neural_data(bundles[formulation_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c9a604a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural architectures and batched evaluators\n",
    "\n",
    "class FeatureEmbeddingBlock(nn.Module):\n",
    "    def __init__(self, cardinalities, emb_dim):\n",
    "        super().__init__()\n",
    "        self.cols = list(cardinalities.keys())\n",
    "        self.embs = nn.ModuleList([\n",
    "            nn.Embedding(int(cardinalities[c]), emb_dim) for c in self.cols\n",
    "        ])\n",
    "        for emb in self.embs:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, x):\n",
    "        parts = [emb(x[:, i]) for i, emb in enumerate(self.embs)]\n",
    "        return torch.cat(parts, dim=1)\n",
    "\n",
    "class MLPBlock(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, dropout=0.1, final_dim=None):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        prev = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(dropout)])\n",
    "            prev = h\n",
    "        if final_dim is not None:\n",
    "            layers.append(nn.Linear(prev, final_dim))\n",
    "            prev = final_dim\n",
    "        self.net = nn.Sequential(*layers)\n",
    "        self.output_dim = prev\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "class TowerEncoder(nn.Module):\n",
    "    def __init__(self, cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.embed = FeatureEmbeddingBlock(cards, emb_dim)\n",
    "        input_dim = len(cards) * emb_dim\n",
    "        self.mlp = MLPBlock(input_dim, hidden_dims, dropout=dropout, final_dim=out_dim)\n",
    "\n",
    "    def forward(self, feat_batch):\n",
    "        x = self.embed(feat_batch)\n",
    "        x = self.mlp(x)\n",
    "        return F.normalize(x, dim=-1)\n",
    "\n",
    "class TwoTowerModel(nn.Module):\n",
    "    def __init__(self, user_cards, item_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "        super().__init__()\n",
    "        self.user_tower = TowerEncoder(user_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "        self.item_tower = TowerEncoder(item_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "class MultVAE(nn.Module):\n",
    "    def __init__(self, num_items, hidden_dim=2048, latent_dim=512, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(num_items, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(latent_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, num_items),\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = self.encoder(self.dropout(x))\n",
    "        return self.mu(h), self.logvar(h)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        if self.training:\n",
    "            std = torch.exp(0.5 * logvar)\n",
    "            eps = torch.randn_like(std)\n",
    "            return mu + eps * std\n",
    "        return mu\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        logits = self.decoder(z)\n",
    "        return logits, mu, logvar\n",
    "\n",
    "class NeuMF(nn.Module):\n",
    "    def __init__(self, num_users, num_items, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128)):\n",
    "        super().__init__()\n",
    "        self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "        self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "        self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "        self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "        mlp_layers = []\n",
    "        prev = 2 * mlp_dim\n",
    "        for h in hidden_dims:\n",
    "            mlp_layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(0.1)])\n",
    "            prev = h\n",
    "        self.mlp = nn.Sequential(*mlp_layers)\n",
    "        self.out = nn.Linear(prev + mf_dim, 1)\n",
    "\n",
    "        for emb in [self.user_mf, self.item_mf, self.user_mlp, self.item_mlp]:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, user_idx, item_idx):\n",
    "        mf_part = self.user_mf(user_idx) * self.item_mf(item_idx)\n",
    "        mlp_in = torch.cat([self.user_mlp(user_idx), self.item_mlp(item_idx)], dim=1)\n",
    "        mlp_part = self.mlp(mlp_in)\n",
    "        x = torch.cat([mf_part, mlp_part], dim=1)\n",
    "        return self.out(x).squeeze(-1)\n",
    "\n",
    "def iter_positive_batches(ndata, batch_size):\n",
    "    n = ndata[\"pos_user_t\"].shape[0]\n",
    "    perm = torch.randperm(n, device=device)\n",
    "    for start in range(0, n, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        yield (\n",
    "            ndata[\"pos_user_t\"][idx],\n",
    "            ndata[\"pos_item_t\"][idx],\n",
    "            ndata[\"pos_weight_t\"][idx],\n",
    "        )\n",
    "\n",
    "def iter_dense_user_batches(ndata, batch_size):\n",
    "    n_users = ndata[\"x_binary_np\"].shape[0]\n",
    "    perm = np.random.permutation(n_users)\n",
    "    for start in range(0, n_users, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        if ndata[\"x_binary_on_device\"]:\n",
    "            yield ndata[\"x_binary_t\"][idx], torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "        else:\n",
    "            batch = torch.from_numpy(ndata[\"x_binary_np\"][idx]).to(device)\n",
    "            yield batch, torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_twotower(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    item_feats = ndata[\"item_feat_t\"]\n",
    "    item_vecs = []\n",
    "    for start in range(0, item_feats.shape[0], 4096):\n",
    "        item_vecs.append(model.item_tower(item_feats[start:start + 4096]))\n",
    "    item_matrix = torch.cat(item_vecs, dim=0)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        feat_batch = ndata[\"user_feat_t\"][batch_uids]\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        user_vec = model.user_tower(feat_batch)\n",
    "        scores = user_vec @ item_matrix.T\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del feat_batch, x_seen, user_vec, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_multvae(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        x_batch = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        logits, _, _ = model(x_batch)\n",
    "        logits = logits.masked_fill(x_batch > 0, -1e9)\n",
    "        topk = torch.topk(logits, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([bundle[\"test_item_by_user\"][int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del x_batch, logits, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_neumf(bundle, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    num_items = bundle[\"num_items\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    all_items = torch.arange(num_items, device=device, dtype=torch.long)\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEUMF_EVAL_USER_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEUMF_EVAL_USER_BATCH]\n",
    "        user_batch = torch.as_tensor(batch_uids, device=device, dtype=torch.long)\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "\n",
    "        parts = []\n",
    "        for istart in range(0, num_items, NEUMF_EVAL_ITEM_BATCH):\n",
    "            item_batch = all_items[istart:istart + NEUMF_EVAL_ITEM_BATCH]\n",
    "            u_expand = user_batch[:, None].expand(-1, item_batch.shape[0]).reshape(-1)\n",
    "            i_expand = item_batch[None, :].expand(user_batch.shape[0], -1).reshape(-1)\n",
    "            scores_part = model(u_expand, i_expand).reshape(user_batch.shape[0], -1)\n",
    "            parts.append(scores_part)\n",
    "\n",
    "        scores = torch.cat(parts, dim=1)\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del user_batch, x_seen, parts, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f000e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural training functions\n",
    "\n",
    "def train_two_tower(bundle, ndata, cfg):\n",
    "    model = TwoTowerModel(\n",
    "        user_cards=ndata[\"user_feature_cards\"],\n",
    "        item_cards=ndata[\"item_feature_cards\"],\n",
    "        emb_dim=cfg[\"emb_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "        out_dim=cfg[\"out_dim\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "        for user_idx_batch, item_idx_batch, weight_batch in iter_positive_batches(ndata, TWOTOWER_BATCH):\n",
    "            user_batch = ndata[\"user_feat_t\"][user_idx_batch]\n",
    "            item_batch = ndata[\"item_feat_t\"][item_idx_batch]\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                user_vec = model.user_tower(user_batch)\n",
    "                item_vec = model.item_tower(item_batch)\n",
    "                logits = (user_vec @ item_vec.T) / cfg[\"temperature\"]\n",
    "                targets = torch.arange(logits.shape[0], device=device)\n",
    "                loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "            total_w += float(weight_batch.sum().item())\n",
    "\n",
    "        print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def train_multvae(bundle, ndata, cfg):\n",
    "    model = MultVAE(\n",
    "        num_items=bundle[\"num_items\"],\n",
    "        hidden_dim=cfg[\"hidden_dim\"],\n",
    "        latent_dim=cfg[\"latent_dim\"],\n",
    "        dropout=cfg[\"dropout\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    steps_per_epoch = max(1, math.ceil(n_users / AUTOENC_BATCH))\n",
    "    total_steps = max(1, cfg[\"epochs\"] * steps_per_epoch)\n",
    "    step = 0\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        total_loss = 0.0\n",
    "        total_rows = 0\n",
    "\n",
    "        for batch_x, _ in iter_dense_user_batches(ndata, AUTOENC_BATCH):\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            anneal = min(cfg[\"anneal_cap\"], step / max(total_steps * 0.3, 1))\n",
    "\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                logits, mu, logvar = model(batch_x)\n",
    "                recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                loss = (recon + anneal * kl).mean()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "            total_rows += batch_x.shape[0]\n",
    "            step += 1\n",
    "\n",
    "        if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == cfg[\"epochs\"]:\n",
    "            print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_rows, 1):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def train_neumf(bundle, cfg):\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    n_items = bundle[\"num_items\"]\n",
    "    train_df = bundle[\"train_interactions\"]\n",
    "\n",
    "    pos_user = torch.as_tensor(train_df[\"user_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_item = torch.as_tensor(train_df[\"item_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_weight = torch.as_tensor(train_df[\"count\"].to_numpy(dtype=np.float32), device=device)\n",
    "\n",
    "    model = NeuMF(\n",
    "        num_users=n_users,\n",
    "        num_items=n_items,\n",
    "        mf_dim=cfg[\"mf_dim\"],\n",
    "        mlp_dim=cfg[\"mlp_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    n = pos_user.shape[0]\n",
    "    model.train()\n",
    "\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        perm = torch.randperm(n, device=device)\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        for start in range(0, n, PAIRWISE_BATCH):\n",
    "            idx = perm[start:start + PAIRWISE_BATCH]\n",
    "            u = pos_user[idx]\n",
    "            p = pos_item[idx]\n",
    "            w = pos_weight[idx]\n",
    "            n_item = torch.randint(0, n_items, size=p.shape, device=device)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                pos_scores = model(u, p)\n",
    "                neg_scores = model(u, n_item)\n",
    "                loss_vec = -F.logsigmoid(pos_scores - neg_scores)\n",
    "                loss = (loss_vec * w).sum() / w.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(w.sum().item())\n",
    "            total_w += float(w.sum().item())\n",
    "\n",
    "        print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b2ee8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural search on the best formulations only\n",
    "\n",
    "neural_results = []\n",
    "\n",
    "if RUN_NEURAL:\n",
    "    for formulation_name in selected_formulations:\n",
    "        print(\"=\" * 120)\n",
    "        print(\"Neural search on formulation:\", formulation_name)\n",
    "\n",
    "        bundle = bundles[formulation_name]\n",
    "        ndata = neural_data_by_formulation[formulation_name]\n",
    "\n",
    "        for cfg in TWOTOWER_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_two_tower(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_twotower(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in MULTVAE_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_multvae(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_multvae(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in NEUMF_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_neumf(bundle, cfg)\n",
    "            neural_results.append(evaluate_neumf(bundle, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "    neural_results_df = (\n",
    "        pd.DataFrame(neural_results)\n",
    "        .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "else:\n",
    "    neural_results_df = pd.DataFrame()\n",
    "\n",
    "print(\"Top neural results:\")\n",
    "display(neural_results_df.head(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d0e73c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional fusion on the best formulation\n",
    "\n",
    "def make_python_recommenders_for_fusion(bundle, formulation_name, regular_results_df, neural_results_df):\n",
    "    out = {}\n",
    "\n",
    "    best_rp3 = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"RP3beta_\")\n",
    "    best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_binary_\")\n",
    "    if best_ease is None:\n",
    "        best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_count_\")\n",
    "    best_tt = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"TwoTower_\") if not neural_results_df.empty else None\n",
    "    best_vae = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"MultVAE_\") if not neural_results_df.empty else None\n",
    "\n",
    "    return best_ease, best_rp3, best_tt, best_vae\n",
    "\n",
    "fusion_results = []\n",
    "\n",
    "if RUN_FUSION and not neural_results_df.empty:\n",
    "    best_formulation = regular_results_df.iloc[0][\"Formulation\"]\n",
    "    bundle = bundles[best_formulation]\n",
    "    print(\"Fusion on formulation:\", best_formulation)\n",
    "\n",
    "    reg_sub = regular_results_df[regular_results_df[\"Formulation\"] == best_formulation]\n",
    "    neu_sub = neural_results_df[neural_results_df[\"Formulation\"] == best_formulation]\n",
    "\n",
    "    ease_name = top_model_name(reg_sub, \"EASE_binary_\")\n",
    "    if ease_name is None:\n",
    "        ease_name = top_model_name(reg_sub, \"EASE_count_\")\n",
    "    rp3_name = top_model_name(reg_sub, \"RP3beta_\")\n",
    "    tt_name = top_model_name(neu_sub, \"TwoTower_\")\n",
    "    vae_name = top_model_name(neu_sub, \"MultVAE_\")\n",
    "\n",
    "    print(\"Best for fusion:\", ease_name, rp3_name, tt_name, vae_name)\n",
    "else:\n",
    "    print(\"Fusion skipped or no neural results.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb491a3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Final comparison tables\n",
    "\n",
    "all_frames = [regular_results_df]\n",
    "if not neural_results_df.empty:\n",
    "    all_frames.append(neural_results_df)\n",
    "\n",
    "all_results_df = (\n",
    "    pd.concat(all_frames, ignore_index=True)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "print(\"Top overall results:\")\n",
    "display(all_results_df.head(30))\n",
    "\n",
    "print(\"Best per formulation:\")\n",
    "best_per_formulation = (\n",
    "    all_results_df.groupby(\"Formulation\")\n",
    "    .head(10)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "display(best_per_formulation)\n",
    "\n",
    "print(\"Best per model family:\")\n",
    "family_df = all_results_df.copy()\n",
    "family_df[\"Family\"] = family_df[\"Model\"].str.extract(r\"^([A-Za-z]+(?:_[A-Za-z]+)?)\")[0]\n",
    "display(\n",
    "    family_df.groupby(\"Family\")[[\"HR@10\", \"HR@20\", \"MRR@10\", \"NDCG@10\"]]\n",
    "    .max()\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\"], ascending=False)\n",
    "    .reset_index()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc27b61",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "This notebook removes the old automatic screening stage that previously caused very long EASE runs and high RAM usage.\n",
    "\n",
    "What is threaded or batched here:\n",
    "- BLAS / LAPACK thread env vars are set before importing NumPy and SciPy\n",
    "- `threadpoolctl` is applied when available\n",
    "- EASE uses a Cholesky-based dense solve\n",
    "- EASE and RP3 evaluation are fully batched\n",
    "- heavy score computation is moved to the GPU when available\n",
    "- neural training uses manual on-device batching instead of multiprocessing DataLoaders\n",
    "\n",
    "If you still need more speed:\n",
    "- lower `TOP_FORMULATIONS_FOR_NEURAL`\n",
    "- lower `TWOTOWER_CONFIGS` / `MULTVAE_CONFIGS`\n",
    "- lower `EVAL_MAX_USERS`\n",
    "- lower `LINEAR_EVAL_BATCH` or `NEURAL_EVAL_BATCH` if GPU memory becomes the bottleneck"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
