{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6979c756",
   "metadata": {},
   "source": [
    "# Car recommendation notebook — H1 variants, tuned top models, batched evaluation\n",
    "\n",
    "This notebook focuses on richer H1-style formulations only and removes the old slow screening path.\n",
    "\n",
    "Goals:\n",
    "- try several harder H1 variants\n",
    "- improve the current top families: RP3beta, EASE, TwoTower, MultVAE\n",
    "- batch all heavy evaluation paths\n",
    "- use GPU where it helps and CPU threads where possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8fb2968a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU thread target: 16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "\n",
    "CPU_THREADS = min(16, os.cpu_count() or 1)\n",
    "for var in [\n",
    "    \"OMP_NUM_THREADS\",\n",
    "    \"OPENBLAS_NUM_THREADS\",\n",
    "    \"MKL_NUM_THREADS\",\n",
    "    \"NUMEXPR_NUM_THREADS\",\n",
    "    \"VECLIB_MAXIMUM_THREADS\",\n",
    "]:\n",
    "    os.environ[var] = str(CPU_THREADS)\n",
    "\n",
    "import gc\n",
    "import math\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy import linalg as sla\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "try:\n",
    "    from threadpoolctl import threadpool_limits\n",
    "except Exception:\n",
    "    threadpool_limits = None\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "\n",
    "if threadpool_limits is not None:\n",
    "    try:\n",
    "        threadpool_limits(limits=CPU_THREADS)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"CPU thread target:\", CPU_THREADS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6731d71d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Formulations: ['H1_base', 'H1_color', 'H1_fine', 'H1_fine_color']\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "CURRENT_YEAR = 2026\n",
    "\n",
    "# Binning\n",
    "PRICE_BINS_BASE = 12\n",
    "PRICE_BINS_FINE = 16\n",
    "MILEAGE_BINS_BASE = 12\n",
    "MILEAGE_BINS_FINE = 16\n",
    "AGE_BINS_BASE = 10\n",
    "AGE_BINS_FINE = 14\n",
    "\n",
    "# Filtering\n",
    "FILTER_SCHEDULE = [\n",
    "    (5, 10),\n",
    "    (4, 8),\n",
    "    (3, 5),\n",
    "    (2, 3),\n",
    "]\n",
    "MAX_FILTER_ITERS = 10\n",
    "\n",
    "# Evaluation\n",
    "TOP_KS = [5, 10, 20]\n",
    "EVAL_MAX_USERS = None\n",
    "LINEAR_EVAL_BATCH = 4096\n",
    "NEURAL_EVAL_BATCH = 4096\n",
    "\n",
    "# Manual richer H1 variants only\n",
    "FORMULATIONS = {\n",
    "    \"H1_base\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin12\", \"MileageBin12\", \"Condition\", \"AgeBin10\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin10\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H1_fine\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\"],\n",
    "    },\n",
    "    \"H1_fine_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin16\", \"MileageBin16\", \"Condition\", \"AgeBin14\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin14\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "# Try all formulations, then keep the top two for neural training\n",
    "TOP_FORMULATIONS_FOR_NEURAL = 2\n",
    "\n",
    "# Best current families only\n",
    "RP3_GRID = [\n",
    "    (1.00, 0.60),\n",
    "    (1.05, 0.60),\n",
    "    (1.10, 0.70),\n",
    "    (1.15, 0.70),\n",
    "]\n",
    "\n",
    "EASE_BINARY_LAMBDAS = [800.0, 1000.0, 1200.0, 1600.0, 2200.0]\n",
    "EASE_COUNT_LAMBDAS = [800.0, 1200.0, 1600.0, 2200.0]\n",
    "\n",
    "TWOTOWER_CONFIGS = [\n",
    "    {\"name\": \"TwoTower_t2\", \"emb_dim\": 96,  \"hidden_dims\": (768, 384, 192),  \"out_dim\": 160, \"epochs\": 28, \"lr\": 1.5e-3, \"wd\": 1e-5, \"temperature\": 0.05},\n",
    "    {\"name\": \"TwoTower_t3\", \"emb_dim\": 128, \"hidden_dims\": (1024, 512, 256), \"out_dim\": 192, \"epochs\": 32, \"lr\": 1.2e-3, \"wd\": 1e-5, \"temperature\": 0.04},\n",
    "]\n",
    "\n",
    "MULTVAE_CONFIGS = [\n",
    "    {\"name\": \"MultVAE_v3\", \"hidden_dim\": 2048, \"latent_dim\": 512, \"dropout\": 0.30, \"epochs\": 100, \"lr\": 6e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "    {\"name\": \"MultVAE_v4\", \"hidden_dim\": 2560, \"latent_dim\": 640, \"dropout\": 0.35, \"epochs\": 110, \"lr\": 5e-4, \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "]\n",
    "\n",
    "NEUMF_CONFIGS = [\n",
    "    {\"name\": \"NeuMF_n2\", \"mf_dim\": 96, \"mlp_dim\": 192, \"hidden_dims\": (384, 192, 96), \"epochs\": 26, \"lr\": 1.5e-3, \"wd\": 1e-6},\n",
    "]\n",
    "\n",
    "TWOTOWER_BATCH = 32768\n",
    "PAIRWISE_BATCH = 32768\n",
    "AUTOENC_BATCH = 4096\n",
    "NEUMF_EVAL_USER_BATCH = 512\n",
    "NEUMF_EVAL_ITEM_BATCH = 1024\n",
    "\n",
    "RUN_NEURAL = True\n",
    "RUN_FUSION = True\n",
    "FUSION_FETCH_N = 100\n",
    "FUSION_RRF_K = 60\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "349eba60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch version: 2.11.0\n",
      "Torch device : cuda\n",
      "USE_AMP      : True\n"
     ]
    }
   ],
   "source": [
    "# PyTorch runtime setup\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "\n",
    "try:\n",
    "    torch.set_num_threads(CPU_THREADS)\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "try:\n",
    "    torch.set_num_interop_threads(max(1, CPU_THREADS // 2))\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "USE_AMP = (device.type == \"cuda\")\n",
    "\n",
    "if device.type == \"cuda\":\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    try:\n",
    "        torch.set_float32_matmul_precision(\"high\")\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"Torch version:\", torch.__version__)\n",
    "print(\"Torch device :\", device)\n",
    "print(\"USE_AMP      :\", USE_AMP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "145329d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows: 1000000\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin12</th>\n",
       "      <th>PriceBin16</th>\n",
       "      <th>MileageBin12</th>\n",
       "      <th>MileageBin16</th>\n",
       "      <th>AgeBin10</th>\n",
       "      <th>AgeBin14</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>P12_3</td>\n",
       "      <td>P16_4</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_0</td>\n",
       "      <td>A14_0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_3</td>\n",
       "      <td>M16_4</td>\n",
       "      <td>A10_9</td>\n",
       "      <td>A14_13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>P12_7</td>\n",
       "      <td>P16_9</td>\n",
       "      <td>M12_0</td>\n",
       "      <td>M16_0</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>P12_1</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_1</td>\n",
       "      <td>M16_2</td>\n",
       "      <td>A10_8</td>\n",
       "      <td>A14_11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>P12_0</td>\n",
       "      <td>P16_1</td>\n",
       "      <td>M12_4</td>\n",
       "      <td>M16_6</td>\n",
       "      <td>A10_4</td>\n",
       "      <td>A14_6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    3   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   26   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   12   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   23   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   14   \n",
       "\n",
       "  PriceBin12 PriceBin16 MileageBin12 MileageBin16 AgeBin10 AgeBin14  \n",
       "0      P12_3      P16_4        M12_3        M16_4    A10_0    A14_0  \n",
       "1      P12_1      P16_1        M12_3        M16_4    A10_9   A14_13  \n",
       "2      P12_7      P16_9        M12_0        M16_0    A10_4    A14_5  \n",
       "3      P12_1      P16_1        M12_1        M16_2    A10_8   A14_11  \n",
       "4      P12_0      P16_1        M12_4        M16_6    A10_4    A14_6  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean the dataset\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\"]:\n",
    "    df[col] = df[col].astype(str).str.strip()\n",
    "\n",
    "df[\"Year\"] = pd.to_numeric(df[\"Year\"], errors=\"coerce\")\n",
    "df[\"Price\"] = pd.to_numeric(df[\"Price\"], errors=\"coerce\")\n",
    "df[\"Mileage\"] = pd.to_numeric(df[\"Mileage\"], errors=\"coerce\")\n",
    "\n",
    "df = df.dropna(subset=[\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\"]).copy()\n",
    "\n",
    "df[\"Year\"] = df[\"Year\"].astype(int)\n",
    "df[\"Age\"] = (CURRENT_YEAR - df[\"Year\"]).clip(lower=0)\n",
    "\n",
    "def make_qbin_labels(series, q, prefix):\n",
    "    cats = pd.qcut(series, q=q, labels=False, duplicates=\"drop\")\n",
    "    cats = cats.astype(int)\n",
    "    return cats.map(lambda x: f\"{prefix}{int(x)}\")\n",
    "\n",
    "df[\"PriceBin12\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_BASE, \"P12_\")\n",
    "df[\"PriceBin16\"] = make_qbin_labels(df[\"Price\"], PRICE_BINS_FINE, \"P16_\")\n",
    "df[\"MileageBin12\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_BASE, \"M12_\")\n",
    "df[\"MileageBin16\"] = make_qbin_labels(df[\"Mileage\"], MILEAGE_BINS_FINE, \"M16_\")\n",
    "df[\"AgeBin10\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_BASE, \"A10_\")\n",
    "df[\"AgeBin14\"] = make_qbin_labels(df[\"Age\"], AGE_BINS_FINE, \"A14_\")\n",
    "\n",
    "print(\"Rows:\", len(df))\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "450426c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user, min_users_per_item, max_iters=10):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_n = len(out)\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if len(out) == old_n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    raw_interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "        .size()\n",
    "        .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = None\n",
    "    used_thresholds = None\n",
    "\n",
    "    for min_items_per_user, min_users_per_item in FILTER_SCHEDULE:\n",
    "        candidate = iterative_filter(\n",
    "            raw_interactions,\n",
    "            min_items_per_user=min_items_per_user,\n",
    "            min_users_per_item=min_users_per_item,\n",
    "            max_iters=MAX_FILTER_ITERS,\n",
    "        ).reset_index(drop=True)\n",
    "        if not candidate.empty:\n",
    "            interactions = candidate\n",
    "            used_thresholds = (min_items_per_user, min_users_per_item)\n",
    "            break\n",
    "\n",
    "    if interactions is None or interactions.empty:\n",
    "        raise ValueError(f\"{formulation_name} produced no interactions after filtering\")\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    user_feature_df = (\n",
    "        tmp[tmp[\"user_id\"].isin(interactions[\"user_id\"])]\n",
    "        [[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "    user_feature_df = user_feature_df.sort_values(\"user_idx\").reset_index(drop=True)\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[tmp[\"item_id\"].isin(interactions[\"item_id\"])]\n",
    "        [[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "    item_feature_df = item_feature_df.sort_values(\"item_idx\").reset_index(drop=True)\n",
    "\n",
    "    # Weighted leave-one-out\n",
    "    rng = np.random.default_rng(RANDOM_STATE)\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = rng.choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(int(picked))\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    num_users = int(interactions[\"user_idx\"].nunique())\n",
    "    num_items = int(interactions[\"item_idx\"].nunique())\n",
    "\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(u): int(i)\n",
    "        for u, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    user_seen_arrays = {}\n",
    "    for uid, g in train_interactions.groupby(\"user_idx\"):\n",
    "        user_seen_arrays[int(uid)] = np.sort(g[\"item_idx\"].astype(np.int32).to_numpy())\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy(dtype=np.int32)\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": list(user_cols),\n",
    "        \"item_cols\": list(item_cols),\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df,\n",
    "        \"item_feature_df\": item_feature_df,\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen_arrays\": user_seen_arrays,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_encoder.classes_.tolist(),\n",
    "        \"used_thresholds\": used_thresholds,\n",
    "    }\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "    avg_train = len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1)\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Thresholds :\", bundle[\"used_thresholds\"])\n",
    "    print(\"Train rows :\", len(bundle[\"train_interactions\"]))\n",
    "    print(\"Test rows  :\", len(bundle[\"test_interactions\"]))\n",
    "    print(\"Avg train interactions/user:\", round(avg_train, 3))\n",
    "\n",
    "def get_eval_users(bundle):\n",
    "    user_indices = np.array(sorted(bundle[\"test_item_by_user\"].keys()), dtype=np.int32)\n",
    "    if EVAL_MAX_USERS is not None and len(user_indices) > EVAL_MAX_USERS:\n",
    "        rng = np.random.default_rng(RANDOM_STATE)\n",
    "        user_indices = rng.choice(user_indices, size=EVAL_MAX_USERS, replace=False)\n",
    "    return np.sort(user_indices)\n",
    "\n",
    "def batch_metric_update(topk_idx, true_items, acc, ks):\n",
    "    matches = (topk_idx == true_items[:, None])\n",
    "    for k in ks:\n",
    "        acc[f\"HR@{k}\"] += matches[:, :k].any(axis=1).sum()\n",
    "\n",
    "    top10 = matches[:, :10]\n",
    "    found10 = top10.any(axis=1)\n",
    "    first_rank10 = np.where(found10, top10.argmax(axis=1) + 1, 0)\n",
    "\n",
    "    acc[\"MRR@10\"] += np.where(found10, 1.0 / first_rank10, 0.0).sum()\n",
    "    acc[\"NDCG@10\"] += np.where(found10, 1.0 / np.log2(first_rank10 + 1), 0.0).sum()\n",
    "    acc[\"UsersEval\"] += len(true_items)\n",
    "\n",
    "def finalize_metrics(acc, model_name, formulation_name):\n",
    "    users_eval = max(int(acc[\"UsersEval\"]), 1)\n",
    "    out = {\n",
    "        \"Formulation\": formulation_name,\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": users_eval,\n",
    "        \"HR@5\": float(acc[\"HR@5\"]) / users_eval,\n",
    "        \"HR@10\": float(acc[\"HR@10\"]) / users_eval,\n",
    "        \"HR@20\": float(acc[\"HR@20\"]) / users_eval,\n",
    "        \"MRR@10\": float(acc[\"MRR@10\"]) / users_eval,\n",
    "        \"NDCG@10\": float(acc[\"NDCG@10\"]) / users_eval,\n",
    "    }\n",
    "    return out\n",
    "\n",
    "def popularity_recommend_one(global_pop_rank, seen_array, n):\n",
    "    if seen_array is None or len(seen_array) == 0:\n",
    "        return global_pop_rank[:n].tolist()\n",
    "    seen_set = set(int(x) for x in seen_array)\n",
    "    out = []\n",
    "    for item in global_pop_rank:\n",
    "        item = int(item)\n",
    "        if item not in seen_set:\n",
    "            out.append(item)\n",
    "            if len(out) >= n:\n",
    "                break\n",
    "    return out\n",
    "\n",
    "def evaluate_popularity(bundle, model_name=\"Popularity\"):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    global_pop_rank = bundle[\"global_pop_rank\"]\n",
    "    user_seen_arrays = bundle[\"user_seen_arrays\"]\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    for uid in tqdm(user_indices, desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        true_item = test_item_by_user[int(uid)]\n",
    "        recs = popularity_recommend_one(global_pop_rank, user_seen_arrays.get(int(uid)), max(TOP_KS))\n",
    "        recs_arr = np.array(recs, dtype=np.int32)[None, :]\n",
    "        true_arr = np.array([true_item], dtype=np.int32)\n",
    "        batch_metric_update(recs_arr, true_arr, acc, TOP_KS)\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def clear_memory():\n",
    "    gc.collect()\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3e001d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model fitting helpers and batched evaluators\n",
    "\n",
    "def l1_row_normalize(X):\n",
    "    X = X.tocsr(copy=True)\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel().astype(np.float32)\n",
    "    inv = np.zeros_like(row_sums, dtype=np.float32)\n",
    "    mask = row_sums > 0\n",
    "    inv[mask] = 1.0 / row_sums[mask]\n",
    "    return sp.diags(inv) @ X\n",
    "\n",
    "\n",
    "def fit_p3alpha(X_binary, alpha=1.0, beta=0.0):\n",
    "    Pui = l1_row_normalize(X_binary).power(alpha).tocsr()\n",
    "    Piu = l1_row_normalize(X_binary.T).power(alpha).tocsr()\n",
    "\n",
    "    S = (Piu @ Pui).astype(np.float32).tocsr()\n",
    "    S.setdiag(0.0)\n",
    "    S.eliminate_zeros()\n",
    "\n",
    "    if beta > 0:\n",
    "        item_degree = np.asarray(X_binary.sum(axis=0)).ravel().astype(np.float32)\n",
    "        penalty = np.ones_like(item_degree, dtype=np.float32)\n",
    "        mask = item_degree > 0\n",
    "        penalty[mask] = np.power(item_degree[mask], -beta)\n",
    "        S = S @ sp.diags(penalty.astype(np.float32))\n",
    "\n",
    "    return S.tocsr()\n",
    "\n",
    "def prepare_ease_cache(X_matrix, source_name=\"binary\", prefer_gpu=True):\n",
    "    t0 = time.time()\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32, copy=False)\n",
    "    prep = {\n",
    "        \"source_name\": source_name,\n",
    "        \"backend\": \"cpu\",\n",
    "        \"gram_np\": G,\n",
    "        \"n_items\": G.shape[0],\n",
    "        \"prep_sec\": time.time() - t0,\n",
    "    }\n",
    "\n",
    "    if prefer_gpu and device.type == \"cuda\":\n",
    "        try:\n",
    "            gram_t = torch.from_numpy(G).to(device, non_blocking=True)\n",
    "            eye_t = torch.eye(G.shape[0], dtype=torch.float32, device=device)\n",
    "            prep.update({\n",
    "                \"backend\": \"gpu\",\n",
    "                \"gram_t\": gram_t,\n",
    "                \"eye_t\": eye_t,\n",
    "            })\n",
    "        except Exception as e:\n",
    "            print(f\"    GPU cache setup failed for {source_name}, falling back to CPU: {e}\")\n",
    "\n",
    "    return prep\n",
    "\n",
    "def fit_ease_from_cache(cache, lam):\n",
    "    if cache.get(\"backend\") == \"gpu\":\n",
    "        gram_t = cache[\"gram_t\"].clone()\n",
    "        gram_t.diagonal().add_(float(lam))\n",
    "        L, info = torch.linalg.cholesky_ex(gram_t, upper=False)\n",
    "        info_val = int(info.item()) if hasattr(info, \"item\") else int(info)\n",
    "        if info_val == 0:\n",
    "            P = torch.cholesky_solve(cache[\"eye_t\"], L, upper=False)\n",
    "            diag = torch.diagonal(P).clone()\n",
    "            B = -P / diag.unsqueeze(0)\n",
    "            B.fill_diagonal_(0.0)\n",
    "            return B\n",
    "        else:\n",
    "            print(f\"    GPU Cholesky failed at lambda={lam} (info={info_val}), falling back to CPU.\")\n",
    "\n",
    "    G = cache[\"gram_np\"].copy()\n",
    "    G[np.diag_indices_from(G)] += np.float32(lam)\n",
    "    c, lower = sla.cho_factor(G, lower=True, overwrite_a=True, check_finite=False)\n",
    "    I = np.eye(G.shape[0], dtype=np.float32)\n",
    "    P = sla.cho_solve((c, lower), I, overwrite_b=True, check_finite=False)\n",
    "    diag = np.diag(P).copy()\n",
    "    B = -P / diag[None, :]\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    return B.astype(np.float32, copy=False)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_dense_operator(bundle, operator_matrix, model_name, source=\"binary\", batch_size=4096):\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    X_source = bundle[\"X_binary\"] if source == \"binary\" else bundle[\"X_counts\"]\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    if isinstance(operator_matrix, torch.Tensor):\n",
    "        op_t = operator_matrix.to(device=device, dtype=torch.float32)\n",
    "    else:\n",
    "        op_t = torch.as_tensor(operator_matrix, dtype=torch.float32, device=device)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), batch_size), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + batch_size]\n",
    "        X_np = X_source[batch_uids].toarray().astype(np.float32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "\n",
    "        X_t = torch.from_numpy(X_np).to(device, non_blocking=True)\n",
    "        scores = X_t @ op_t\n",
    "        scores = scores.masked_fill(X_t > 0, -1e9)\n",
    "\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "\n",
    "        del X_t, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "def make_rrf_scores(score_matrices, rrf_k=60):\n",
    "    # score_matrices: list of top-k item index arrays with shape [batch, fetch_n]\n",
    "    batch, fetch_n = score_matrices[0].shape\n",
    "    fused = np.zeros((batch, fetch_n * len(score_matrices)), dtype=np.float32)\n",
    "    del fused\n",
    "    # kept only as a placeholder helper; fusion is done in a simple Python routine later\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "def make_rrf_scores(score_matrices, rrf_k=60):\n",
    "    # score_matrices: list of top-k item index arrays with shape [batch, fetch_n]\n",
    "    batch, fetch_n = score_matrices[0].shape\n",
    "    fused = np.zeros((batch, fetch_n * len(score_matrices)), dtype=np.float32)\n",
    "    del fused\n",
    "    # kept only as a placeholder helper; fusion is done in a simple Python routine later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "33f155c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Formulation: H1_base\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition']\n",
      "Users      : 43199\n",
      "Items      : 2640\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 831152\n",
      "Test rows  : 43199\n",
      "Avg train interactions/user: 19.24\n",
      "====================================================================================================\n",
      "Formulation: H1_color\n",
      "User cols  : ['Country', 'PriceBin12', 'MileageBin12', 'Condition', 'AgeBin10', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin10', 'Condition', 'Color']\n",
      "Users      : 62921\n",
      "Items      : 13198\n",
      "Thresholds : (4, 8)\n",
      "Train rows : 237108\n",
      "Test rows  : 62921\n",
      "Avg train interactions/user: 3.768\n",
      "====================================================================================================\n",
      "Formulation: H1_fine\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition']\n",
      "Users      : 95529\n",
      "Items      : 3696\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 812918\n",
      "Test rows  : 95529\n",
      "Avg train interactions/user: 8.51\n",
      "====================================================================================================\n",
      "Formulation: H1_fine_color\n",
      "User cols  : ['Country', 'PriceBin16', 'MileageBin16', 'Condition', 'AgeBin14', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin14', 'Condition', 'Color']\n",
      "Users      : 59346\n",
      "Items      : 23514\n",
      "Thresholds : (3, 5)\n",
      "Train rows : 134712\n",
      "Test rows  : 59346\n",
      "Avg train interactions/user: 2.27\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>TrainRows</th>\n",
       "      <th>AvgTrainPerUser</th>\n",
       "      <th>Thresholds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_fine_color</td>\n",
       "      <td>59346</td>\n",
       "      <td>23514</td>\n",
       "      <td>134712</td>\n",
       "      <td>2.269942</td>\n",
       "      <td>(3, 5)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H1_color</td>\n",
       "      <td>62921</td>\n",
       "      <td>13198</td>\n",
       "      <td>237108</td>\n",
       "      <td>3.768344</td>\n",
       "      <td>(4, 8)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H1_fine</td>\n",
       "      <td>95529</td>\n",
       "      <td>3696</td>\n",
       "      <td>812918</td>\n",
       "      <td>8.509646</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>2640</td>\n",
       "      <td>831152</td>\n",
       "      <td>19.240075</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     Formulation  Users  Items  TrainRows  AvgTrainPerUser Thresholds\n",
       "0  H1_fine_color  59346  23514     134712         2.269942     (3, 5)\n",
       "1       H1_color  62921  13198     237108         3.768344     (4, 8)\n",
       "2        H1_fine  95529   3696     812918         8.509646    (5, 10)\n",
       "3        H1_base  43199   2640     831152        19.240075    (5, 10)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Build all richer H1 formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"TrainRows\": len(bundle[\"train_interactions\"]),\n",
    "        \"AvgTrainPerUser\": len(bundle[\"train_interactions\"]) / max(bundle[\"num_users\"], 1),\n",
    "        \"Thresholds\": bundle[\"used_thresholds\"],\n",
    "    })\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values([\"Items\", \"Users\"], ascending=False).reset_index(drop=True)\n",
    "display(summary_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a61a7425",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Regular search on formulation: H1_base\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / Popularity: 100%|██████████| 43199/43199 [00:01<00:00, 31871.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_0_b0_6: 100%|██████████| 11/11 [00:00<00:00, 33.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.62s\n",
      "- Fitting RP3beta_a1_05_b0_6\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_05_b0_6: 100%|██████████| 11/11 [00:00<00:00, 95.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_1_b0_7: 100%|██████████| 11/11 [00:00<00:00, 94.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / RP3beta_a1_15_b0_7: 100%|██████████| 11/11 [00:00<00:00, 94.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=2640 prep_time=0.06s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=2640 prep_time=0.06s\n",
      "- Fitting EASE_binary_l800\n",
      "    fit_time=0.04s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l800: 100%|██████████| 11/11 [00:00<00:00, 86.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l1000\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1000: 100%|██████████| 11/11 [00:00<00:00, 79.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l1200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1200: 100%|██████████| 11/11 [00:00<00:00, 78.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l1600\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l1600: 100%|██████████| 11/11 [00:00<00:00, 82.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_binary_l2200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_binary_l2200: 100%|██████████| 11/11 [00:00<00:00, 84.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l800\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l800: 100%|██████████| 11/11 [00:00<00:00, 84.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l1200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l1200: 100%|██████████| 11/11 [00:00<00:00, 90.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l1600\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l1600: 100%|██████████| 11/11 [00:00<00:00, 90.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting EASE_count_l2200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_base / EASE_count_l2200: 100%|██████████| 11/11 [00:00<00:00, 91.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.12s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================================================================================================\n",
      "Regular search on formulation: H1_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / Popularity: 100%|██████████| 62921/62921 [00:01<00:00, 32761.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_0_b0_6: 100%|██████████| 16/16 [00:01<00:00, 11.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.46s\n",
      "- Fitting RP3beta_a1_05_b0_6\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_05_b0_6: 100%|██████████| 16/16 [00:01<00:00, 12.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.32s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_1_b0_7: 100%|██████████| 16/16 [00:01<00:00, 12.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.30s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / RP3beta_a1_15_b0_7: 100%|██████████| 16/16 [00:01<00:00, 12.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.29s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=13198 prep_time=0.06s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=13198 prep_time=0.06s\n",
      "- Fitting EASE_binary_l800\n",
      "    fit_time=0.09s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_binary_l800: 100%|██████████| 16/16 [00:01<00:00, 10.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.57s\n",
      "- Fitting EASE_binary_l1000\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_binary_l1000: 100%|██████████| 16/16 [00:01<00:00, 10.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.58s\n",
      "- Fitting EASE_binary_l1200\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_binary_l1200: 100%|██████████| 16/16 [00:01<00:00, 10.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.52s\n",
      "- Fitting EASE_binary_l1600\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_binary_l1600: 100%|██████████| 16/16 [00:01<00:00, 10.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.53s\n",
      "- Fitting EASE_binary_l2200\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_binary_l2200: 100%|██████████| 16/16 [00:01<00:00, 10.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.51s\n",
      "- Fitting EASE_count_l800\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_count_l800: 100%|██████████| 16/16 [00:01<00:00, 10.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.51s\n",
      "- Fitting EASE_count_l1200\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_count_l1200: 100%|██████████| 16/16 [00:01<00:00, 10.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.52s\n",
      "- Fitting EASE_count_l1600\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_count_l1600: 100%|██████████| 16/16 [00:01<00:00, 10.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.56s\n",
      "- Fitting EASE_count_l2200\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_color / EASE_count_l2200: 100%|██████████| 16/16 [00:01<00:00, 10.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=1.55s\n",
      "========================================================================================================================\n",
      "Regular search on formulation: H1_fine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / Popularity: 100%|██████████| 95529/95529 [00:02<00:00, 33355.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_a1_0_b0_6: 100%|██████████| 24/24 [00:00<00:00, 42.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.58s\n",
      "- Fitting RP3beta_a1_05_b0_6\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_a1_05_b0_6: 100%|██████████| 24/24 [00:00<00:00, 67.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.36s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.07s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_a1_1_b0_7: 100%|██████████| 24/24 [00:00<00:00, 70.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.08s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / RP3beta_a1_15_b0_7: 100%|██████████| 24/24 [00:00<00:00, 70.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=3696 prep_time=0.05s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=3696 prep_time=0.05s\n",
      "- Fitting EASE_binary_l800\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l800: 100%|██████████| 24/24 [00:00<00:00, 69.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting EASE_binary_l1000\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l1000: 100%|██████████| 24/24 [00:00<00:00, 70.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.34s\n",
      "- Fitting EASE_binary_l1200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l1200: 100%|██████████| 24/24 [00:00<00:00, 70.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting EASE_binary_l1600\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l1600: 100%|██████████| 24/24 [00:00<00:00, 70.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.34s\n",
      "- Fitting EASE_binary_l2200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_binary_l2200: 100%|██████████| 24/24 [00:00<00:00, 70.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.35s\n",
      "- Fitting EASE_count_l800\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_count_l800: 100%|██████████| 24/24 [00:00<00:00, 70.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.34s\n",
      "- Fitting EASE_count_l1200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_count_l1200: 100%|██████████| 24/24 [00:00<00:00, 70.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.34s\n",
      "- Fitting EASE_count_l1600\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_count_l1600: 100%|██████████| 24/24 [00:00<00:00, 70.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.34s\n",
      "- Fitting EASE_count_l2200\n",
      "    fit_time=0.00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine / EASE_count_l2200: 100%|██████████| 24/24 [00:00<00:00, 70.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=0.34s\n",
      "========================================================================================================================\n",
      "Regular search on formulation: H1_fine_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine_color / Popularity: 100%|██████████| 59346/59346 [00:01<00:00, 34119.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Fitting RP3beta_a1_0_b0_6\n",
      "    fit_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine_color / RP3beta_a1_0_b0_6: 100%|██████████| 15/15 [00:03<00:00,  4.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=3.28s\n",
      "- Fitting RP3beta_a1_05_b0_6\n",
      "    fit_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine_color / RP3beta_a1_05_b0_6: 100%|██████████| 15/15 [00:02<00:00,  5.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=3.12s\n",
      "- Fitting RP3beta_a1_1_b0_7\n",
      "    fit_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine_color / RP3beta_a1_1_b0_7: 100%|██████████| 15/15 [00:02<00:00,  5.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=3.12s\n",
      "- Fitting RP3beta_a1_15_b0_7\n",
      "    fit_time=0.14s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_fine_color / RP3beta_a1_15_b0_7: 100%|██████████| 15/15 [00:02<00:00,  5.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    eval_time=3.13s\n",
      "- Preparing EASE binary Gram cache\n",
      "    backend=gpu n_items=23514 prep_time=0.13s\n",
      "- Preparing EASE count Gram cache\n",
      "    backend=gpu n_items=23514 prep_time=0.13s\n",
      "- Fitting EASE_binary_l800\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 2.06 GiB. GPU 0 has a total capacity of 9.64 GiB of which 503.31 MiB is free. Process 1192 has 22.22 MiB memory in use. Including non-PyTorch memory, this process has 8.53 GiB memory in use. Of the allocated memory 8.25 GiB is allocated by PyTorch, and 11.88 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mOutOfMemoryError\u001b[39m                          Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 54\u001b[39m\n\u001b[32m     50\u001b[39m         clear_memory()\n\u001b[32m     51\u001b[39m         model_name = f\"EASE_binary_l{int(lam)}\"\n\u001b[32m     52\u001b[39m         print(\u001b[33m\"- Fitting\"\u001b[39m, model_name)\n\u001b[32m     53\u001b[39m         t_fit = time.time()\n\u001b[32m---> \u001b[39m\u001b[32m54\u001b[39m         B = fit_ease_from_cache(ease_bin_cache, lam=lam)\n\u001b[32m     55\u001b[39m         fit_sec = time.time() - t_fit\n\u001b[32m     56\u001b[39m         print(f\"    fit_time={fit_sec:.2f}s\")\n\u001b[32m     57\u001b[39m         t_eval = time.time()\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 56\u001b[39m, in \u001b[36mfit_ease_from_cache\u001b[39m\u001b[34m(cache, lam)\u001b[39m\n\u001b[32m     54\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m fit_ease_from_cache(cache, lam):\n\u001b[32m     55\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m cache.get(\u001b[33m\"backend\"\u001b[39m) == \u001b[33m\"gpu\"\u001b[39m:\n\u001b[32m---> \u001b[39m\u001b[32m56\u001b[39m         gram_t = cache[\u001b[33m\"gram_t\"\u001b[39m].clone()\n\u001b[32m     57\u001b[39m         gram_t.diagonal().add_(float(lam))\n\u001b[32m     58\u001b[39m         L, info = torch.linalg.cholesky_ex(gram_t, upper=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m     59\u001b[39m         info_val = int(info.item()) \u001b[38;5;28;01mif\u001b[39;00m hasattr(info, \u001b[33m\"item\"\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m int(info)\n",
      "\u001b[31mOutOfMemoryError\u001b[39m: CUDA out of memory. Tried to allocate 2.06 GiB. GPU 0 has a total capacity of 9.64 GiB of which 503.31 MiB is free. Process 1192 has 22.22 MiB memory in use. Including non-PyTorch memory, this process has 8.53 GiB memory in use. Of the allocated memory 8.25 GiB is allocated by PyTorch, and 11.88 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)"
     ]
    }
   ],
   "source": [
    "# Regular model search across all H1 variants\n",
    "\n",
    "regular_results = []\n",
    "\n",
    "for formulation_name, bundle in bundles.items():\n",
    "    print(\"=\" * 120)\n",
    "    print(\"Regular search on formulation:\", formulation_name)\n",
    "\n",
    "    # Popularity baseline\n",
    "    regular_results.append(evaluate_popularity(bundle, model_name=\"Popularity\"))\n",
    "\n",
    "    # RP3beta grid\n",
    "    for alpha, beta in RP3_GRID:\n",
    "        clear_memory()\n",
    "        model_name = f\"RP3beta_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        S = fit_p3alpha(bundle[\"X_binary\"], alpha=alpha, beta=beta)\n",
    "        S_dense = S.toarray().astype(np.float32, copy=False)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        regular_results.append(\n",
    "            evaluate_dense_operator(bundle, S_dense, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        )\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del S, S_dense\n",
    "        clear_memory()\n",
    "\n",
    "    # Prepare EASE caches once per formulation\n",
    "    clear_memory()\n",
    "    print(\"- Preparing EASE binary Gram cache\")\n",
    "    ease_bin_cache = prepare_ease_cache(bundle[\"X_binary\"], source_name=\"binary\", prefer_gpu=True)\n",
    "    print(\n",
    "        f\"    backend={ease_bin_cache['backend']} \"\n",
    "        f\"n_items={ease_bin_cache['n_items']} \"\n",
    "        f\"prep_time={ease_bin_cache['prep_sec']:.2f}s\"\n",
    "    )\n",
    "\n",
    "    print(\"- Preparing EASE count Gram cache\")\n",
    "    ease_cnt_cache = prepare_ease_cache(bundle[\"X_counts\"], source_name=\"count\", prefer_gpu=True)\n",
    "    print(\n",
    "        f\"    backend={ease_cnt_cache['backend']} \"\n",
    "        f\"n_items={ease_cnt_cache['n_items']} \"\n",
    "        f\"prep_time={ease_cnt_cache['prep_sec']:.2f}s\"\n",
    "    )\n",
    "\n",
    "    # EASE binary\n",
    "    for lam in EASE_BINARY_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_binary_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        B = fit_ease_from_cache(ease_bin_cache, lam=lam)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        regular_results.append(\n",
    "            evaluate_dense_operator(bundle, B, model_name=model_name, source=\"binary\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        )\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "    # EASE count\n",
    "    for lam in EASE_COUNT_LAMBDAS:\n",
    "        clear_memory()\n",
    "        model_name = f\"EASE_count_l{int(lam)}\"\n",
    "        print(\"- Fitting\", model_name)\n",
    "        t_fit = time.time()\n",
    "        B = fit_ease_from_cache(ease_cnt_cache, lam=lam)\n",
    "        fit_sec = time.time() - t_fit\n",
    "        print(f\"    fit_time={fit_sec:.2f}s\")\n",
    "        t_eval = time.time()\n",
    "        regular_results.append(\n",
    "            evaluate_dense_operator(bundle, B, model_name=model_name, source=\"count\", batch_size=LINEAR_EVAL_BATCH)\n",
    "        )\n",
    "        print(f\"    eval_time={time.time() - t_eval:.2f}s\")\n",
    "        del B\n",
    "        clear_memory()\n",
    "\n",
    "    # Drop caches explicitly before moving to next formulation\n",
    "    for cache in [ease_bin_cache, ease_cnt_cache]:\n",
    "        for key in [\"gram_t\", \"eye_t\", \"gram_np\"]:\n",
    "            if key in cache:\n",
    "                del cache[key]\n",
    "    del ease_bin_cache, ease_cnt_cache\n",
    "    clear_memory()\n",
    "\n",
    "regular_results_df = (\n",
    "    pd.DataFrame(regular_results)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "print(\"Top regular results overall:\")\n",
    "display(regular_results_df.head(20))\n",
    "\n",
    "best_formulations_df = (\n",
    "    regular_results_df.groupby(\"Formulation\")[[\"HR@10\", \"HR@20\", \"MRR@10\", \"NDCG@10\"]]\n",
    "    .max()\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "print(\"Best formulation scores:\")\n",
    "display(best_formulations_df)\n",
    "\n",
    "selected_formulations = best_formulations_df[\"Formulation\"].head(TOP_FORMULATIONS_FOR_NEURAL).tolist()\n",
    "print(\"Selected formulations for neural search:\", selected_formulations)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3cf5b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare formulation-specific tensors for neural models\n",
    "\n",
    "def prepare_neural_data(bundle):\n",
    "    user_feature_df = bundle[\"user_feature_df\"].copy()\n",
    "    item_feature_df = bundle[\"item_feature_df\"].copy()\n",
    "    train_interactions = bundle[\"train_interactions\"].copy()\n",
    "\n",
    "    user_cols = bundle[\"user_cols\"]\n",
    "    item_cols = bundle[\"item_cols\"]\n",
    "\n",
    "    feature_encoders = {}\n",
    "\n",
    "    for col in user_cols:\n",
    "        le = LabelEncoder()\n",
    "        user_feature_df[col] = le.fit_transform(user_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"user::{col}\"] = le\n",
    "\n",
    "    for col in item_cols:\n",
    "        le = LabelEncoder()\n",
    "        item_feature_df[col] = le.fit_transform(item_feature_df[col].astype(str))\n",
    "        feature_encoders[f\"item::{col}\"] = le\n",
    "\n",
    "    user_feat_np = user_feature_df[user_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "    item_feat_np = item_feature_df[item_cols].to_numpy(dtype=np.int64, copy=True)\n",
    "\n",
    "    pos_user_np = train_interactions[\"user_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_item_np = train_interactions[\"item_idx\"].to_numpy(dtype=np.int64, copy=True)\n",
    "    pos_weight_np = train_interactions[\"count\"].to_numpy(dtype=np.float32, copy=True)\n",
    "\n",
    "    x_binary_np = bundle[\"X_binary\"].toarray().astype(np.float32, copy=False)\n",
    "\n",
    "    out = {\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"user_feat_t\": torch.from_numpy(user_feat_np).to(device),\n",
    "        \"item_feat_t\": torch.from_numpy(item_feat_np).to(device),\n",
    "        \"pos_user_t\": torch.from_numpy(pos_user_np).to(device),\n",
    "        \"pos_item_t\": torch.from_numpy(pos_item_np).to(device),\n",
    "        \"pos_weight_t\": torch.from_numpy(pos_weight_np).to(device),\n",
    "        \"x_binary_np\": x_binary_np,\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        out[\"x_binary_t\"] = torch.from_numpy(x_binary_np).to(device)\n",
    "        out[\"x_binary_on_device\"] = True\n",
    "    except RuntimeError:\n",
    "        out[\"x_binary_t\"] = None\n",
    "        out[\"x_binary_on_device\"] = False\n",
    "        clear_memory()\n",
    "\n",
    "    out[\"user_feature_cards\"] = {\n",
    "        col: int(user_feature_df[col].nunique()) for col in user_cols\n",
    "    }\n",
    "    out[\"item_feature_cards\"] = {\n",
    "        col: int(item_feature_df[col].nunique()) for col in item_cols\n",
    "    }\n",
    "\n",
    "    return out\n",
    "\n",
    "neural_data_by_formulation = {}\n",
    "for formulation_name in selected_formulations:\n",
    "    print(\"Preparing neural tensors:\", formulation_name)\n",
    "    neural_data_by_formulation[formulation_name] = prepare_neural_data(bundles[formulation_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c9a604a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural architectures and batched evaluators\n",
    "\n",
    "class FeatureEmbeddingBlock(nn.Module):\n",
    "    def __init__(self, cardinalities, emb_dim):\n",
    "        super().__init__()\n",
    "        self.cols = list(cardinalities.keys())\n",
    "        self.embs = nn.ModuleList([\n",
    "            nn.Embedding(int(cardinalities[c]), emb_dim) for c in self.cols\n",
    "        ])\n",
    "        for emb in self.embs:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, x):\n",
    "        parts = [emb(x[:, i]) for i, emb in enumerate(self.embs)]\n",
    "        return torch.cat(parts, dim=1)\n",
    "\n",
    "class MLPBlock(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims, dropout=0.1, final_dim=None):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        prev = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(dropout)])\n",
    "            prev = h\n",
    "        if final_dim is not None:\n",
    "            layers.append(nn.Linear(prev, final_dim))\n",
    "            prev = final_dim\n",
    "        self.net = nn.Sequential(*layers)\n",
    "        self.output_dim = prev\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "class TowerEncoder(nn.Module):\n",
    "    def __init__(self, cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.embed = FeatureEmbeddingBlock(cards, emb_dim)\n",
    "        input_dim = len(cards) * emb_dim\n",
    "        self.mlp = MLPBlock(input_dim, hidden_dims, dropout=dropout, final_dim=out_dim)\n",
    "\n",
    "    def forward(self, feat_batch):\n",
    "        x = self.embed(feat_batch)\n",
    "        x = self.mlp(x)\n",
    "        return F.normalize(x, dim=-1)\n",
    "\n",
    "class TwoTowerModel(nn.Module):\n",
    "    def __init__(self, user_cards, item_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "        super().__init__()\n",
    "        self.user_tower = TowerEncoder(user_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "        self.item_tower = TowerEncoder(item_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "class MultVAE(nn.Module):\n",
    "    def __init__(self, num_items, hidden_dim=2048, latent_dim=512, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(num_items, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(latent_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, num_items),\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = self.encoder(self.dropout(x))\n",
    "        return self.mu(h), self.logvar(h)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        if self.training:\n",
    "            std = torch.exp(0.5 * logvar)\n",
    "            eps = torch.randn_like(std)\n",
    "            return mu + eps * std\n",
    "        return mu\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        logits = self.decoder(z)\n",
    "        return logits, mu, logvar\n",
    "\n",
    "class NeuMF(nn.Module):\n",
    "    def __init__(self, num_users, num_items, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128)):\n",
    "        super().__init__()\n",
    "        self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "        self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "        self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "        self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "        mlp_layers = []\n",
    "        prev = 2 * mlp_dim\n",
    "        for h in hidden_dims:\n",
    "            mlp_layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(0.1)])\n",
    "            prev = h\n",
    "        self.mlp = nn.Sequential(*mlp_layers)\n",
    "        self.out = nn.Linear(prev + mf_dim, 1)\n",
    "\n",
    "        for emb in [self.user_mf, self.item_mf, self.user_mlp, self.item_mlp]:\n",
    "            nn.init.normal_(emb.weight, std=0.02)\n",
    "\n",
    "    def forward(self, user_idx, item_idx):\n",
    "        mf_part = self.user_mf(user_idx) * self.item_mf(item_idx)\n",
    "        mlp_in = torch.cat([self.user_mlp(user_idx), self.item_mlp(item_idx)], dim=1)\n",
    "        mlp_part = self.mlp(mlp_in)\n",
    "        x = torch.cat([mf_part, mlp_part], dim=1)\n",
    "        return self.out(x).squeeze(-1)\n",
    "\n",
    "def iter_positive_batches(ndata, batch_size):\n",
    "    n = ndata[\"pos_user_t\"].shape[0]\n",
    "    perm = torch.randperm(n, device=device)\n",
    "    for start in range(0, n, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        yield (\n",
    "            ndata[\"pos_user_t\"][idx],\n",
    "            ndata[\"pos_item_t\"][idx],\n",
    "            ndata[\"pos_weight_t\"][idx],\n",
    "        )\n",
    "\n",
    "def iter_dense_user_batches(ndata, batch_size):\n",
    "    n_users = ndata[\"x_binary_np\"].shape[0]\n",
    "    perm = np.random.permutation(n_users)\n",
    "    for start in range(0, n_users, batch_size):\n",
    "        idx = perm[start:start + batch_size]\n",
    "        if ndata[\"x_binary_on_device\"]:\n",
    "            yield ndata[\"x_binary_t\"][idx], torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "        else:\n",
    "            batch = torch.from_numpy(ndata[\"x_binary_np\"][idx]).to(device)\n",
    "            yield batch, torch.as_tensor(idx, device=device, dtype=torch.long)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_twotower(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    item_feats = ndata[\"item_feat_t\"]\n",
    "    item_vecs = []\n",
    "    for start in range(0, item_feats.shape[0], 4096):\n",
    "        item_vecs.append(model.item_tower(item_feats[start:start + 4096]))\n",
    "    item_matrix = torch.cat(item_vecs, dim=0)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        feat_batch = ndata[\"user_feat_t\"][batch_uids]\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        user_vec = model.user_tower(feat_batch)\n",
    "        scores = user_vec @ item_matrix.T\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del feat_batch, x_seen, user_vec, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_multvae(bundle, ndata, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEURAL_EVAL_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEURAL_EVAL_BATCH]\n",
    "        x_batch = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "        logits, _, _ = model(x_batch)\n",
    "        logits = logits.masked_fill(x_batch > 0, -1e9)\n",
    "        topk = torch.topk(logits, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([bundle[\"test_item_by_user\"][int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del x_batch, logits, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_neumf(bundle, model, model_name):\n",
    "    model.eval()\n",
    "    user_indices = get_eval_users(bundle)\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    x_source = bundle[\"X_binary\"]\n",
    "    num_items = bundle[\"num_items\"]\n",
    "    max_k = max(TOP_KS)\n",
    "\n",
    "    acc = {\"HR@5\": 0.0, \"HR@10\": 0.0, \"HR@20\": 0.0, \"MRR@10\": 0.0, \"NDCG@10\": 0.0, \"UsersEval\": 0}\n",
    "    all_items = torch.arange(num_items, device=device, dtype=torch.long)\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), NEUMF_EVAL_USER_BATCH), desc=f\"{bundle['name']} / {model_name}\"):\n",
    "        batch_uids = user_indices[start:start + NEUMF_EVAL_USER_BATCH]\n",
    "        user_batch = torch.as_tensor(batch_uids, device=device, dtype=torch.long)\n",
    "        x_seen = torch.from_numpy(x_source[batch_uids].toarray().astype(np.float32, copy=False)).to(device)\n",
    "\n",
    "        parts = []\n",
    "        for istart in range(0, num_items, NEUMF_EVAL_ITEM_BATCH):\n",
    "            item_batch = all_items[istart:istart + NEUMF_EVAL_ITEM_BATCH]\n",
    "            u_expand = user_batch[:, None].expand(-1, item_batch.shape[0]).reshape(-1)\n",
    "            i_expand = item_batch[None, :].expand(user_batch.shape[0], -1).reshape(-1)\n",
    "            scores_part = model(u_expand, i_expand).reshape(user_batch.shape[0], -1)\n",
    "            parts.append(scores_part)\n",
    "\n",
    "        scores = torch.cat(parts, dim=1)\n",
    "        scores = scores.masked_fill(x_seen > 0, -1e9)\n",
    "        topk = torch.topk(scores, k=max_k, dim=1).indices.cpu().numpy().astype(np.int32, copy=False)\n",
    "        true_np = np.array([test_item_by_user[int(u)] for u in batch_uids], dtype=np.int32)\n",
    "        batch_metric_update(topk, true_np, acc, TOP_KS)\n",
    "        del user_batch, x_seen, parts, scores, topk\n",
    "\n",
    "    return finalize_metrics(acc, model_name, bundle[\"name\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f000e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural training functions\n",
    "\n",
    "def train_two_tower(bundle, ndata, cfg):\n",
    "    model = TwoTowerModel(\n",
    "        user_cards=ndata[\"user_feature_cards\"],\n",
    "        item_cards=ndata[\"item_feature_cards\"],\n",
    "        emb_dim=cfg[\"emb_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "        out_dim=cfg[\"out_dim\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "        for user_idx_batch, item_idx_batch, weight_batch in iter_positive_batches(ndata, TWOTOWER_BATCH):\n",
    "            user_batch = ndata[\"user_feat_t\"][user_idx_batch]\n",
    "            item_batch = ndata[\"item_feat_t\"][item_idx_batch]\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                user_vec = model.user_tower(user_batch)\n",
    "                item_vec = model.item_tower(item_batch)\n",
    "                logits = (user_vec @ item_vec.T) / cfg[\"temperature\"]\n",
    "                targets = torch.arange(logits.shape[0], device=device)\n",
    "                loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "            total_w += float(weight_batch.sum().item())\n",
    "\n",
    "        print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def train_multvae(bundle, ndata, cfg):\n",
    "    model = MultVAE(\n",
    "        num_items=bundle[\"num_items\"],\n",
    "        hidden_dim=cfg[\"hidden_dim\"],\n",
    "        latent_dim=cfg[\"latent_dim\"],\n",
    "        dropout=cfg[\"dropout\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    steps_per_epoch = max(1, math.ceil(n_users / AUTOENC_BATCH))\n",
    "    total_steps = max(1, cfg[\"epochs\"] * steps_per_epoch)\n",
    "    step = 0\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        total_loss = 0.0\n",
    "        total_rows = 0\n",
    "\n",
    "        for batch_x, _ in iter_dense_user_batches(ndata, AUTOENC_BATCH):\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            anneal = min(cfg[\"anneal_cap\"], step / max(total_steps * 0.3, 1))\n",
    "\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                logits, mu, logvar = model(batch_x)\n",
    "                recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                loss = (recon + anneal * kl).mean()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "            total_rows += batch_x.shape[0]\n",
    "            step += 1\n",
    "\n",
    "        if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == cfg[\"epochs\"]:\n",
    "            print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_rows, 1):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def train_neumf(bundle, cfg):\n",
    "    n_users = bundle[\"num_users\"]\n",
    "    n_items = bundle[\"num_items\"]\n",
    "    train_df = bundle[\"train_interactions\"]\n",
    "\n",
    "    pos_user = torch.as_tensor(train_df[\"user_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_item = torch.as_tensor(train_df[\"item_idx\"].to_numpy(dtype=np.int64), device=device)\n",
    "    pos_weight = torch.as_tensor(train_df[\"count\"].to_numpy(dtype=np.float32), device=device)\n",
    "\n",
    "    model = NeuMF(\n",
    "        num_users=n_users,\n",
    "        num_items=n_items,\n",
    "        mf_dim=cfg[\"mf_dim\"],\n",
    "        mlp_dim=cfg[\"mlp_dim\"],\n",
    "        hidden_dims=cfg[\"hidden_dims\"],\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg[\"lr\"], weight_decay=cfg[\"wd\"])\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    n = pos_user.shape[0]\n",
    "    model.train()\n",
    "\n",
    "    for epoch in range(cfg[\"epochs\"]):\n",
    "        perm = torch.randperm(n, device=device)\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        for start in range(0, n, PAIRWISE_BATCH):\n",
    "            idx = perm[start:start + PAIRWISE_BATCH]\n",
    "            u = pos_user[idx]\n",
    "            p = pos_item[idx]\n",
    "            w = pos_weight[idx]\n",
    "            n_item = torch.randint(0, n_items, size=p.shape, device=device)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                pos_scores = model(u, p)\n",
    "                neg_scores = model(u, n_item)\n",
    "                loss_vec = -F.logsigmoid(pos_scores - neg_scores)\n",
    "                loss = (loss_vec * w).sum() / w.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(w.sum().item())\n",
    "            total_w += float(w.sum().item())\n",
    "\n",
    "        print(f\"{bundle['name']} / {cfg['name']} epoch {epoch + 1}/{cfg['epochs']} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b2ee8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural search on the best formulations only\n",
    "\n",
    "neural_results = []\n",
    "\n",
    "if RUN_NEURAL:\n",
    "    for formulation_name in selected_formulations:\n",
    "        print(\"=\" * 120)\n",
    "        print(\"Neural search on formulation:\", formulation_name)\n",
    "\n",
    "        bundle = bundles[formulation_name]\n",
    "        ndata = neural_data_by_formulation[formulation_name]\n",
    "\n",
    "        for cfg in TWOTOWER_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_two_tower(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_twotower(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in MULTVAE_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_multvae(bundle, ndata, cfg)\n",
    "            neural_results.append(evaluate_multvae(bundle, ndata, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "        for cfg in NEUMF_CONFIGS:\n",
    "            clear_memory()\n",
    "            model = train_neumf(bundle, cfg)\n",
    "            neural_results.append(evaluate_neumf(bundle, model, cfg[\"name\"]))\n",
    "            del model\n",
    "            clear_memory()\n",
    "\n",
    "    neural_results_df = (\n",
    "        pd.DataFrame(neural_results)\n",
    "        .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "else:\n",
    "    neural_results_df = pd.DataFrame()\n",
    "\n",
    "print(\"Top neural results:\")\n",
    "display(neural_results_df.head(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d0e73c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional fusion on the best formulation\n",
    "\n",
    "def make_python_recommenders_for_fusion(bundle, formulation_name, regular_results_df, neural_results_df):\n",
    "    out = {}\n",
    "\n",
    "    best_rp3 = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"RP3beta_\")\n",
    "    best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_binary_\")\n",
    "    if best_ease is None:\n",
    "        best_ease = top_model_name(regular_results_df[regular_results_df[\"Formulation\"] == formulation_name], \"EASE_count_\")\n",
    "    best_tt = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"TwoTower_\") if not neural_results_df.empty else None\n",
    "    best_vae = top_model_name(neural_results_df[neural_results_df[\"Formulation\"] == formulation_name], \"MultVAE_\") if not neural_results_df.empty else None\n",
    "\n",
    "    return best_ease, best_rp3, best_tt, best_vae\n",
    "\n",
    "fusion_results = []\n",
    "\n",
    "if RUN_FUSION and not neural_results_df.empty:\n",
    "    best_formulation = regular_results_df.iloc[0][\"Formulation\"]\n",
    "    bundle = bundles[best_formulation]\n",
    "    print(\"Fusion on formulation:\", best_formulation)\n",
    "\n",
    "    reg_sub = regular_results_df[regular_results_df[\"Formulation\"] == best_formulation]\n",
    "    neu_sub = neural_results_df[neural_results_df[\"Formulation\"] == best_formulation]\n",
    "\n",
    "    ease_name = top_model_name(reg_sub, \"EASE_binary_\")\n",
    "    if ease_name is None:\n",
    "        ease_name = top_model_name(reg_sub, \"EASE_count_\")\n",
    "    rp3_name = top_model_name(reg_sub, \"RP3beta_\")\n",
    "    tt_name = top_model_name(neu_sub, \"TwoTower_\")\n",
    "    vae_name = top_model_name(neu_sub, \"MultVAE_\")\n",
    "\n",
    "    print(\"Best for fusion:\", ease_name, rp3_name, tt_name, vae_name)\n",
    "else:\n",
    "    print(\"Fusion skipped or no neural results.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb491a3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Final comparison tables\n",
    "\n",
    "all_frames = [regular_results_df]\n",
    "if not neural_results_df.empty:\n",
    "    all_frames.append(neural_results_df)\n",
    "\n",
    "all_results_df = (\n",
    "    pd.concat(all_frames, ignore_index=True)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "print(\"Top overall results:\")\n",
    "display(all_results_df.head(30))\n",
    "\n",
    "print(\"Best per formulation:\")\n",
    "best_per_formulation = (\n",
    "    all_results_df.groupby(\"Formulation\")\n",
    "    .head(10)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "display(best_per_formulation)\n",
    "\n",
    "print(\"Best per model family:\")\n",
    "family_df = all_results_df.copy()\n",
    "family_df[\"Family\"] = family_df[\"Model\"].str.extract(r\"^([A-Za-z]+(?:_[A-Za-z]+)?)\")[0]\n",
    "display(\n",
    "    family_df.groupby(\"Family\")[[\"HR@10\", \"HR@20\", \"MRR@10\", \"NDCG@10\"]]\n",
    "    .max()\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\"], ascending=False)\n",
    "    .reset_index()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc27b61",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "This notebook removes the old automatic screening stage that previously caused very long EASE runs and high RAM usage.\n",
    "\n",
    "What is threaded or batched here:\n",
    "- BLAS / LAPACK thread env vars are set before importing NumPy and SciPy\n",
    "- `threadpoolctl` is applied when available\n",
    "- EASE uses a Cholesky-based dense solve\n",
    "- EASE and RP3 evaluation are fully batched\n",
    "- heavy score computation is moved to the GPU when available\n",
    "- neural training uses manual on-device batching instead of multiprocessing DataLoaders\n",
    "\n",
    "If you still need more speed:\n",
    "- lower `TOP_FORMULATIONS_FOR_NEURAL`\n",
    "- lower `TWOTOWER_CONFIGS` / `MULTVAE_CONFIGS`\n",
    "- lower `EVAL_MAX_USERS`\n",
    "- lower `LINEAR_EVAL_BATCH` or `NEURAL_EVAL_BATCH` if GPU memory becomes the bottleneck\n",
    "\n",
    "Additional EASE fit changes in this version:\n",
    "- Gram matrices are built once per formulation and reused across all EASE lambdas\n",
    "- EASE fitting can use GPU Cholesky solve when CUDA is available\n",
    "- fit time and eval time are printed separately so long pre-bar stalls are visible\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
