{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8cd9eae1",
   "metadata": {},
   "source": [
    "# Car recommender notebook — aggressive search version\n",
    "\n",
    "This version is designed to do two things:\n",
    "\n",
    "1. search across **multiple task formulations** to find a stronger recommendation setup;\n",
    "2. train **more GPU-heavy models** so the GPU does much more work than in the previous notebook.\n",
    "\n",
    "Important note:\n",
    "- with the previous formulation there were only **88 items**, which is too small to fully saturate a modern GPU;\n",
    "- this notebook increases the item granularity and uses **large-batch neural training**, **full-softmax classification**, and **in-batch contrastive ranking** to push GPU usage much higher.\n",
    "\n",
    "This notebook intentionally tries **many options**. After it finishes, keep the strongest few models and throw away the rest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f5ca59df",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch device: cuda\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import gc\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset, TensorDataset\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Torch device:\", device)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    try:\n",
    "        torch.set_float32_matmul_precision(\"high\")\n",
    "    except Exception:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3156b716",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Formulations: ['F1_country_price_mileage_cond__brand_model', 'F2_country_price_mileage_cond_age__brand_model_age_cond', 'F3_country_price_mileage_age__brand_model_age_cond', 'F4_country_price_mileage__brand_model_age_cond']\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "# Binning\n",
    "N_PRICE_BINS = 10\n",
    "N_MILEAGE_BINS = 10\n",
    "N_AGE_BINS = 8\n",
    "\n",
    "# Interaction filtering\n",
    "MIN_ITEMS_PER_USER = 5\n",
    "MIN_USERS_PER_ITEM = 10\n",
    "MAX_FILTER_ITERS = 8\n",
    "\n",
    "# Ranking metrics\n",
    "TOP_KS = [5, 10, 20]\n",
    "\n",
    "# Model search grids\n",
    "ITEMKNN_NEIGHBORS_GRID = [50, 100, 150]\n",
    "EASE_LAMBDA_GRID = [50.0, 100.0, 200.0, 500.0, 1000.0]\n",
    "\n",
    "# Neural training defaults\n",
    "SOFTMAX_EPOCHS = 18\n",
    "TWOTOWER_EPOCHS = 18\n",
    "MULTVAE_EPOCHS = 60\n",
    "\n",
    "SOFTMAX_BATCH_SIZE = 8192\n",
    "TWOTOWER_BATCH_SIZE = 4096\n",
    "MULTVAE_BATCH_SIZE = 1024\n",
    "\n",
    "NUM_WORKERS = 4 if os.name != \"nt\" else 0\n",
    "USE_AMP = torch.cuda.is_available()\n",
    "\n",
    "# Candidate task formulations.\n",
    "# Each one defines a pseudo-user profile and an item granularity.\n",
    "FORMULATIONS = {\n",
    "    \"F1_country_price_mileage_cond__brand_model\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\"],\n",
    "    },\n",
    "    \"F2_country_price_mileage_cond_age__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"AgeBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "    \"F3_country_price_mileage_age__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"AgeBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "    \"F4_country_price_mileage__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b60787ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape: (1000000, 15)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513.0</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>1</td>\n",
       "      <td>p2</td>\n",
       "      <td>m2</td>\n",
       "      <td>a0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990.0</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>24</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703.0</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>10</td>\n",
       "      <td>p5</td>\n",
       "      <td>m0</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353.0</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>21</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>a6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854.0</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>12</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20  58513.0   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14  60990.0   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93   1703.0   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94  25353.0  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01  76854.0  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    1   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   24   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   10   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   21   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   12   \n",
       "\n",
       "  PriceBin MileageBin AgeBin  \n",
       "0       p2         m2     a0  \n",
       "1       p0         m3     a7  \n",
       "2       p5         m0     a3  \n",
       "3       p0         m1     a6  \n",
       "4       p0         m3     a3  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Age</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Condition</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>1</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513.0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>p2</td>\n",
       "      <td>m2</td>\n",
       "      <td>a0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>24</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990.0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Italy</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>10</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703.0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>UK</td>\n",
       "      <td>p5</td>\n",
       "      <td>m0</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>21</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353.0</td>\n",
       "      <td>Used</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>a6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>12</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854.0</td>\n",
       "      <td>Used</td>\n",
       "      <td>USA</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year  Age     Price  Mileage            Condition  \\\n",
       "0       Honda        Civic  2023    1  25627.20  58513.0  Certified Pre-Owned   \n",
       "1       Mazda       Mazda3  2000   24  12027.14  60990.0  Certified Pre-Owned   \n",
       "2       Mazda         CX-5  2014   10  49194.93   1703.0  Certified Pre-Owned   \n",
       "3     Hyundai       Tucson  2003   21  11955.94  25353.0                 Used   \n",
       "4  Land Rover  Range Rover  2012   12  10910.01  76854.0                 Used   \n",
       "\n",
       "  Country PriceBin MileageBin AgeBin  \n",
       "0  Brazil       p2         m2     a0  \n",
       "1   Italy       p0         m3     a7  \n",
       "2      UK       p5         m0     a3  \n",
       "3  Mexico       p0         m1     a6  \n",
       "4     USA       p0         m3     a3  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean data\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\", \"First Name\", \"Last Name\", \"Address\"]:\n",
    "    if col in df.columns:\n",
    "        df[col] = df[col].astype(str).str.strip()\n",
    "\n",
    "df[\"Year\"] = df[\"Year\"].astype(int)\n",
    "df[\"Price\"] = df[\"Price\"].astype(float)\n",
    "df[\"Mileage\"] = df[\"Mileage\"].astype(float)\n",
    "\n",
    "# Age from the newest year in the file\n",
    "max_year = int(df[\"Year\"].max())\n",
    "df[\"Age\"] = (max_year - df[\"Year\"]).astype(int)\n",
    "\n",
    "def qbin(series, q, prefix):\n",
    "    cat = pd.qcut(series, q=q, duplicates=\"drop\")\n",
    "    return prefix + cat.cat.codes.astype(str)\n",
    "\n",
    "df[\"PriceBin\"] = qbin(df[\"Price\"], N_PRICE_BINS, \"p\")\n",
    "df[\"MileageBin\"] = qbin(df[\"Mileage\"], N_MILEAGE_BINS, \"m\")\n",
    "df[\"AgeBin\"] = qbin(df[\"Age\"], N_AGE_BINS, \"a\")\n",
    "\n",
    "print(\"Shape:\", df.shape)\n",
    "display(df.head())\n",
    "display(df[[\"Brand\", \"Model\", \"Year\", \"Age\", \"Price\", \"Mileage\", \"Condition\", \"Country\", \"PriceBin\", \"MileageBin\", \"AgeBin\"]].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6739a0dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user=5, min_users_per_item=10, max_iters=8):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_shape = out.shape[0]\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if out.shape[0] == old_shape:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "           .size()\n",
    "           .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = iterative_filter(\n",
    "        interactions,\n",
    "        min_items_per_user=MIN_ITEMS_PER_USER,\n",
    "        min_users_per_item=MIN_USERS_PER_ITEM,\n",
    "        max_iters=MAX_FILTER_ITERS\n",
    "    ).reset_index(drop=True)\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    num_users = interactions[\"user_idx\"].nunique()\n",
    "    num_items = interactions[\"item_idx\"].nunique()\n",
    "\n",
    "    # user/item feature tables for the selected formulation\n",
    "    user_feature_df = (\n",
    "        tmp[[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "\n",
    "    # weighted leave-one-out split\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = np.random.default_rng(RANDOM_STATE + int(uid)).choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(picked)\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    # matrices from train only\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    user_seen = {\n",
    "        int(uid): set(g[\"item_idx\"].astype(int).tolist())\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    user_strength = {\n",
    "        int(uid): {int(i): float(c) for i, c in zip(g[\"item_idx\"], g[\"count\"])}\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(uid): int(i)\n",
    "        for uid, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy()\n",
    "    )\n",
    "\n",
    "    item_ids = item_encoder.classes_.tolist()\n",
    "    user_ids = user_encoder.classes_.tolist()\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df.sort_values(\"user_idx\").reset_index(drop=True),\n",
    "        \"item_feature_df\": item_feature_df.sort_values(\"item_idx\").reset_index(drop=True),\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen\": user_seen,\n",
    "        \"user_strength\": user_strength,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_ids,\n",
    "        \"user_ids\": user_ids,\n",
    "    }\n",
    "\n",
    "def hit_rate_at_k(recs, true_item, k):\n",
    "    return 1.0 if true_item in recs[:k] else 0.0\n",
    "\n",
    "def mrr_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / rank\n",
    "    return 0.0\n",
    "\n",
    "def ndcg_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / math.log2(rank + 1.0)\n",
    "    return 0.0\n",
    "\n",
    "def evaluate_model(model_func, model_name, data_bundle, user_indices=None, ks=(5, 10, 20)):\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(data_bundle[\"test_item_by_user\"].keys()))\n",
    "\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr = 0.0\n",
    "    ndcg = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    for uid in tqdm(user_indices, desc=model_name):\n",
    "        uid = int(uid)\n",
    "        true_item = data_bundle[\"test_item_by_user\"].get(uid)\n",
    "        if true_item is None:\n",
    "            continue\n",
    "\n",
    "        recs = model_func(uid, n=max(ks))\n",
    "        valid += 1\n",
    "        for k in ks:\n",
    "            hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "        mrr += mrr_at_k(recs, true_item, 10)\n",
    "        ndcg += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    row = {\"Model\": model_name}\n",
    "    for k in ks:\n",
    "        row[f\"HR@{k}\"] = hits[k] / valid if valid else 0.0\n",
    "    row[\"MRR@10\"] = mrr / valid if valid else 0.0\n",
    "    row[\"NDCG@10\"] = ndcg / valid if valid else 0.0\n",
    "    return row\n",
    "\n",
    "def topn_from_scores(scores, seen, n):\n",
    "    scores = scores.copy()\n",
    "    if seen:\n",
    "        scores[list(seen)] = -np.inf\n",
    "    n = min(n, len(scores) - len(seen))\n",
    "    if n <= 0:\n",
    "        return []\n",
    "    idx = np.argpartition(scores, -n)[-n:]\n",
    "    idx = idx[np.argsort(scores[idx])[::-1]]\n",
    "    return [int(i) for i in idx]\n",
    "\n",
    "def bm25_weight(X, K1=100, B=0.8):\n",
    "    X = X.tocoo(copy=True).astype(np.float32)\n",
    "    N = float(X.shape[0])\n",
    "\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    avgdl = row_sums.mean() + 1e-8\n",
    "\n",
    "    df = np.bincount(X.col, minlength=X.shape[1]).astype(np.float32)\n",
    "    idf = np.log((N - df + 0.5) / (df + 0.5))\n",
    "    idf = np.maximum(idf, 0)\n",
    "\n",
    "    denom = X.data + K1 * (1 - B + B * row_sums[X.row] / avgdl)\n",
    "    data = X.data * (K1 + 1) / denom * idf[X.col]\n",
    "\n",
    "    return csr_matrix((data, (X.row, X.col)), shape=X.shape, dtype=np.float32)\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Train rows  :\", bundle[\"train_interactions\"].shape)\n",
    "    print(\"Test rows   :\", bundle[\"test_interactions\"].shape)\n",
    "    print(\"Matrix shape:\", bundle[\"X_binary\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54e4aa85",
   "metadata": {},
   "source": [
    "## Quick formulation search\n",
    "\n",
    "This step tries several ways of defining the pseudo-user and the item.\n",
    "The goal is to find a setup that gives a stronger recommendation signal **before** training the heavier neural models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "df30b27b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Formulation: F1_country_price_mileage_cond__brand_model\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition']\n",
      "Item cols  : ['Brand', 'Model']\n",
      "Users      : 3000\n",
      "Items      : 88\n",
      "Train rows  : (254673, 5)\n",
      "Test rows   : (3000, 5)\n",
      "Matrix shape: (3000, 88)\n",
      "====================================================================================================\n",
      "Formulation: F2_country_price_mileage_cond_age__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition', 'AgeBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 24000\n",
      "Items      : 2112\n",
      "Train rows  : (769325, 5)\n",
      "Test rows   : (24000, 5)\n",
      "Matrix shape: (24000, 2112)\n",
      "====================================================================================================\n",
      "Formulation: F3_country_price_mileage_age__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'AgeBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 8000\n",
      "Items      : 2112\n",
      "Train rows  : (785325, 5)\n",
      "Test rows   : (8000, 5)\n",
      "Matrix shape: (8000, 2112)\n",
      "====================================================================================================\n",
      "Formulation: F4_country_price_mileage__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 1000\n",
      "Items      : 2112\n",
      "Train rows  : (792325, 5)\n",
      "Test rows   : (1000, 5)\n",
      "Matrix shape: (1000, 2112)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>Interactions</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>F2_country_price_mileage_cond_age__brand_model...</td>\n",
       "      <td>24000</td>\n",
       "      <td>2112</td>\n",
       "      <td>793325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>F3_country_price_mileage_age__brand_model_age_...</td>\n",
       "      <td>8000</td>\n",
       "      <td>2112</td>\n",
       "      <td>793325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>F4_country_price_mileage__brand_model_age_cond</td>\n",
       "      <td>1000</td>\n",
       "      <td>2112</td>\n",
       "      <td>793325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>F1_country_price_mileage_cond__brand_model</td>\n",
       "      <td>3000</td>\n",
       "      <td>88</td>\n",
       "      <td>257673</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         Formulation  Users  Items  \\\n",
       "0  F2_country_price_mileage_cond_age__brand_model...  24000   2112   \n",
       "1  F3_country_price_mileage_age__brand_model_age_...   8000   2112   \n",
       "2     F4_country_price_mileage__brand_model_age_cond   1000   2112   \n",
       "3         F1_country_price_mileage_cond__brand_model   3000     88   \n",
       "\n",
       "   Interactions  \n",
       "0        793325  \n",
       "1        793325  \n",
       "2        793325  \n",
       "3        257673  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Build all candidate formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"Interactions\": bundle[\"interactions\"].shape[0],\n",
    "    })\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values([\"Items\", \"Users\"], ascending=False).reset_index(drop=True)\n",
    "display(summary_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "74500d09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F1_country_price_mileage_cond__brand_model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F1_country_price_mileage_cond__brand_model / Pop: 100%|██████████| 3000/3000 [00:00<00:00, 132205.39it/s]\n",
      "F1_country_price_mileage_cond__brand_model / EASE200: 100%|██████████| 3000/3000 [00:00<00:00, 25403.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F2_country_price_mileage_cond_age__brand_model_age_cond\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F2_country_price_mileage_cond_age__brand_model_age_cond / Pop: 100%|██████████| 24000/24000 [00:04<00:00, 5516.54it/s]\n",
      "F2_country_price_mileage_cond_age__brand_model_age_cond / EASE200: 100%|██████████| 24000/24000 [00:26<00:00, 889.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F3_country_price_mileage_age__brand_model_age_cond\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F3_country_price_mileage_age__brand_model_age_cond / Pop: 100%|██████████| 8000/8000 [00:01<00:00, 5489.72it/s]\n",
      "F3_country_price_mileage_age__brand_model_age_cond / EASE200: 100%|██████████| 8000/8000 [00:08<00:00, 894.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F4_country_price_mileage__brand_model_age_cond\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F4_country_price_mileage__brand_model_age_cond / Pop: 100%|██████████| 1000/1000 [00:00<00:00, 4907.36it/s]\n",
      "F4_country_price_mileage__brand_model_age_cond / EASE200: 100%|██████████| 1000/1000 [00:01<00:00, 899.58it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>Pop_HR@10</th>\n",
       "      <th>EASE200_HR@10</th>\n",
       "      <th>EASE200_NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>F1_country_price_mileage_cond__brand_model</td>\n",
       "      <td>3000</td>\n",
       "      <td>88</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.751848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>F2_country_price_mileage_cond_age__brand_model...</td>\n",
       "      <td>24000</td>\n",
       "      <td>2112</td>\n",
       "      <td>0.007000</td>\n",
       "      <td>0.201208</td>\n",
       "      <td>0.097243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>F3_country_price_mileage_age__brand_model_age_...</td>\n",
       "      <td>8000</td>\n",
       "      <td>2112</td>\n",
       "      <td>0.007125</td>\n",
       "      <td>0.064375</td>\n",
       "      <td>0.029549</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>F4_country_price_mileage__brand_model_age_cond</td>\n",
       "      <td>1000</td>\n",
       "      <td>2112</td>\n",
       "      <td>0.024000</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>0.004541</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         Formulation  Users  Items  Pop_HR@10  \\\n",
       "0         F1_country_price_mileage_cond__brand_model   3000     88   1.000000   \n",
       "1  F2_country_price_mileage_cond_age__brand_model...  24000   2112   0.007000   \n",
       "2  F3_country_price_mileage_age__brand_model_age_...   8000   2112   0.007125   \n",
       "3     F4_country_price_mileage__brand_model_age_cond   1000   2112   0.024000   \n",
       "\n",
       "   EASE200_HR@10  EASE200_NDCG@10  \n",
       "0       1.000000         0.751848  \n",
       "1       0.201208         0.097243  \n",
       "2       0.064375         0.029549  \n",
       "3       0.010000         0.004541  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected formulation: F1_country_price_mileage_cond__brand_model\n",
      "Formulation: F1_country_price_mileage_cond__brand_model\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition']\n",
      "Item cols  : ['Brand', 'Model']\n",
      "Users      : 3000\n",
      "Items      : 88\n",
      "Train rows  : (254673, 5)\n",
      "Test rows   : (3000, 5)\n",
      "Matrix shape: (3000, 88)\n"
     ]
    }
   ],
   "source": [
    "# Quick screen with two cheap baselines:\n",
    "# 1) global popularity\n",
    "# 2) EASE with lambda=200 on binary interactions\n",
    "# The best formulation by EASE HR@10 becomes the main formulation.\n",
    "\n",
    "screen_rows = []\n",
    "\n",
    "for name, bundle in bundles.items():\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Screening formulation:\", name)\n",
    "\n",
    "    def rec_pop(uid, n=10, bundle=bundle):\n",
    "        seen = bundle[\"user_seen\"].get(uid, set())\n",
    "        return [int(i) for i in bundle[\"global_pop_rank\"] if int(i) not in seen][:n]\n",
    "\n",
    "    pop_res = evaluate_model(rec_pop, f\"{name} / Pop\", bundle, ks=TOP_KS)\n",
    "\n",
    "    Xb = bundle[\"X_binary\"]\n",
    "    G = (Xb.T @ Xb).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += 200.0\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "\n",
    "    def rec_ease(uid, n=10, bundle=bundle, B=B):\n",
    "        scores = bundle[\"X_binary\"].getrow(uid).toarray().ravel() @ B\n",
    "        seen = bundle[\"user_seen\"].get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "\n",
    "    ease_res = evaluate_model(rec_ease, f\"{name} / EASE200\", bundle, ks=TOP_KS)\n",
    "\n",
    "    screen_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"Pop_HR@10\": pop_res[\"HR@10\"],\n",
    "        \"EASE200_HR@10\": ease_res[\"HR@10\"],\n",
    "        \"EASE200_NDCG@10\": ease_res[\"NDCG@10\"],\n",
    "    })\n",
    "\n",
    "screen_df = pd.DataFrame(screen_rows).sort_values(\n",
    "    [\"EASE200_HR@10\", \"EASE200_NDCG@10\", \"Items\"], ascending=False\n",
    ").reset_index(drop=True)\n",
    "\n",
    "display(screen_df)\n",
    "\n",
    "BEST_FORMULATION = screen_df.iloc[0][\"Formulation\"]\n",
    "print(\"Selected formulation:\", BEST_FORMULATION)\n",
    "\n",
    "data = bundles[BEST_FORMULATION]\n",
    "print_bundle_summary(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "014f1da7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>user_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Australia | p0 | m0 | Certified Pre-Owned</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Australia | p0 | m0 | New</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m0</td>\n",
       "      <td>New</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Australia | p0 | m0 | Used</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m0</td>\n",
       "      <td>Used</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Australia | p0 | m1 | Certified Pre-Owned</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Australia | p0 | m1 | New</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>New</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                     user_id    Country PriceBin MileageBin  \\\n",
       "0  Australia | p0 | m0 | Certified Pre-Owned  Australia       p0         m0   \n",
       "1                  Australia | p0 | m0 | New  Australia       p0         m0   \n",
       "2                 Australia | p0 | m0 | Used  Australia       p0         m0   \n",
       "3  Australia | p0 | m1 | Certified Pre-Owned  Australia       p0         m1   \n",
       "4                  Australia | p0 | m1 | New  Australia       p0         m1   \n",
       "\n",
       "             Condition  user_idx  \n",
       "0  Certified Pre-Owned         0  \n",
       "1                  New         1  \n",
       "2                 Used         2  \n",
       "3  Certified Pre-Owned         3  \n",
       "4                  New         4  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>item_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Audi :: A3</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Audi :: A4</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A4</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Audi :: A6</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A6</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Audi :: Q5</td>\n",
       "      <td>Audi</td>\n",
       "      <td>Q5</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Audi :: Q7</td>\n",
       "      <td>Audi</td>\n",
       "      <td>Q7</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      item_id Brand Model  item_idx\n",
       "0  Audi :: A3  Audi    A3         0\n",
       "1  Audi :: A4  Audi    A4         1\n",
       "2  Audi :: A6  Audi    A6         2\n",
       "3  Audi :: Q5  Audi    Q5         3\n",
       "4  Audi :: Q7  Audi    Q7         4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Prepare reusable arrays and feature tables for the selected formulation\n",
    "\n",
    "user_feature_df = data[\"user_feature_df\"].copy()\n",
    "item_feature_df = data[\"item_feature_df\"].copy()\n",
    "\n",
    "train_interactions = data[\"train_interactions\"].copy()\n",
    "test_interactions = data[\"test_interactions\"].copy()\n",
    "\n",
    "X_counts = data[\"X_counts\"].tocsr()\n",
    "X_binary = data[\"X_binary\"].tocsr()\n",
    "\n",
    "num_users = data[\"num_users\"]\n",
    "num_items = data[\"num_items\"]\n",
    "\n",
    "user_seen = data[\"user_seen\"]\n",
    "user_strength = data[\"user_strength\"]\n",
    "test_item_by_user = data[\"test_item_by_user\"]\n",
    "global_pop_rank = data[\"global_pop_rank\"]\n",
    "item_ids = data[\"item_ids\"]\n",
    "\n",
    "# Dense versions only for models that need them\n",
    "X_binary_dense = X_binary.toarray().astype(np.float32)\n",
    "X_counts_dense = X_counts.toarray().astype(np.float32)\n",
    "\n",
    "display(user_feature_df.head())\n",
    "display(item_feature_df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a7b336d",
   "metadata": {},
   "source": [
    "## Regular baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8c765ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 1: global popularity\n",
    "\n",
    "def recommend_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(int(user_idx), set())\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1bcb91ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 2: country-aware popularity baseline\n",
    "\n",
    "country_col = \"Country\" if \"Country\" in user_feature_df.columns else None\n",
    "user_country = {}\n",
    "\n",
    "if country_col is not None:\n",
    "    user_country = dict(zip(user_feature_df[\"user_idx\"], user_feature_df[country_col]))\n",
    "\n",
    "    train_with_country = train_interactions.merge(\n",
    "        user_feature_df[[\"user_idx\", country_col]],\n",
    "        on=\"user_idx\",\n",
    "        how=\"left\"\n",
    "    )\n",
    "\n",
    "    country_pop_rank = {}\n",
    "    for country, g in train_with_country.groupby(country_col):\n",
    "        country_pop_rank[country] = (\n",
    "            g.groupby(\"item_idx\")[\"count\"]\n",
    "             .sum()\n",
    "             .sort_values(ascending=False)\n",
    "             .index.to_numpy()\n",
    "        )\n",
    "else:\n",
    "    country_pop_rank = {}\n",
    "\n",
    "def recommend_country_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(int(user_idx), set())\n",
    "    country = user_country.get(int(user_idx), None)\n",
    "\n",
    "    if country in country_pop_rank:\n",
    "        recs = [int(i) for i in country_pop_rank[country] if int(i) not in seen][:n]\n",
    "        if len(recs) >= n:\n",
    "            return recs\n",
    "\n",
    "    # fallback\n",
    "    out = [int(i) for i in global_pop_rank if int(i) not in seen][:n]\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "114bafa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 3: ItemKNN on binary interactions\n",
    "\n",
    "def fit_itemknn_recommender(X_matrix, neighbors=100):\n",
    "    item_user = X_matrix.T.tocsr()\n",
    "\n",
    "    knn = NearestNeighbors(\n",
    "        metric=\"cosine\",\n",
    "        algorithm=\"brute\",\n",
    "        n_neighbors=min(neighbors, item_user.shape[0]),\n",
    "        n_jobs=-1\n",
    "    )\n",
    "    knn.fit(item_user)\n",
    "\n",
    "    distances, indices = knn.kneighbors(item_user, n_neighbors=min(neighbors, item_user.shape[0]))\n",
    "    similarities = (1.0 - distances).astype(np.float32)\n",
    "    return indices, similarities\n",
    "\n",
    "def make_itemknn_recommender(X_matrix, user_seen_dict, user_strength_dict, neighbors=100, use_strength=True):\n",
    "    neighbor_idx, neighbor_sim = fit_itemknn_recommender(X_matrix, neighbors=neighbors)\n",
    "\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        seen = user_seen_dict.get(uid, set())\n",
    "        strength_map = user_strength_dict.get(uid, {})\n",
    "        scores = {}\n",
    "\n",
    "        for item_idx in seen:\n",
    "            weight = float(strength_map.get(item_idx, 1.0)) if use_strength else 1.0\n",
    "            nbrs = neighbor_idx[item_idx]\n",
    "            sims = neighbor_sim[item_idx]\n",
    "\n",
    "            for j, sim in zip(nbrs, sims):\n",
    "                j = int(j)\n",
    "                if j == item_idx or j in seen:\n",
    "                    continue\n",
    "                scores[j] = scores.get(j, 0.0) + float(sim) * weight\n",
    "\n",
    "        if not scores:\n",
    "            return recommend_popularity(uid, n=n)\n",
    "\n",
    "        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:n]\n",
    "        return [int(i) for i, _ in ranked]\n",
    "\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5c57d7fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 4: ItemKNN with BM25 weighting\n",
    "\n",
    "X_bm25 = bm25_weight(X_counts, K1=100, B=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "74f0ff1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 5: EASE search\n",
    "# We try both binary and count-weighted inputs and keep all runs in the results table.\n",
    "\n",
    "def fit_ease(X_matrix, lam):\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += lam\n",
    "\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    return B.astype(np.float32)\n",
    "\n",
    "def make_ease_recommender(X_train_dense, B_matrix, user_seen_dict):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_train_dense[uid] @ B_matrix\n",
    "        seen = user_seen_dict.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acf6838c",
   "metadata": {},
   "source": [
    "## Neural models with heavier GPU usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f525bb97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "User feature columns: ['Country', 'PriceBin', 'MileageBin', 'Condition']\n",
      "Item feature columns: ['Brand', 'Model']\n",
      "Country user cardinality: 10\n",
      "PriceBin user cardinality: 10\n",
      "MileageBin user cardinality: 10\n",
      "Condition user cardinality: 3\n",
      "Brand item cardinality: 18\n",
      "Model item cardinality: 88\n"
     ]
    }
   ],
   "source": [
    "# Feature encoders for neural models\n",
    "\n",
    "user_cols = data[\"user_cols\"]\n",
    "item_cols = data[\"item_cols\"]\n",
    "\n",
    "feature_encoders = {}\n",
    "for col in sorted(set(user_cols + item_cols)):\n",
    "    enc = LabelEncoder()\n",
    "    values = pd.concat([\n",
    "        user_feature_df[col].astype(str) if col in user_feature_df.columns else pd.Series(dtype=str),\n",
    "        item_feature_df[col].astype(str) if col in item_feature_df.columns else pd.Series(dtype=str),\n",
    "    ], ignore_index=True)\n",
    "    enc.fit(values)\n",
    "    feature_encoders[col] = enc\n",
    "\n",
    "user_feature_arrays = {}\n",
    "for col in user_cols:\n",
    "    user_feature_arrays[col] = feature_encoders[col].transform(user_feature_df[col].astype(str))\n",
    "\n",
    "item_feature_arrays = {}\n",
    "for col in item_cols:\n",
    "    item_feature_arrays[col] = feature_encoders[col].transform(item_feature_df[col].astype(str))\n",
    "\n",
    "user_feature_tensors = {\n",
    "    col: torch.tensor(arr, dtype=torch.long)\n",
    "    for col, arr in user_feature_arrays.items()\n",
    "}\n",
    "item_feature_tensors = {\n",
    "    col: torch.tensor(arr, dtype=torch.long)\n",
    "    for col, arr in item_feature_arrays.items()\n",
    "}\n",
    "\n",
    "print(\"User feature columns:\", user_cols)\n",
    "print(\"Item feature columns:\", item_cols)\n",
    "for col in user_cols:\n",
    "    print(col, \"user cardinality:\", len(feature_encoders[col].classes_))\n",
    "for col in item_cols:\n",
    "    print(col, \"item cardinality:\", len(feature_encoders[col].classes_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "25793118",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive rows: 254673\n"
     ]
    }
   ],
   "source": [
    "# Training datasets\n",
    "\n",
    "class WeightedPositiveDataset(Dataset):\n",
    "    # one row per observed (user, item) pair with its count as weight\n",
    "    def __init__(self, interactions_df):\n",
    "        self.user_idx = torch.tensor(interactions_df[\"user_idx\"].to_numpy(), dtype=torch.long)\n",
    "        self.item_idx = torch.tensor(interactions_df[\"item_idx\"].to_numpy(), dtype=torch.long)\n",
    "        self.weight = torch.tensor(interactions_df[\"count\"].astype(np.float32).to_numpy(), dtype=torch.float32)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.user_idx)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.user_idx[idx], self.item_idx[idx], self.weight[idx]\n",
    "\n",
    "class UserDenseDataset(Dataset):\n",
    "    def __init__(self, dense_matrix):\n",
    "        self.X = torch.tensor(dense_matrix, dtype=torch.float32)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.X.shape[0]\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], idx\n",
    "\n",
    "positive_dataset = WeightedPositiveDataset(train_interactions)\n",
    "positive_loader = DataLoader(\n",
    "    positive_dataset,\n",
    "    batch_size=SOFTMAX_BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    pin_memory=torch.cuda.is_available(),\n",
    "    persistent_workers=(NUM_WORKERS > 0)\n",
    ")\n",
    "\n",
    "positive_loader_twotower = DataLoader(\n",
    "    positive_dataset,\n",
    "    batch_size=TWOTOWER_BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    pin_memory=torch.cuda.is_available(),\n",
    "    persistent_workers=(NUM_WORKERS > 0)\n",
    ")\n",
    "\n",
    "user_dense_dataset = UserDenseDataset(X_binary_dense)\n",
    "user_dense_loader = DataLoader(\n",
    "    user_dense_dataset,\n",
    "    batch_size=MULTVAE_BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    pin_memory=torch.cuda.is_available(),\n",
    "    persistent_workers=(NUM_WORKERS > 0)\n",
    ")\n",
    "\n",
    "print(\"Positive rows:\", len(positive_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0c256a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural model 1: full-softmax feature MLP\n",
    "# This is deliberately GPU-friendly: large batches and a full item softmax.\n",
    "\n",
    "class FeatureEmbeddingBlock(nn.Module):\n",
    "    def __init__(self, feature_cardinalities, emb_dim):\n",
    "        super().__init__()\n",
    "        self.feature_names = list(feature_cardinalities.keys())\n",
    "        self.embs = nn.ModuleDict({\n",
    "            name: nn.Embedding(card, emb_dim)\n",
    "            for name, card in feature_cardinalities.items()\n",
    "        })\n",
    "\n",
    "    def forward(self, feature_batch):\n",
    "        parts = [self.embs[name](feature_batch[name]) for name in self.feature_names]\n",
    "        return torch.cat(parts, dim=-1)\n",
    "\n",
    "class SoftmaxSegmentMLP(nn.Module):\n",
    "    def __init__(self, feature_cardinalities, num_items, emb_dim=64, hidden_dims=(512, 512, 256), dropout=0.15):\n",
    "        super().__init__()\n",
    "        self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "        input_dim = len(feature_cardinalities) * emb_dim\n",
    "\n",
    "        layers = []\n",
    "        last = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]\n",
    "            last = h\n",
    "        self.mlp = nn.Sequential(*layers)\n",
    "        self.out = nn.Linear(last, num_items)\n",
    "\n",
    "    def forward(self, feature_batch):\n",
    "        x = self.encoder(feature_batch)\n",
    "        x = self.mlp(x)\n",
    "        return self.out(x)\n",
    "\n",
    "def make_user_feature_batch(user_idx_batch):\n",
    "    batch = {}\n",
    "    idx_cpu = user_idx_batch.detach().cpu()\n",
    "    for col in user_cols:\n",
    "        batch[col] = user_feature_tensors[col][idx_cpu].to(device, non_blocking=True)\n",
    "    return batch\n",
    "\n",
    "def train_softmax_model(model_name, emb_dim=64, hidden_dims=(512, 512, 256), epochs=18, lr=2e-3, wd=1e-5):\n",
    "    feature_cardinalities = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "    model = SoftmaxSegmentMLP(\n",
    "        feature_cardinalities=feature_cardinalities,\n",
    "        num_items=num_items,\n",
    "        emb_dim=emb_dim,\n",
    "        hidden_dims=hidden_dims,\n",
    "        dropout=0.15,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "    criterion = nn.CrossEntropyLoss(reduction=\"none\")\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        for user_idx_batch, item_idx_batch, weight_batch in positive_loader:\n",
    "            user_idx_batch = user_idx_batch.to(device, non_blocking=True)\n",
    "            item_idx_batch = item_idx_batch.to(device, non_blocking=True)\n",
    "            weight_batch = weight_batch.to(device, non_blocking=True)\n",
    "\n",
    "            feat_batch = make_user_feature_batch(user_idx_batch)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                logits = model(feat_batch)\n",
    "                loss_vec = criterion(logits, item_idx_batch)\n",
    "                loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "            total_w += float(weight_batch.sum().item())\n",
    "\n",
    "        print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def make_softmax_recommender(model):\n",
    "    @torch.no_grad()\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "        feat_batch = make_user_feature_batch(u)\n",
    "        logits = model(feat_batch).squeeze(0).detach().cpu().numpy()\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(logits, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a15df0f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural model 2: feature two-tower with in-batch negatives\n",
    "# This typically drives GPU usage much better than pointwise negative sampling.\n",
    "\n",
    "class TowerEncoder(nn.Module):\n",
    "    def __init__(self, feature_cardinalities, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "        input_dim = len(feature_cardinalities) * emb_dim\n",
    "\n",
    "        layers = []\n",
    "        last = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]\n",
    "            last = h\n",
    "        layers += [nn.Linear(last, out_dim)]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, feature_batch):\n",
    "        x = self.encoder(feature_batch)\n",
    "        x = self.net(x)\n",
    "        x = F.normalize(x, dim=-1)\n",
    "        return x\n",
    "\n",
    "class TwoTowerModel(nn.Module):\n",
    "    def __init__(self, user_feature_cards, item_feature_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "        super().__init__()\n",
    "        self.user_tower = TowerEncoder(user_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "        self.item_tower = TowerEncoder(item_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "    def forward(self, user_batch, item_batch):\n",
    "        u = self.user_tower(user_batch)\n",
    "        i = self.item_tower(item_batch)\n",
    "        return u, i\n",
    "\n",
    "def make_item_feature_batch(item_idx_batch):\n",
    "    batch = {}\n",
    "    idx_cpu = item_idx_batch.detach().cpu()\n",
    "    for col in item_cols:\n",
    "        batch[col] = item_feature_tensors[col][idx_cpu].to(device, non_blocking=True)\n",
    "    return batch\n",
    "\n",
    "def train_two_tower(model_name, emb_dim=64, hidden_dims=(512, 256), out_dim=128, epochs=18, lr=2e-3, wd=1e-5, temperature=0.07):\n",
    "    user_feature_cards = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "    item_feature_cards = {col: len(feature_encoders[col].classes_) for col in item_cols}\n",
    "\n",
    "    model = TwoTowerModel(\n",
    "        user_feature_cards=user_feature_cards,\n",
    "        item_feature_cards=item_feature_cards,\n",
    "        emb_dim=emb_dim,\n",
    "        hidden_dims=hidden_dims,\n",
    "        out_dim=out_dim,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        for user_idx_batch, item_idx_batch, weight_batch in positive_loader_twotower:\n",
    "            user_idx_batch = user_idx_batch.to(device, non_blocking=True)\n",
    "            item_idx_batch = item_idx_batch.to(device, non_blocking=True)\n",
    "            weight_batch = weight_batch.to(device, non_blocking=True)\n",
    "\n",
    "            user_batch = make_user_feature_batch(user_idx_batch)\n",
    "            item_batch = make_item_feature_batch(item_idx_batch)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                user_vec, item_vec = model(user_batch, item_batch)\n",
    "                logits = (user_vec @ item_vec.T) / temperature\n",
    "                targets = torch.arange(logits.shape[0], device=device)\n",
    "                loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "            total_w += float(weight_batch.sum().item())\n",
    "\n",
    "        print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def make_two_tower_recommender(model):\n",
    "    # precompute all item embeddings once\n",
    "    model.eval()\n",
    "    all_item_idx = torch.arange(num_items, dtype=torch.long)\n",
    "    item_matrix = []\n",
    "    with torch.no_grad():\n",
    "        for start in range(0, num_items, 4096):\n",
    "            idx = all_item_idx[start:start+4096].to(device)\n",
    "            batch = make_item_feature_batch(idx)\n",
    "            vec = model.item_tower(batch)\n",
    "            item_matrix.append(vec.detach().cpu())\n",
    "    item_matrix = torch.cat(item_matrix, dim=0).to(device)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "        user_batch = make_user_feature_batch(u)\n",
    "        user_vec = model.user_tower(user_batch)\n",
    "        scores = (user_vec @ item_matrix.T).squeeze(0).detach().cpu().numpy()\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1bb5a7ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural model 3: larger MultVAE\n",
    "# This stays close to recommender-system literature and uses the whole user-item vector as input.\n",
    "\n",
    "class MultVAE(nn.Module):\n",
    "    def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(num_items, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(latent_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, num_items),\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = self.encoder(self.dropout(x))\n",
    "        return self.mu(h), self.logvar(h)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        if self.training:\n",
    "            std = torch.exp(0.5 * logvar)\n",
    "            eps = torch.randn_like(std)\n",
    "            return mu + eps * std\n",
    "        return mu\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        logits = self.decoder(z)\n",
    "        return logits, mu, logvar\n",
    "\n",
    "def train_multvae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.3, epochs=60, lr=1e-3, wd=0.0, anneal_cap=1.0):\n",
    "    model = MultVAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    total_steps = max(1, epochs * len(user_dense_loader))\n",
    "    step = 0\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        total_loss = 0.0\n",
    "        total_rows = 0\n",
    "\n",
    "        for batch_x, _ in user_dense_loader:\n",
    "            batch_x = batch_x.to(device, non_blocking=True)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "\n",
    "            anneal = min(anneal_cap, step / max(total_steps * 0.3, 1))\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                logits, mu, logvar = model(batch_x)\n",
    "                recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                loss = (recon + anneal * kl).mean()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "            total_rows += batch_x.shape[0]\n",
    "            step += 1\n",
    "\n",
    "        if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f} - anneal: {anneal:.3f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "def make_multvae_recommender(model):\n",
    "    @torch.no_grad()\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        x = torch.tensor(X_binary_dense[uid:uid+1], dtype=torch.float32, device=device)\n",
    "        logits, _, _ = model(x)\n",
    "        scores = logits.squeeze(0).detach().cpu().numpy()\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7181984",
   "metadata": {},
   "source": [
    "## Train the model zoo\n",
    "\n",
    "This section intentionally tries **many variants**.\n",
    "If you leave the notebook running, it should finish with a large comparison table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e711dcc4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: Popularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Popularity: 100%|██████████| 3000/3000 [00:00<00:00, 107553.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: CountryPopularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "CountryPopularity: 100%|██████████| 3000/3000 [00:00<00:00, 73559.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN neighbors=50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "ItemKNN_binary_k50: 100%|██████████| 3000/3000 [00:01<00:00, 2280.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN neighbors=100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_binary_k100: 100%|██████████| 3000/3000 [00:02<00:00, 1397.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN neighbors=150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_binary_k150: 100%|██████████| 3000/3000 [00:02<00:00, 1401.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k50: 100%|██████████| 3000/3000 [00:01<00:00, 2290.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k100: 100%|██████████| 3000/3000 [00:02<00:00, 1397.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k150: 100%|██████████| 3000/3000 [00:02<00:00, 1405.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=50.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l50: 100%|██████████| 3000/3000 [00:00<00:00, 64795.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=50.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l50: 100%|██████████| 3000/3000 [00:00<00:00, 67335.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l100: 100%|██████████| 3000/3000 [00:00<00:00, 68073.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l100: 100%|██████████| 3000/3000 [00:00<00:00, 66730.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l200: 100%|██████████| 3000/3000 [00:00<00:00, 67209.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l200: 100%|██████████| 3000/3000 [00:00<00:00, 69367.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=500.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l500: 100%|██████████| 3000/3000 [00:00<00:00, 68853.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=500.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l500: 100%|██████████| 3000/3000 [00:00<00:00, 68537.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l1000: 100%|██████████| 3000/3000 [00:00<00:00, 69250.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l1000: 100%|██████████| 3000/3000 [00:00<00:00, 54954.89it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>EASE_binary_l500</td>\n",
       "      <td>0.987000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.671896</td>\n",
       "      <td>0.754519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ItemKNN_binary_k100</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ItemKNN_binary_k150</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>0.989333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670471</td>\n",
       "      <td>0.753456</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>EASE_binary_l200</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.668304</td>\n",
       "      <td>0.751848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>EASE_binary_l50</td>\n",
       "      <td>0.985333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.665105</td>\n",
       "      <td>0.749452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>EASE_binary_l100</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.665007</td>\n",
       "      <td>0.749386</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>EASE_count_l100</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663710</td>\n",
       "      <td>0.748281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>EASE_count_l50</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663693</td>\n",
       "      <td>0.748266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663483</td>\n",
       "      <td>0.748141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>CountryPopularity</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663483</td>\n",
       "      <td>0.748141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>EASE_count_l200</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663510</td>\n",
       "      <td>0.748131</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>EASE_count_l500</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663443</td>\n",
       "      <td>0.748078</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.662649</td>\n",
       "      <td>0.747472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>ItemKNN_bm25_k100</td>\n",
       "      <td>0.991333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.651311</td>\n",
       "      <td>0.739278</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>ItemKNN_bm25_k150</td>\n",
       "      <td>0.991333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.651311</td>\n",
       "      <td>0.739278</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>ItemKNN_binary_k50</td>\n",
       "      <td>0.856333</td>\n",
       "      <td>0.858000</td>\n",
       "      <td>0.858000</td>\n",
       "      <td>0.620887</td>\n",
       "      <td>0.681031</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>ItemKNN_bm25_k50</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.546567</td>\n",
       "      <td>0.587167</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model      HR@5     HR@10     HR@20    MRR@10   NDCG@10\n",
       "0      EASE_binary_l500  0.987000  1.000000  1.000000  0.671896  0.754519\n",
       "1   ItemKNN_binary_k100  0.992000  1.000000  1.000000  0.670833  0.753761\n",
       "2   ItemKNN_binary_k150  0.992000  1.000000  1.000000  0.670833  0.753761\n",
       "3     EASE_binary_l1000  0.989333  1.000000  1.000000  0.670471  0.753456\n",
       "4      EASE_binary_l200  0.986000  1.000000  1.000000  0.668304  0.751848\n",
       "5       EASE_binary_l50  0.985333  1.000000  1.000000  0.665105  0.749452\n",
       "6      EASE_binary_l100  0.986000  1.000000  1.000000  0.665007  0.749386\n",
       "7       EASE_count_l100  0.988667  1.000000  1.000000  0.663710  0.748281\n",
       "8        EASE_count_l50  0.988667  1.000000  1.000000  0.663693  0.748266\n",
       "9            Popularity  0.986000  1.000000  1.000000  0.663483  0.748141\n",
       "10    CountryPopularity  0.986000  1.000000  1.000000  0.663483  0.748141\n",
       "11      EASE_count_l200  0.989000  1.000000  1.000000  0.663510  0.748131\n",
       "12      EASE_count_l500  0.989000  1.000000  1.000000  0.663443  0.748078\n",
       "13     EASE_count_l1000  0.988667  1.000000  1.000000  0.662649  0.747472\n",
       "14    ItemKNN_bm25_k100  0.991333  1.000000  1.000000  0.651311  0.739278\n",
       "15    ItemKNN_bm25_k150  0.991333  1.000000  1.000000  0.651311  0.739278\n",
       "16   ItemKNN_binary_k50  0.856333  0.858000  0.858000  0.620887  0.681031\n",
       "17     ItemKNN_bm25_k50  0.704667  0.704667  0.704667  0.546567  0.587167"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit and evaluate regular models\n",
    "\n",
    "regular_results = []\n",
    "regular_models = {}\n",
    "\n",
    "eval_users = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: Popularity\")\n",
    "regular_results.append(evaluate_model(recommend_popularity, \"Popularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "regular_models[\"Popularity\"] = recommend_popularity\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: CountryPopularity\")\n",
    "regular_results.append(evaluate_model(recommend_country_popularity, \"CountryPopularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "regular_models[\"CountryPopularity\"] = recommend_country_popularity\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_binary, user_seen, user_strength, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_binary_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN BM25 neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_bm25, user_seen, user_strength, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_bm25_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for lam in EASE_LAMBDA_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting EASE binary lambda={lam}\")\n",
    "    B_bin = fit_ease(X_binary, lam=lam)\n",
    "    rec_bin = make_ease_recommender(X_binary_dense, B_bin, user_seen)\n",
    "    name_bin = f\"EASE_binary_l{int(lam)}\"\n",
    "    regular_models[name_bin] = rec_bin\n",
    "    regular_results.append(evaluate_model(rec_bin, name_bin, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(f\"Fitting EASE count lambda={lam}\")\n",
    "    B_cnt = fit_ease(X_counts, lam=lam)\n",
    "    rec_cnt = make_ease_recommender(X_counts_dense, B_cnt, user_seen)\n",
    "    name_cnt = f\"EASE_count_l{int(lam)}\"\n",
    "    regular_models[name_cnt] = rec_cnt\n",
    "    regular_results.append(evaluate_model(rec_cnt, name_cnt, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "regular_results_df = pd.DataFrame(regular_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "display(regular_results_df.head(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1f125d6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training Softmax MLP base\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m344\u001b[0m, in \u001b[35mmain\u001b[0m\n",
      "    code = _serve_one(child_r, fds,\n",
      "                      unused_fds,\n",
      "                      old_handlers)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m384\u001b[0m, in \u001b[35m_serve_one\u001b[0m\n",
      "    code = spawn._main(child_r, parent_sentinel)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/spawn.py\"\u001b[0m, line \u001b[35m132\u001b[0m, in \u001b[35m_main\u001b[0m\n",
      "    self = reduction.pickle.load(from_parent)\n",
      "\u001b[1;35mAttributeError\u001b[0m: \u001b[35mmodule '__main__' has no attribute 'WeightedPositiveDataset'\u001b[0m\n",
      "Traceback (most recent call last):\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m344\u001b[0m, in \u001b[35mmain\u001b[0m\n",
      "    code = _serve_one(child_r, fds,\n",
      "                      unused_fds,\n",
      "                      old_handlers)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m384\u001b[0m, in \u001b[35m_serve_one\u001b[0m\n",
      "    code = spawn._main(child_r, parent_sentinel)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/spawn.py\"\u001b[0m, line \u001b[35m132\u001b[0m, in \u001b[35m_main\u001b[0m\n",
      "    self = reduction.pickle.load(from_parent)\n",
      "\u001b[1;35mAttributeError\u001b[0m: \u001b[35mmodule '__main__' has no attribute 'WeightedPositiveDataset'\u001b[0m\n",
      "Traceback (most recent call last):\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m344\u001b[0m, in \u001b[35mmain\u001b[0m\n",
      "    code = _serve_one(child_r, fds,\n",
      "                      unused_fds,\n",
      "                      old_handlers)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m384\u001b[0m, in \u001b[35m_serve_one\u001b[0m\n",
      "    code = spawn._main(child_r, parent_sentinel)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/spawn.py\"\u001b[0m, line \u001b[35m132\u001b[0m, in \u001b[35m_main\u001b[0m\n",
      "    self = reduction.pickle.load(from_parent)\n",
      "\u001b[1;35mAttributeError\u001b[0m: \u001b[35mmodule '__main__' has no attribute 'WeightedPositiveDataset'\u001b[0m\n",
      "Traceback (most recent call last):\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m344\u001b[0m, in \u001b[35mmain\u001b[0m\n",
      "    code = _serve_one(child_r, fds,\n",
      "                      unused_fds,\n",
      "                      old_handlers)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/forkserver.py\"\u001b[0m, line \u001b[35m384\u001b[0m, in \u001b[35m_serve_one\u001b[0m\n",
      "    code = spawn._main(child_r, parent_sentinel)\n",
      "  File \u001b[35m\"/usr/lib/python3.14/multiprocessing/spawn.py\"\u001b[0m, line \u001b[35m132\u001b[0m, in \u001b[35m_main\u001b[0m\n",
      "    self = reduction.pickle.load(from_parent)\n",
      "\u001b[1;35mAttributeError\u001b[0m: \u001b[35mmodule '__main__' has no attribute 'WeightedPositiveDataset'\u001b[0m\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "DataLoader worker (pid(s) 113038, 113039, 113040, 113041) exited unexpectedly",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mEmpty\u001b[39m                                     Traceback (most recent call last)",
      "\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.14/site-packages/torch/utils/data/dataloader.py:1310\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._try_get_data\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m   1309\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1310\u001b[39m     data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_data_queue\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1311\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.14/queue.py:209\u001b[39m, in \u001b[36mQueue.get\u001b[39m\u001b[34m(self, block, timeout)\u001b[39m\n\u001b[32m    208\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m remaining <= \u001b[32m0.0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m209\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[32m    210\u001b[39m \u001b[38;5;28mself\u001b[39m.not_empty.wait(remaining)\n",
      "\u001b[31mEmpty\u001b[39m: ",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[31mRuntimeError\u001b[39m                              Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[19]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m      5\u001b[39m neural_models = {}\n\u001b[32m      6\u001b[39m \n\u001b[32m      7\u001b[39m print(\u001b[33m\"=\"\u001b[39m * \u001b[32m100\u001b[39m)\n\u001b[32m      8\u001b[39m print(\u001b[33m\"Training Softmax MLP base\"\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m softmax_base = train_softmax_model(\n\u001b[32m     10\u001b[39m     model_name=\u001b[33m\"SoftmaxMLP_base\"\u001b[39m,\n\u001b[32m     11\u001b[39m     emb_dim=\u001b[32m64\u001b[39m,\n\u001b[32m     12\u001b[39m     hidden_dims=(\u001b[32m512\u001b[39m, \u001b[32m512\u001b[39m, \u001b[32m256\u001b[39m),\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[15]\u001b[39m\u001b[32m, line 62\u001b[39m, in \u001b[36mtrain_softmax_model\u001b[39m\u001b[34m(model_name, emb_dim, hidden_dims, epochs, lr, wd)\u001b[39m\n\u001b[32m     58\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;28;01min\u001b[39;00m range(epochs):\n\u001b[32m     59\u001b[39m         total_loss = \u001b[32m0.0\u001b[39m\n\u001b[32m     60\u001b[39m         total_w = \u001b[32m0.0\u001b[39m\n\u001b[32m     61\u001b[39m \n\u001b[32m---> \u001b[39m\u001b[32m62\u001b[39m         \u001b[38;5;28;01mfor\u001b[39;00m user_idx_batch, item_idx_batch, weight_batch \u001b[38;5;28;01min\u001b[39;00m positive_loader:\n\u001b[32m     63\u001b[39m             user_idx_batch = user_idx_batch.to(device, non_blocking=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m     64\u001b[39m             item_idx_batch = item_idx_batch.to(device, non_blocking=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m     65\u001b[39m             weight_batch = weight_batch.to(device, non_blocking=\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.14/site-packages/torch/utils/data/dataloader.py:741\u001b[39m, in \u001b[36m_BaseDataLoaderIter.__next__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m    738\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    739\u001b[39m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[32m    740\u001b[39m     \u001b[38;5;28mself\u001b[39m._reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m741\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    742\u001b[39m \u001b[38;5;28mself\u001b[39m._num_yielded += \u001b[32m1\u001b[39m\n\u001b[32m    743\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m    744\u001b[39m     \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable\n\u001b[32m    745\u001b[39m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m    746\u001b[39m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._num_yielded > \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called\n\u001b[32m    747\u001b[39m ):\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.14/site-packages/torch/utils/data/dataloader.py:1524\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._next_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m   1520\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._shutdown \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._tasks_outstanding <= \u001b[32m0\u001b[39m:\n\u001b[32m   1521\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\n\u001b[32m   1522\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33mInvalid iterator state: shutdown or no outstanding tasks when fetching next data\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m   1523\u001b[39m     )\n\u001b[32m-> \u001b[39m\u001b[32m1524\u001b[39m idx, data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1525\u001b[39m \u001b[38;5;28mself\u001b[39m._tasks_outstanding -= \u001b[32m1\u001b[39m\n\u001b[32m   1526\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable:\n\u001b[32m   1527\u001b[39m     \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.14/site-packages/torch/utils/data/dataloader.py:1473\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._get_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m   1471\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._pin_memory:\n\u001b[32m   1472\u001b[39m     \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m._pin_memory_thread.is_alive():\n\u001b[32m-> \u001b[39m\u001b[32m1473\u001b[39m         success, data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1474\u001b[39m         \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[32m   1475\u001b[39m             \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.14/site-packages/torch/utils/data/dataloader.py:1323\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._try_get_data\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m   1321\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) > \u001b[32m0\u001b[39m:\n\u001b[32m   1322\u001b[39m     pids_str = \u001b[33m\"\u001b[39m\u001b[33m, \u001b[39m\u001b[33m\"\u001b[39m.join(\u001b[38;5;28mstr\u001b[39m(w.pid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[32m-> \u001b[39m\u001b[32m1323\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m   1324\u001b[39m         \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m) exited unexpectedly\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m   1325\u001b[39m     ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m   1326\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue.Empty):\n\u001b[32m   1327\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
      "\u001b[31mRuntimeError\u001b[39m: DataLoader worker (pid(s) 113038, 113039, 113040, 113041) exited unexpectedly"
     ]
    }
   ],
   "source": [
    "# Fit and evaluate neural models\n",
    "# These models are intentionally larger than before.\n",
    "\n",
    "neural_results = []\n",
    "neural_models = {}\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training Softmax MLP base\")\n",
    "softmax_base = train_softmax_model(\n",
    "    model_name=\"SoftmaxMLP_base\",\n",
    "    emb_dim=64,\n",
    "    hidden_dims=(512, 512, 256),\n",
    "    epochs=SOFTMAX_EPOCHS,\n",
    "    lr=2e-3,\n",
    "    wd=1e-5,\n",
    ")\n",
    "rec_softmax_base = make_softmax_recommender(softmax_base)\n",
    "neural_models[\"SoftmaxMLP_base\"] = rec_softmax_base\n",
    "neural_results.append(evaluate_model(rec_softmax_base, \"SoftmaxMLP_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training Softmax MLP large\")\n",
    "softmax_large = train_softmax_model(\n",
    "    model_name=\"SoftmaxMLP_large\",\n",
    "    emb_dim=96,\n",
    "    hidden_dims=(1024, 1024, 512, 256),\n",
    "    epochs=SOFTMAX_EPOCHS,\n",
    "    lr=1.5e-3,\n",
    "    wd=1e-5,\n",
    ")\n",
    "rec_softmax_large = make_softmax_recommender(softmax_large)\n",
    "neural_models[\"SoftmaxMLP_large\"] = rec_softmax_large\n",
    "neural_results.append(evaluate_model(rec_softmax_large, \"SoftmaxMLP_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training TwoTower base\")\n",
    "twotower_base = train_two_tower(\n",
    "    model_name=\"TwoTower_base\",\n",
    "    emb_dim=64,\n",
    "    hidden_dims=(512, 256),\n",
    "    out_dim=128,\n",
    "    epochs=TWOTOWER_EPOCHS,\n",
    "    lr=2e-3,\n",
    "    wd=1e-5,\n",
    "    temperature=0.07,\n",
    ")\n",
    "rec_twotower_base = make_two_tower_recommender(twotower_base)\n",
    "neural_models[\"TwoTower_base\"] = rec_twotower_base\n",
    "neural_results.append(evaluate_model(rec_twotower_base, \"TwoTower_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training TwoTower large\")\n",
    "twotower_large = train_two_tower(\n",
    "    model_name=\"TwoTower_large\",\n",
    "    emb_dim=96,\n",
    "    hidden_dims=(1024, 512, 256),\n",
    "    out_dim=192,\n",
    "    epochs=TWOTOWER_EPOCHS,\n",
    "    lr=1.5e-3,\n",
    "    wd=1e-5,\n",
    "    temperature=0.05,\n",
    ")\n",
    "rec_twotower_large = make_two_tower_recommender(twotower_large)\n",
    "neural_models[\"TwoTower_large\"] = rec_twotower_large\n",
    "neural_results.append(evaluate_model(rec_twotower_large, \"TwoTower_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training MultVAE base\")\n",
    "multvae_base = train_multvae(\n",
    "    model_name=\"MultVAE_base\",\n",
    "    hidden_dim=768,\n",
    "    latent_dim=192,\n",
    "    dropout=0.25,\n",
    "    epochs=MULTVAE_EPOCHS,\n",
    "    lr=1e-3,\n",
    "    wd=0.0,\n",
    ")\n",
    "rec_multvae_base = make_multvae_recommender(multvae_base)\n",
    "neural_models[\"MultVAE_base\"] = rec_multvae_base\n",
    "neural_results.append(evaluate_model(rec_multvae_base, \"MultVAE_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training MultVAE_large\")\n",
    "multvae_large = train_multvae(\n",
    "    model_name=\"MultVAE_large\",\n",
    "    hidden_dim=1024,\n",
    "    latent_dim=256,\n",
    "    dropout=0.30,\n",
    "    epochs=MULTVAE_EPOCHS,\n",
    "    lr=8e-4,\n",
    "    wd=0.0,\n",
    ")\n",
    "rec_multvae_large = make_multvae_recommender(multvae_large)\n",
    "neural_models[\"MultVAE_large\"] = rec_multvae_large\n",
    "neural_results.append(evaluate_model(rec_multvae_large, \"MultVAE_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "neural_results_df = pd.DataFrame(neural_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "display(neural_results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0a525a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Final combined comparison\n",
    "\n",
    "all_results_df = pd.concat(\n",
    "    [regular_results_df.assign(Family=\"Regular\"), neural_results_df.assign(Family=\"Neural\")],\n",
    "    ignore_index=True\n",
    ").sort_values([\"HR@10\", \"NDCG@10\", \"HR@20\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "display(all_results_df)\n",
    "\n",
    "print()\n",
    "print(\"Top regular models:\")\n",
    "display(regular_results_df.head(10))\n",
    "\n",
    "print()\n",
    "print(\"Top neural models:\")\n",
    "display(neural_results_df.head(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f4f73bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show example recommendations from the best overall model\n",
    "\n",
    "best_model_name = all_results_df.iloc[0][\"Model\"]\n",
    "print(\"Best overall model:\", best_model_name)\n",
    "\n",
    "if best_model_name in regular_models:\n",
    "    best_recommender = regular_models[best_model_name]\n",
    "else:\n",
    "    best_recommender = neural_models[best_model_name]\n",
    "\n",
    "example_user = int(train_interactions[\"user_idx\"].sample(1, random_state=RANDOM_STATE).iloc[0])\n",
    "\n",
    "print(\"Example user / segment:\")\n",
    "print(data[\"user_ids\"][example_user])\n",
    "print()\n",
    "\n",
    "print(\"Seen items (first 20):\")\n",
    "print([item_ids[i] for i in sorted(list(user_seen[example_user]))[:20]])\n",
    "print()\n",
    "\n",
    "print(\"Recommended items:\")\n",
    "print([item_ids[i] for i in best_recommender(example_user, n=10)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b394fadb",
   "metadata": {},
   "source": [
    "## Notes for the next iteration\n",
    "\n",
    "If you still want even better quality after this run, the next tuning steps are:\n",
    "\n",
    "1. keep only the best **formulation** from the quick search;\n",
    "2. keep only the best **2–3 regular** and **2–3 neural** models from the result table;\n",
    "3. increase:\n",
    "   - item granularity a little more if the item space is still too small;\n",
    "   - batch sizes if the GPU still has spare memory;\n",
    "   - hidden sizes for the strongest neural model;\n",
    "4. tune:\n",
    "   - `EASE_LAMBDA_GRID`\n",
    "   - `ITEMKNN_NEIGHBORS_GRID`\n",
    "   - `temperature` in the two-tower model\n",
    "   - `latent_dim` and `hidden_dim` in MultVAE\n",
    "5. if the GPU is still underused, the most direct way to push it harder is:\n",
    "   - larger item granularity,\n",
    "   - larger batches,\n",
    "   - larger hidden layers,\n",
    "   - or replacing the softmax / two-tower models with even wider networks.\n",
    "\n",
    "This notebook is intentionally written to search broadly first."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
