{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8cd9eae1",
   "metadata": {},
   "source": [
    "# Car recommender notebook — aggressive search version\n",
    "\n",
    "This version is designed to do two things:\n",
    "\n",
    "1. search across **multiple task formulations** to find a stronger recommendation setup;\n",
    "2. train **more GPU-heavy models** so the GPU does much more work than in the previous notebook.\n",
    "\n",
    "Important note:\n",
    "- with the previous formulation there were only **88 items**, which is too small to fully saturate a modern GPU;\n",
    "- this notebook increases the item granularity and uses **large-batch neural training**, **full-softmax classification**, and **in-batch contrastive ranking** to push GPU usage much higher.\n",
    "\n",
    "This notebook intentionally tries **many options**. After it finishes, keep the strongest few models and throw away the rest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f5ca59df",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch device: cuda\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import gc\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset, TensorDataset\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Torch device:\", device)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    try:\n",
    "        torch.set_float32_matmul_precision(\"high\")\n",
    "    except Exception:\n",
    "        pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3156b716",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Formulations: ['F1_country_price_mileage_cond__brand_model', 'F2_country_price_mileage_cond_age__brand_model_age_cond', 'F3_country_price_mileage_age__brand_model_age_cond', 'F4_country_price_mileage__brand_model_age_cond']\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "# Binning\n",
    "N_PRICE_BINS = 10\n",
    "N_MILEAGE_BINS = 10\n",
    "N_AGE_BINS = 8\n",
    "\n",
    "# Interaction filtering\n",
    "MIN_ITEMS_PER_USER = 5\n",
    "MIN_USERS_PER_ITEM = 10\n",
    "MAX_FILTER_ITERS = 8\n",
    "\n",
    "# Ranking metrics\n",
    "TOP_KS = [5, 10, 20]\n",
    "\n",
    "# Model search grids\n",
    "ITEMKNN_NEIGHBORS_GRID = [50, 100, 150]\n",
    "EASE_LAMBDA_GRID = [50.0, 100.0, 200.0, 500.0, 1000.0]\n",
    "\n",
    "# Neural training defaults\n",
    "SOFTMAX_EPOCHS = 18\n",
    "TWOTOWER_EPOCHS = 18\n",
    "MULTVAE_EPOCHS = 60\n",
    "\n",
    "# Conservative defaults for an RTX 3080; increase later if you still have free VRAM.\n",
    "SOFTMAX_BATCH_SIZE = 8192\n",
    "TWOTOWER_BATCH_SIZE = 4096\n",
    "MULTVAE_BATCH_SIZE = 2048\n",
    "\n",
    "# Important: keep this at 0 in notebooks / Python 3.14 to avoid fragile worker crashes.\n",
    "NUM_WORKERS = 0\n",
    "USE_AMP = torch.cuda.is_available()\n",
    "\n",
    "# Candidate task formulations.\n",
    "# Each one defines a pseudo-user profile and an item granularity.\n",
    "FORMULATIONS = {\n",
    "    \"F1_country_price_mileage_cond__brand_model\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\"],\n",
    "    },\n",
    "    \"F2_country_price_mileage_cond_age__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"AgeBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "    \"F3_country_price_mileage_age__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"AgeBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "    \"F4_country_price_mileage__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b60787ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape: (1000000, 15)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513.0</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>1</td>\n",
       "      <td>p2</td>\n",
       "      <td>m2</td>\n",
       "      <td>a0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990.0</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>24</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703.0</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>10</td>\n",
       "      <td>p5</td>\n",
       "      <td>m0</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353.0</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>21</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>a6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854.0</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>12</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20  58513.0   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14  60990.0   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93   1703.0   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94  25353.0  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01  76854.0  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    1   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   24   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   10   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   21   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   12   \n",
       "\n",
       "  PriceBin MileageBin AgeBin  \n",
       "0       p2         m2     a0  \n",
       "1       p0         m3     a7  \n",
       "2       p5         m0     a3  \n",
       "3       p0         m1     a6  \n",
       "4       p0         m3     a3  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Age</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Condition</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>1</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513.0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>p2</td>\n",
       "      <td>m2</td>\n",
       "      <td>a0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>24</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990.0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Italy</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>10</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703.0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>UK</td>\n",
       "      <td>p5</td>\n",
       "      <td>m0</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>21</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353.0</td>\n",
       "      <td>Used</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>a6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>12</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854.0</td>\n",
       "      <td>Used</td>\n",
       "      <td>USA</td>\n",
       "      <td>p0</td>\n",
       "      <td>m3</td>\n",
       "      <td>a3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year  Age     Price  Mileage            Condition  \\\n",
       "0       Honda        Civic  2023    1  25627.20  58513.0  Certified Pre-Owned   \n",
       "1       Mazda       Mazda3  2000   24  12027.14  60990.0  Certified Pre-Owned   \n",
       "2       Mazda         CX-5  2014   10  49194.93   1703.0  Certified Pre-Owned   \n",
       "3     Hyundai       Tucson  2003   21  11955.94  25353.0                 Used   \n",
       "4  Land Rover  Range Rover  2012   12  10910.01  76854.0                 Used   \n",
       "\n",
       "  Country PriceBin MileageBin AgeBin  \n",
       "0  Brazil       p2         m2     a0  \n",
       "1   Italy       p0         m3     a7  \n",
       "2      UK       p5         m0     a3  \n",
       "3  Mexico       p0         m1     a6  \n",
       "4     USA       p0         m3     a3  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean data\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\", \"First Name\", \"Last Name\", \"Address\"]:\n",
    "    if col in df.columns:\n",
    "        df[col] = df[col].astype(str).str.strip()\n",
    "\n",
    "df[\"Year\"] = df[\"Year\"].astype(int)\n",
    "df[\"Price\"] = df[\"Price\"].astype(float)\n",
    "df[\"Mileage\"] = df[\"Mileage\"].astype(float)\n",
    "\n",
    "# Age from the newest year in the file\n",
    "max_year = int(df[\"Year\"].max())\n",
    "df[\"Age\"] = (max_year - df[\"Year\"]).astype(int)\n",
    "\n",
    "def qbin(series, q, prefix):\n",
    "    cat = pd.qcut(series, q=q, duplicates=\"drop\")\n",
    "    return prefix + cat.cat.codes.astype(str)\n",
    "\n",
    "df[\"PriceBin\"] = qbin(df[\"Price\"], N_PRICE_BINS, \"p\")\n",
    "df[\"MileageBin\"] = qbin(df[\"Mileage\"], N_MILEAGE_BINS, \"m\")\n",
    "df[\"AgeBin\"] = qbin(df[\"Age\"], N_AGE_BINS, \"a\")\n",
    "\n",
    "print(\"Shape:\", df.shape)\n",
    "display(df.head())\n",
    "display(df[[\"Brand\", \"Model\", \"Year\", \"Age\", \"Price\", \"Mileage\", \"Condition\", \"Country\", \"PriceBin\", \"MileageBin\", \"AgeBin\"]].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6739a0dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user=5, min_users_per_item=10, max_iters=8):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_shape = out.shape[0]\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if out.shape[0] == old_shape:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "           .size()\n",
    "           .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = iterative_filter(\n",
    "        interactions,\n",
    "        min_items_per_user=MIN_ITEMS_PER_USER,\n",
    "        min_users_per_item=MIN_USERS_PER_ITEM,\n",
    "        max_iters=MAX_FILTER_ITERS\n",
    "    ).reset_index(drop=True)\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    num_users = interactions[\"user_idx\"].nunique()\n",
    "    num_items = interactions[\"item_idx\"].nunique()\n",
    "\n",
    "    # user/item feature tables for the selected formulation\n",
    "    user_feature_df = (\n",
    "        tmp[[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "\n",
    "    # weighted leave-one-out split\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = np.random.default_rng(RANDOM_STATE + int(uid)).choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(picked)\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    # matrices from train only\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    user_seen = {\n",
    "        int(uid): set(g[\"item_idx\"].astype(int).tolist())\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    user_strength = {\n",
    "        int(uid): {int(i): float(c) for i, c in zip(g[\"item_idx\"], g[\"count\"])}\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(uid): int(i)\n",
    "        for uid, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy()\n",
    "    )\n",
    "\n",
    "    item_ids = item_encoder.classes_.tolist()\n",
    "    user_ids = user_encoder.classes_.tolist()\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df.sort_values(\"user_idx\").reset_index(drop=True),\n",
    "        \"item_feature_df\": item_feature_df.sort_values(\"item_idx\").reset_index(drop=True),\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen\": user_seen,\n",
    "        \"user_strength\": user_strength,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_ids,\n",
    "        \"user_ids\": user_ids,\n",
    "    }\n",
    "\n",
    "def hit_rate_at_k(recs, true_item, k):\n",
    "    return 1.0 if true_item in recs[:k] else 0.0\n",
    "\n",
    "def mrr_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / rank\n",
    "    return 0.0\n",
    "\n",
    "def ndcg_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / math.log2(rank + 1.0)\n",
    "    return 0.0\n",
    "\n",
    "def evaluate_model(model_func, model_name, data_bundle, user_indices=None, ks=(5, 10, 20)):\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(data_bundle[\"test_item_by_user\"].keys()))\n",
    "\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr = 0.0\n",
    "    ndcg = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    for uid in tqdm(user_indices, desc=model_name):\n",
    "        uid = int(uid)\n",
    "        true_item = data_bundle[\"test_item_by_user\"].get(uid)\n",
    "        if true_item is None:\n",
    "            continue\n",
    "\n",
    "        recs = model_func(uid, n=max(ks))\n",
    "        valid += 1\n",
    "        for k in ks:\n",
    "            hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "        mrr += mrr_at_k(recs, true_item, 10)\n",
    "        ndcg += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    row = {\"Model\": model_name}\n",
    "    for k in ks:\n",
    "        row[f\"HR@{k}\"] = hits[k] / valid if valid else 0.0\n",
    "    row[\"MRR@10\"] = mrr / valid if valid else 0.0\n",
    "    row[\"NDCG@10\"] = ndcg / valid if valid else 0.0\n",
    "    return row\n",
    "\n",
    "def topn_from_scores(scores, seen, n):\n",
    "    scores = scores.copy()\n",
    "    if seen:\n",
    "        scores[list(seen)] = -np.inf\n",
    "    n = min(n, len(scores) - len(seen))\n",
    "    if n <= 0:\n",
    "        return []\n",
    "    idx = np.argpartition(scores, -n)[-n:]\n",
    "    idx = idx[np.argsort(scores[idx])[::-1]]\n",
    "    return [int(i) for i in idx]\n",
    "\n",
    "def bm25_weight(X, K1=100, B=0.8):\n",
    "    X = X.tocoo(copy=True).astype(np.float32)\n",
    "    N = float(X.shape[0])\n",
    "\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    avgdl = row_sums.mean() + 1e-8\n",
    "\n",
    "    df = np.bincount(X.col, minlength=X.shape[1]).astype(np.float32)\n",
    "    idf = np.log((N - df + 0.5) / (df + 0.5))\n",
    "    idf = np.maximum(idf, 0)\n",
    "\n",
    "    denom = X.data + K1 * (1 - B + B * row_sums[X.row] / avgdl)\n",
    "    data = X.data * (K1 + 1) / denom * idf[X.col]\n",
    "\n",
    "    return csr_matrix((data, (X.row, X.col)), shape=X.shape, dtype=np.float32)\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Train rows  :\", bundle[\"train_interactions\"].shape)\n",
    "    print(\"Test rows   :\", bundle[\"test_interactions\"].shape)\n",
    "    print(\"Matrix shape:\", bundle[\"X_binary\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54e4aa85",
   "metadata": {},
   "source": [
    "## Quick formulation search\n",
    "\n",
    "This step tries several ways of defining the pseudo-user and the item.\n",
    "The goal is to find a setup that gives a stronger recommendation signal **before** training the heavier neural models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "df30b27b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Formulation: F1_country_price_mileage_cond__brand_model\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition']\n",
      "Item cols  : ['Brand', 'Model']\n",
      "Users      : 3000\n",
      "Items      : 88\n",
      "Train rows  : (254673, 5)\n",
      "Test rows   : (3000, 5)\n",
      "Matrix shape: (3000, 88)\n",
      "====================================================================================================\n",
      "Formulation: F2_country_price_mileage_cond_age__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition', 'AgeBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 24000\n",
      "Items      : 2112\n",
      "Train rows  : (769325, 5)\n",
      "Test rows   : (24000, 5)\n",
      "Matrix shape: (24000, 2112)\n",
      "====================================================================================================\n",
      "Formulation: F3_country_price_mileage_age__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'AgeBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 8000\n",
      "Items      : 2112\n",
      "Train rows  : (785325, 5)\n",
      "Test rows   : (8000, 5)\n",
      "Matrix shape: (8000, 2112)\n",
      "====================================================================================================\n",
      "Formulation: F4_country_price_mileage__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 1000\n",
      "Items      : 2112\n",
      "Train rows  : (792325, 5)\n",
      "Test rows   : (1000, 5)\n",
      "Matrix shape: (1000, 2112)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>Interactions</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>F2_country_price_mileage_cond_age__brand_model...</td>\n",
       "      <td>24000</td>\n",
       "      <td>2112</td>\n",
       "      <td>793325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>F3_country_price_mileage_age__brand_model_age_...</td>\n",
       "      <td>8000</td>\n",
       "      <td>2112</td>\n",
       "      <td>793325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>F4_country_price_mileage__brand_model_age_cond</td>\n",
       "      <td>1000</td>\n",
       "      <td>2112</td>\n",
       "      <td>793325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>F1_country_price_mileage_cond__brand_model</td>\n",
       "      <td>3000</td>\n",
       "      <td>88</td>\n",
       "      <td>257673</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         Formulation  Users  Items  \\\n",
       "0  F2_country_price_mileage_cond_age__brand_model...  24000   2112   \n",
       "1  F3_country_price_mileage_age__brand_model_age_...   8000   2112   \n",
       "2     F4_country_price_mileage__brand_model_age_cond   1000   2112   \n",
       "3         F1_country_price_mileage_cond__brand_model   3000     88   \n",
       "\n",
       "   Interactions  \n",
       "0        793325  \n",
       "1        793325  \n",
       "2        793325  \n",
       "3        257673  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Build all candidate formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"Interactions\": bundle[\"interactions\"].shape[0],\n",
    "    })\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values([\"Items\", \"Users\"], ascending=False).reset_index(drop=True)\n",
    "display(summary_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "74500d09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F1_country_price_mileage_cond__brand_model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F1_country_price_mileage_cond__brand_model / Pop: 100%|██████████| 3000/3000 [00:00<00:00, 129310.15it/s]\n",
      "F1_country_price_mileage_cond__brand_model / EASE200: 100%|██████████| 3000/3000 [00:00<00:00, 24565.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F2_country_price_mileage_cond_age__brand_model_age_cond\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F2_country_price_mileage_cond_age__brand_model_age_cond / Pop: 100%|██████████| 24000/24000 [00:04<00:00, 5509.46it/s]\n",
      "F2_country_price_mileage_cond_age__brand_model_age_cond / EASE200: 100%|██████████| 24000/24000 [00:26<00:00, 890.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F3_country_price_mileage_age__brand_model_age_cond\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F3_country_price_mileage_age__brand_model_age_cond / Pop: 100%|██████████| 8000/8000 [00:01<00:00, 5541.81it/s]\n",
      "F3_country_price_mileage_age__brand_model_age_cond / EASE200: 100%|██████████| 8000/8000 [00:09<00:00, 885.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: F4_country_price_mileage__brand_model_age_cond\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F4_country_price_mileage__brand_model_age_cond / Pop: 100%|██████████| 1000/1000 [00:00<00:00, 4790.18it/s]\n",
      "F4_country_price_mileage__brand_model_age_cond / EASE200: 100%|██████████| 1000/1000 [00:01<00:00, 878.77it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>Pop_HR@10</th>\n",
       "      <th>EASE200_HR@10</th>\n",
       "      <th>EASE200_NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>F1_country_price_mileage_cond__brand_model</td>\n",
       "      <td>3000</td>\n",
       "      <td>88</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.751848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>F2_country_price_mileage_cond_age__brand_model...</td>\n",
       "      <td>24000</td>\n",
       "      <td>2112</td>\n",
       "      <td>0.007000</td>\n",
       "      <td>0.201208</td>\n",
       "      <td>0.097243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>F3_country_price_mileage_age__brand_model_age_...</td>\n",
       "      <td>8000</td>\n",
       "      <td>2112</td>\n",
       "      <td>0.007125</td>\n",
       "      <td>0.064375</td>\n",
       "      <td>0.029549</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>F4_country_price_mileage__brand_model_age_cond</td>\n",
       "      <td>1000</td>\n",
       "      <td>2112</td>\n",
       "      <td>0.024000</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>0.004541</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         Formulation  Users  Items  Pop_HR@10  \\\n",
       "0         F1_country_price_mileage_cond__brand_model   3000     88   1.000000   \n",
       "1  F2_country_price_mileage_cond_age__brand_model...  24000   2112   0.007000   \n",
       "2  F3_country_price_mileage_age__brand_model_age_...   8000   2112   0.007125   \n",
       "3     F4_country_price_mileage__brand_model_age_cond   1000   2112   0.024000   \n",
       "\n",
       "   EASE200_HR@10  EASE200_NDCG@10  \n",
       "0       1.000000         0.751848  \n",
       "1       0.201208         0.097243  \n",
       "2       0.064375         0.029549  \n",
       "3       0.010000         0.004541  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected formulation: F1_country_price_mileage_cond__brand_model\n",
      "Formulation: F1_country_price_mileage_cond__brand_model\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition']\n",
      "Item cols  : ['Brand', 'Model']\n",
      "Users      : 3000\n",
      "Items      : 88\n",
      "Train rows  : (254673, 5)\n",
      "Test rows   : (3000, 5)\n",
      "Matrix shape: (3000, 88)\n"
     ]
    }
   ],
   "source": [
    "# Quick screen with two cheap baselines:\n",
    "# 1) global popularity\n",
    "# 2) EASE with lambda=200 on binary interactions\n",
    "# The best formulation by EASE HR@10 becomes the main formulation.\n",
    "\n",
    "screen_rows = []\n",
    "\n",
    "for name, bundle in bundles.items():\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Screening formulation:\", name)\n",
    "\n",
    "    def rec_pop(uid, n=10, bundle=bundle):\n",
    "        seen = bundle[\"user_seen\"].get(uid, set())\n",
    "        return [int(i) for i in bundle[\"global_pop_rank\"] if int(i) not in seen][:n]\n",
    "\n",
    "    pop_res = evaluate_model(rec_pop, f\"{name} / Pop\", bundle, ks=TOP_KS)\n",
    "\n",
    "    Xb = bundle[\"X_binary\"]\n",
    "    G = (Xb.T @ Xb).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += 200.0\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "\n",
    "    def rec_ease(uid, n=10, bundle=bundle, B=B):\n",
    "        scores = bundle[\"X_binary\"].getrow(uid).toarray().ravel() @ B\n",
    "        seen = bundle[\"user_seen\"].get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "\n",
    "    ease_res = evaluate_model(rec_ease, f\"{name} / EASE200\", bundle, ks=TOP_KS)\n",
    "\n",
    "    screen_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"Pop_HR@10\": pop_res[\"HR@10\"],\n",
    "        \"EASE200_HR@10\": ease_res[\"HR@10\"],\n",
    "        \"EASE200_NDCG@10\": ease_res[\"NDCG@10\"],\n",
    "    })\n",
    "\n",
    "screen_df = pd.DataFrame(screen_rows).sort_values(\n",
    "    [\"EASE200_HR@10\", \"EASE200_NDCG@10\", \"Items\"], ascending=False\n",
    ").reset_index(drop=True)\n",
    "\n",
    "display(screen_df)\n",
    "\n",
    "BEST_FORMULATION = screen_df.iloc[0][\"Formulation\"]\n",
    "print(\"Selected formulation:\", BEST_FORMULATION)\n",
    "\n",
    "data = bundles[BEST_FORMULATION]\n",
    "print_bundle_summary(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "014f1da7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>user_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Australia | p0 | m0 | Certified Pre-Owned</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m0</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Australia | p0 | m0 | New</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m0</td>\n",
       "      <td>New</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Australia | p0 | m0 | Used</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m0</td>\n",
       "      <td>Used</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Australia | p0 | m1 | Certified Pre-Owned</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Australia | p0 | m1 | New</td>\n",
       "      <td>Australia</td>\n",
       "      <td>p0</td>\n",
       "      <td>m1</td>\n",
       "      <td>New</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                     user_id    Country PriceBin MileageBin  \\\n",
       "0  Australia | p0 | m0 | Certified Pre-Owned  Australia       p0         m0   \n",
       "1                  Australia | p0 | m0 | New  Australia       p0         m0   \n",
       "2                 Australia | p0 | m0 | Used  Australia       p0         m0   \n",
       "3  Australia | p0 | m1 | Certified Pre-Owned  Australia       p0         m1   \n",
       "4                  Australia | p0 | m1 | New  Australia       p0         m1   \n",
       "\n",
       "             Condition  user_idx  \n",
       "0  Certified Pre-Owned         0  \n",
       "1                  New         1  \n",
       "2                 Used         2  \n",
       "3  Certified Pre-Owned         3  \n",
       "4                  New         4  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>item_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Audi :: A3</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Audi :: A4</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A4</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Audi :: A6</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A6</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Audi :: Q5</td>\n",
       "      <td>Audi</td>\n",
       "      <td>Q5</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Audi :: Q7</td>\n",
       "      <td>Audi</td>\n",
       "      <td>Q7</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      item_id Brand Model  item_idx\n",
       "0  Audi :: A3  Audi    A3         0\n",
       "1  Audi :: A4  Audi    A4         1\n",
       "2  Audi :: A6  Audi    A6         2\n",
       "3  Audi :: Q5  Audi    Q5         3\n",
       "4  Audi :: Q7  Audi    Q7         4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Prepare reusable arrays and feature tables for the selected formulation\n",
    "\n",
    "user_feature_df = data[\"user_feature_df\"].copy()\n",
    "item_feature_df = data[\"item_feature_df\"].copy()\n",
    "\n",
    "train_interactions = data[\"train_interactions\"].copy()\n",
    "test_interactions = data[\"test_interactions\"].copy()\n",
    "\n",
    "X_counts = data[\"X_counts\"].tocsr()\n",
    "X_binary = data[\"X_binary\"].tocsr()\n",
    "\n",
    "num_users = data[\"num_users\"]\n",
    "num_items = data[\"num_items\"]\n",
    "\n",
    "user_seen = data[\"user_seen\"]\n",
    "user_strength = data[\"user_strength\"]\n",
    "test_item_by_user = data[\"test_item_by_user\"]\n",
    "global_pop_rank = data[\"global_pop_rank\"]\n",
    "item_ids = data[\"item_ids\"]\n",
    "\n",
    "# Dense versions only for models that need them\n",
    "X_binary_dense = X_binary.toarray().astype(np.float32)\n",
    "X_counts_dense = X_counts.toarray().astype(np.float32)\n",
    "\n",
    "display(user_feature_df.head())\n",
    "display(item_feature_df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a7b336d",
   "metadata": {},
   "source": [
    "## Regular baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8c765ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 1: global popularity\n",
    "\n",
    "def recommend_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(int(user_idx), set())\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1bcb91ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 2: country-aware popularity baseline\n",
    "\n",
    "country_col = \"Country\" if \"Country\" in user_feature_df.columns else None\n",
    "user_country = {}\n",
    "\n",
    "if country_col is not None:\n",
    "    user_country = dict(zip(user_feature_df[\"user_idx\"], user_feature_df[country_col]))\n",
    "\n",
    "    train_with_country = train_interactions.merge(\n",
    "        user_feature_df[[\"user_idx\", country_col]],\n",
    "        on=\"user_idx\",\n",
    "        how=\"left\"\n",
    "    )\n",
    "\n",
    "    country_pop_rank = {}\n",
    "    for country, g in train_with_country.groupby(country_col):\n",
    "        country_pop_rank[country] = (\n",
    "            g.groupby(\"item_idx\")[\"count\"]\n",
    "             .sum()\n",
    "             .sort_values(ascending=False)\n",
    "             .index.to_numpy()\n",
    "        )\n",
    "else:\n",
    "    country_pop_rank = {}\n",
    "\n",
    "def recommend_country_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(int(user_idx), set())\n",
    "    country = user_country.get(int(user_idx), None)\n",
    "\n",
    "    if country in country_pop_rank:\n",
    "        recs = [int(i) for i in country_pop_rank[country] if int(i) not in seen][:n]\n",
    "        if len(recs) >= n:\n",
    "            return recs\n",
    "\n",
    "    # fallback\n",
    "    out = [int(i) for i in global_pop_rank if int(i) not in seen][:n]\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "114bafa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 3: ItemKNN on binary interactions\n",
    "\n",
    "def fit_itemknn_recommender(X_matrix, neighbors=100):\n",
    "    item_user = X_matrix.T.tocsr()\n",
    "\n",
    "    knn = NearestNeighbors(\n",
    "        metric=\"cosine\",\n",
    "        algorithm=\"brute\",\n",
    "        n_neighbors=min(neighbors, item_user.shape[0]),\n",
    "        n_jobs=-1\n",
    "    )\n",
    "    knn.fit(item_user)\n",
    "\n",
    "    distances, indices = knn.kneighbors(item_user, n_neighbors=min(neighbors, item_user.shape[0]))\n",
    "    similarities = (1.0 - distances).astype(np.float32)\n",
    "    return indices, similarities\n",
    "\n",
    "def make_itemknn_recommender(X_matrix, user_seen_dict, user_strength_dict, neighbors=100, use_strength=True):\n",
    "    neighbor_idx, neighbor_sim = fit_itemknn_recommender(X_matrix, neighbors=neighbors)\n",
    "\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        seen = user_seen_dict.get(uid, set())\n",
    "        strength_map = user_strength_dict.get(uid, {})\n",
    "        scores = {}\n",
    "\n",
    "        for item_idx in seen:\n",
    "            weight = float(strength_map.get(item_idx, 1.0)) if use_strength else 1.0\n",
    "            nbrs = neighbor_idx[item_idx]\n",
    "            sims = neighbor_sim[item_idx]\n",
    "\n",
    "            for j, sim in zip(nbrs, sims):\n",
    "                j = int(j)\n",
    "                if j == item_idx or j in seen:\n",
    "                    continue\n",
    "                scores[j] = scores.get(j, 0.0) + float(sim) * weight\n",
    "\n",
    "        if not scores:\n",
    "            return recommend_popularity(uid, n=n)\n",
    "\n",
    "        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:n]\n",
    "        return [int(i) for i, _ in ranked]\n",
    "\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5c57d7fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 4: ItemKNN with BM25 weighting\n",
    "\n",
    "X_bm25 = bm25_weight(X_counts, K1=100, B=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "74f0ff1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 5: EASE search\n",
    "# We try both binary and count-weighted inputs and keep all runs in the results table.\n",
    "\n",
    "def fit_ease(X_matrix, lam):\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += lam\n",
    "\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    return B.astype(np.float32)\n",
    "\n",
    "def make_ease_recommender(X_train_dense, B_matrix, user_seen_dict):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_train_dense[uid] @ B_matrix\n",
    "        seen = user_seen_dict.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acf6838c",
   "metadata": {},
   "source": [
    "## Neural models with heavier GPU usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f525bb97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "User feature columns: ['Country', 'PriceBin', 'MileageBin', 'Condition']\n",
      "Item feature columns: ['Brand', 'Model']\n",
      "Country user cardinality: 10\n",
      "PriceBin user cardinality: 10\n",
      "MileageBin user cardinality: 10\n",
      "Condition user cardinality: 3\n",
      "Brand item cardinality: 18\n",
      "Model item cardinality: 88\n"
     ]
    }
   ],
   "source": [
    "# Feature encoders for neural models\n",
    "# Keep feature tensors on the target device so training batches do not bounce through the CPU.\n",
    "\n",
    "user_cols = data[\"user_cols\"]\n",
    "item_cols = data[\"item_cols\"]\n",
    "\n",
    "feature_encoders = {}\n",
    "for col in sorted(set(user_cols + item_cols)):\n",
    "    enc = LabelEncoder()\n",
    "    values = pd.concat([\n",
    "        user_feature_df[col].astype(str) if col in user_feature_df.columns else pd.Series(dtype=str),\n",
    "        item_feature_df[col].astype(str) if col in item_feature_df.columns else pd.Series(dtype=str),\n",
    "    ], ignore_index=True)\n",
    "    enc.fit(values)\n",
    "    feature_encoders[col] = enc\n",
    "\n",
    "user_feature_arrays = {}\n",
    "for col in user_cols:\n",
    "    user_feature_arrays[col] = feature_encoders[col].transform(user_feature_df[col].astype(str))\n",
    "\n",
    "item_feature_arrays = {}\n",
    "for col in item_cols:\n",
    "    item_feature_arrays[col] = feature_encoders[col].transform(item_feature_df[col].astype(str))\n",
    "\n",
    "user_feature_tensors = {\n",
    "    col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "    for col, arr in user_feature_arrays.items()\n",
    "}\n",
    "item_feature_tensors = {\n",
    "    col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "    for col, arr in item_feature_arrays.items()\n",
    "}\n",
    "\n",
    "print(\"User feature columns:\", user_cols)\n",
    "print(\"Item feature columns:\", item_cols)\n",
    "for col in user_cols:\n",
    "    print(col, \"user cardinality:\", len(feature_encoders[col].classes_))\n",
    "for col in item_cols:\n",
    "    print(col, \"item cardinality:\", len(feature_encoders[col].classes_))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "25793118",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive rows: 254673\n",
      "Dense user matrix shape: (3000, 88)\n"
     ]
    }
   ],
   "source": [
    "# GPU-friendly training tensors and manual batch iterators\n",
    "# This avoids fragile multiprocessing DataLoader workers and keeps the hot path on-device.\n",
    "\n",
    "positive_user_idx = torch.tensor(train_interactions[\"user_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "positive_item_idx = torch.tensor(train_interactions[\"item_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "positive_weight = torch.tensor(train_interactions[\"count\"].astype(np.float32).to_numpy(), dtype=torch.float32, device=device)\n",
    "\n",
    "X_binary_dense_tensor = torch.tensor(X_binary_dense, dtype=torch.float32, device=device)\n",
    "X_counts_dense_tensor = torch.tensor(X_counts_dense, dtype=torch.float32, device=device)\n",
    "\n",
    "n_positive_rows = int(positive_user_idx.shape[0])\n",
    "\n",
    "\n",
    "def iterate_positive_batches(batch_size, shuffle=True):\n",
    "    n = n_positive_rows\n",
    "    order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "    for start in range(0, n, batch_size):\n",
    "        idx = order[start:start + batch_size]\n",
    "        yield positive_user_idx[idx], positive_item_idx[idx], positive_weight[idx]\n",
    "\n",
    "\n",
    "def iterate_dense_user_batches(batch_size, shuffle=True):\n",
    "    n = int(X_binary_dense_tensor.shape[0])\n",
    "    order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "    for start in range(0, n, batch_size):\n",
    "        idx = order[start:start + batch_size]\n",
    "        yield X_binary_dense_tensor[idx], idx\n",
    "\n",
    "print(\"Positive rows:\", n_positive_rows)\n",
    "print(\"Dense user matrix shape:\", tuple(X_binary_dense_tensor.shape))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0c256a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural model 1: full-softmax feature MLP\n",
    "# This is deliberately GPU-friendly: large batches and a full item softmax.\n",
    "\n",
    "class FeatureEmbeddingBlock(nn.Module):\n",
    "    def __init__(self, feature_cardinalities, emb_dim):\n",
    "        super().__init__()\n",
    "        self.feature_names = list(feature_cardinalities.keys())\n",
    "        self.embs = nn.ModuleDict({\n",
    "            name: nn.Embedding(card, emb_dim)\n",
    "            for name, card in feature_cardinalities.items()\n",
    "        })\n",
    "\n",
    "    def forward(self, feature_batch):\n",
    "        parts = [self.embs[name](feature_batch[name]) for name in self.feature_names]\n",
    "        return torch.cat(parts, dim=-1)\n",
    "\n",
    "class SoftmaxSegmentMLP(nn.Module):\n",
    "    def __init__(self, feature_cardinalities, num_items, emb_dim=64, hidden_dims=(512, 512, 256), dropout=0.15):\n",
    "        super().__init__()\n",
    "        self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "        input_dim = len(feature_cardinalities) * emb_dim\n",
    "\n",
    "        layers = []\n",
    "        last = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]\n",
    "            last = h\n",
    "        self.mlp = nn.Sequential(*layers)\n",
    "        self.out = nn.Linear(last, num_items)\n",
    "\n",
    "    def forward(self, feature_batch):\n",
    "        x = self.encoder(feature_batch)\n",
    "        x = self.mlp(x)\n",
    "        return self.out(x)\n",
    "\n",
    "\n",
    "def make_user_feature_batch(user_idx_batch):\n",
    "    return {col: user_feature_tensors[col][user_idx_batch] for col in user_cols}\n",
    "\n",
    "\n",
    "def train_softmax_model(model_name, emb_dim=64, hidden_dims=(512, 512, 256), epochs=18, lr=2e-3, wd=1e-5):\n",
    "    feature_cardinalities = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "    model = SoftmaxSegmentMLP(\n",
    "        feature_cardinalities=feature_cardinalities,\n",
    "        num_items=num_items,\n",
    "        emb_dim=emb_dim,\n",
    "        hidden_dims=hidden_dims,\n",
    "        dropout=0.15,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "    criterion = nn.CrossEntropyLoss(reduction=\"none\")\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(SOFTMAX_BATCH_SIZE, shuffle=True):\n",
    "            feat_batch = make_user_feature_batch(user_idx_batch)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                logits = model(feat_batch)\n",
    "                loss_vec = criterion(logits, item_idx_batch)\n",
    "                loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "            total_w += float(weight_batch.sum().item())\n",
    "\n",
    "        print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "\n",
    "def make_softmax_recommender(model):\n",
    "    @torch.no_grad()\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "        feat_batch = make_user_feature_batch(u)\n",
    "        logits = model(feat_batch).squeeze(0).detach().cpu().numpy()\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(logits, seen, n)\n",
    "    return recommend\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a15df0f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural model 2: feature two-tower with in-batch negatives\n",
    "# This typically drives GPU usage much better than pointwise negative sampling.\n",
    "\n",
    "class TowerEncoder(nn.Module):\n",
    "    def __init__(self, feature_cardinalities, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "        input_dim = len(feature_cardinalities) * emb_dim\n",
    "\n",
    "        layers = []\n",
    "        last = input_dim\n",
    "        for h in hidden_dims:\n",
    "            layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]\n",
    "            last = h\n",
    "        layers += [nn.Linear(last, out_dim)]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, feature_batch):\n",
    "        x = self.encoder(feature_batch)\n",
    "        x = self.net(x)\n",
    "        x = F.normalize(x, dim=-1)\n",
    "        return x\n",
    "\n",
    "class TwoTowerModel(nn.Module):\n",
    "    def __init__(self, user_feature_cards, item_feature_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "        super().__init__()\n",
    "        self.user_tower = TowerEncoder(user_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "        self.item_tower = TowerEncoder(item_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "    def forward(self, user_batch, item_batch):\n",
    "        u = self.user_tower(user_batch)\n",
    "        i = self.item_tower(item_batch)\n",
    "        return u, i\n",
    "\n",
    "\n",
    "def make_item_feature_batch(item_idx_batch):\n",
    "    return {col: item_feature_tensors[col][item_idx_batch] for col in item_cols}\n",
    "\n",
    "\n",
    "def train_two_tower(model_name, emb_dim=64, hidden_dims=(512, 256), out_dim=128, epochs=18, lr=2e-3, wd=1e-5, temperature=0.07):\n",
    "    user_feature_cards = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "    item_feature_cards = {col: len(feature_encoders[col].classes_) for col in item_cols}\n",
    "\n",
    "    model = TwoTowerModel(\n",
    "        user_feature_cards=user_feature_cards,\n",
    "        item_feature_cards=item_feature_cards,\n",
    "        emb_dim=emb_dim,\n",
    "        hidden_dims=hidden_dims,\n",
    "        out_dim=out_dim,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        total_loss = 0.0\n",
    "        total_w = 0.0\n",
    "\n",
    "        for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(TWOTOWER_BATCH_SIZE, shuffle=True):\n",
    "            user_batch = make_user_feature_batch(user_idx_batch)\n",
    "            item_batch = make_item_feature_batch(item_idx_batch)\n",
    "\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                user_vec, item_vec = model(user_batch, item_batch)\n",
    "                logits = (user_vec @ item_vec.T) / temperature\n",
    "                targets = torch.arange(logits.shape[0], device=device)\n",
    "                loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "            total_w += float(weight_batch.sum().item())\n",
    "\n",
    "        print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "\n",
    "def make_two_tower_recommender(model):\n",
    "    model.eval()\n",
    "    all_item_idx = torch.arange(num_items, dtype=torch.long, device=device)\n",
    "    item_matrix = []\n",
    "    with torch.no_grad():\n",
    "        for start in range(0, num_items, 4096):\n",
    "            idx = all_item_idx[start:start+4096]\n",
    "            batch = make_item_feature_batch(idx)\n",
    "            vec = model.item_tower(batch)\n",
    "            item_matrix.append(vec)\n",
    "    item_matrix = torch.cat(item_matrix, dim=0)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "        user_batch = make_user_feature_batch(u)\n",
    "        user_vec = model.user_tower(user_batch)\n",
    "        scores = (user_vec @ item_matrix.T).squeeze(0).detach().cpu().numpy()\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "\n",
    "    return recommend\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1bb5a7ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural model 3: larger MultVAE\n",
    "# This stays close to recommender-system literature and uses the whole user-item vector as input.\n",
    "\n",
    "class MultVAE(nn.Module):\n",
    "    def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(num_items, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(latent_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, num_items),\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = self.encoder(self.dropout(x))\n",
    "        return self.mu(h), self.logvar(h)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        if self.training:\n",
    "            std = torch.exp(0.5 * logvar)\n",
    "            eps = torch.randn_like(std)\n",
    "            return mu + eps * std\n",
    "        return mu\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        logits = self.decoder(z)\n",
    "        return logits, mu, logvar\n",
    "\n",
    "\n",
    "def train_multvae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.3, epochs=60, lr=1e-3, wd=0.0, anneal_cap=1.0):\n",
    "    model = MultVAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "    steps_per_epoch = max(1, math.ceil(num_users / MULTVAE_BATCH_SIZE))\n",
    "    total_steps = max(1, epochs * steps_per_epoch)\n",
    "    step = 0\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        total_loss = 0.0\n",
    "        total_rows = 0\n",
    "\n",
    "        for batch_x, _ in iterate_dense_user_batches(MULTVAE_BATCH_SIZE, shuffle=True):\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "\n",
    "            anneal = min(anneal_cap, step / max(total_steps * 0.3, 1))\n",
    "            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                logits, mu, logvar = model(batch_x)\n",
    "                recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                loss = (recon + anneal * kl).mean()\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "\n",
    "            total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "            total_rows += batch_x.shape[0]\n",
    "            step += 1\n",
    "\n",
    "        if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f} - anneal: {anneal:.3f}\")\n",
    "\n",
    "    return model.eval()\n",
    "\n",
    "\n",
    "def make_multvae_recommender(model):\n",
    "    @torch.no_grad()\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        x = X_binary_dense_tensor[uid:uid+1]\n",
    "        logits, _, _ = model(x)\n",
    "        scores = logits.squeeze(0).detach().cpu().numpy()\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7181984",
   "metadata": {},
   "source": [
    "## Train the model zoo\n",
    "\n",
    "This section intentionally tries **many variants**.\n",
    "If you leave the notebook running, it should finish with a large comparison table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e711dcc4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: Popularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Popularity: 100%|██████████| 3000/3000 [00:00<00:00, 56449.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: CountryPopularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "CountryPopularity: 100%|██████████| 3000/3000 [00:00<00:00, 50090.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN neighbors=50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "ItemKNN_binary_k50: 100%|██████████| 3000/3000 [00:01<00:00, 2256.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN neighbors=100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_binary_k100: 100%|██████████| 3000/3000 [00:02<00:00, 1346.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN neighbors=150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_binary_k150: 100%|██████████| 3000/3000 [00:02<00:00, 1414.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k50: 100%|██████████| 3000/3000 [00:01<00:00, 2275.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k100: 100%|██████████| 3000/3000 [00:02<00:00, 1386.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k150: 100%|██████████| 3000/3000 [00:02<00:00, 1380.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=50.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l50: 100%|██████████| 3000/3000 [00:00<00:00, 66717.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=50.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l50: 100%|██████████| 3000/3000 [00:00<00:00, 65935.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l100: 100%|██████████| 3000/3000 [00:00<00:00, 67566.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l100: 100%|██████████| 3000/3000 [00:00<00:00, 65828.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l200: 100%|██████████| 3000/3000 [00:00<00:00, 65606.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l200: 100%|██████████| 3000/3000 [00:00<00:00, 66693.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=500.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l500: 100%|██████████| 3000/3000 [00:00<00:00, 66561.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=500.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l500: 100%|██████████| 3000/3000 [00:00<00:00, 63302.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l1000: 100%|██████████| 3000/3000 [00:00<00:00, 66746.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l1000: 100%|██████████| 3000/3000 [00:00<00:00, 66720.29it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>EASE_binary_l500</td>\n",
       "      <td>0.987000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.671896</td>\n",
       "      <td>0.754519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ItemKNN_binary_k100</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ItemKNN_binary_k150</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>0.989333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670471</td>\n",
       "      <td>0.753456</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>EASE_binary_l200</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.668304</td>\n",
       "      <td>0.751848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>EASE_binary_l50</td>\n",
       "      <td>0.985333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.665105</td>\n",
       "      <td>0.749452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>EASE_binary_l100</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.665007</td>\n",
       "      <td>0.749386</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>EASE_count_l100</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663710</td>\n",
       "      <td>0.748281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>EASE_count_l50</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663693</td>\n",
       "      <td>0.748266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663483</td>\n",
       "      <td>0.748141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>CountryPopularity</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663483</td>\n",
       "      <td>0.748141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>EASE_count_l200</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663510</td>\n",
       "      <td>0.748131</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>EASE_count_l500</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663443</td>\n",
       "      <td>0.748078</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.662649</td>\n",
       "      <td>0.747472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>ItemKNN_bm25_k100</td>\n",
       "      <td>0.991333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.651311</td>\n",
       "      <td>0.739278</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>ItemKNN_bm25_k150</td>\n",
       "      <td>0.991333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.651311</td>\n",
       "      <td>0.739278</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>ItemKNN_binary_k50</td>\n",
       "      <td>0.856333</td>\n",
       "      <td>0.858000</td>\n",
       "      <td>0.858000</td>\n",
       "      <td>0.620887</td>\n",
       "      <td>0.681031</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>ItemKNN_bm25_k50</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.546567</td>\n",
       "      <td>0.587167</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model      HR@5     HR@10     HR@20    MRR@10   NDCG@10\n",
       "0      EASE_binary_l500  0.987000  1.000000  1.000000  0.671896  0.754519\n",
       "1   ItemKNN_binary_k100  0.992000  1.000000  1.000000  0.670833  0.753761\n",
       "2   ItemKNN_binary_k150  0.992000  1.000000  1.000000  0.670833  0.753761\n",
       "3     EASE_binary_l1000  0.989333  1.000000  1.000000  0.670471  0.753456\n",
       "4      EASE_binary_l200  0.986000  1.000000  1.000000  0.668304  0.751848\n",
       "5       EASE_binary_l50  0.985333  1.000000  1.000000  0.665105  0.749452\n",
       "6      EASE_binary_l100  0.986000  1.000000  1.000000  0.665007  0.749386\n",
       "7       EASE_count_l100  0.988667  1.000000  1.000000  0.663710  0.748281\n",
       "8        EASE_count_l50  0.988667  1.000000  1.000000  0.663693  0.748266\n",
       "9            Popularity  0.986000  1.000000  1.000000  0.663483  0.748141\n",
       "10    CountryPopularity  0.986000  1.000000  1.000000  0.663483  0.748141\n",
       "11      EASE_count_l200  0.989000  1.000000  1.000000  0.663510  0.748131\n",
       "12      EASE_count_l500  0.989000  1.000000  1.000000  0.663443  0.748078\n",
       "13     EASE_count_l1000  0.988667  1.000000  1.000000  0.662649  0.747472\n",
       "14    ItemKNN_bm25_k100  0.991333  1.000000  1.000000  0.651311  0.739278\n",
       "15    ItemKNN_bm25_k150  0.991333  1.000000  1.000000  0.651311  0.739278\n",
       "16   ItemKNN_binary_k50  0.856333  0.858000  0.858000  0.620887  0.681031\n",
       "17     ItemKNN_bm25_k50  0.704667  0.704667  0.704667  0.546567  0.587167"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit and evaluate regular models\n",
    "\n",
    "regular_results = []\n",
    "regular_models = {}\n",
    "\n",
    "eval_users = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: Popularity\")\n",
    "regular_results.append(evaluate_model(recommend_popularity, \"Popularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "regular_models[\"Popularity\"] = recommend_popularity\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: CountryPopularity\")\n",
    "regular_results.append(evaluate_model(recommend_country_popularity, \"CountryPopularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "regular_models[\"CountryPopularity\"] = recommend_country_popularity\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_binary, user_seen, user_strength, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_binary_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN BM25 neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_bm25, user_seen, user_strength, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_bm25_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for lam in EASE_LAMBDA_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting EASE binary lambda={lam}\")\n",
    "    B_bin = fit_ease(X_binary, lam=lam)\n",
    "    rec_bin = make_ease_recommender(X_binary_dense, B_bin, user_seen)\n",
    "    name_bin = f\"EASE_binary_l{int(lam)}\"\n",
    "    regular_models[name_bin] = rec_bin\n",
    "    regular_results.append(evaluate_model(rec_bin, name_bin, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(f\"Fitting EASE count lambda={lam}\")\n",
    "    B_cnt = fit_ease(X_counts, lam=lam)\n",
    "    rec_cnt = make_ease_recommender(X_counts_dense, B_cnt, user_seen)\n",
    "    name_cnt = f\"EASE_count_l{int(lam)}\"\n",
    "    regular_models[name_cnt] = rec_cnt\n",
    "    regular_results.append(evaluate_model(rec_cnt, name_cnt, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "regular_results_df = pd.DataFrame(regular_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "display(regular_results_df.head(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1f125d6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training Softmax MLP base\n",
      "SoftmaxMLP_base epoch 1/18 - loss: 4.474615\n",
      "SoftmaxMLP_base epoch 2/18 - loss: 4.473007\n",
      "SoftmaxMLP_base epoch 3/18 - loss: 4.472695\n",
      "SoftmaxMLP_base epoch 4/18 - loss: 4.472480\n",
      "SoftmaxMLP_base epoch 5/18 - loss: 4.472484\n",
      "SoftmaxMLP_base epoch 6/18 - loss: 4.472268\n",
      "SoftmaxMLP_base epoch 7/18 - loss: 4.472184\n",
      "SoftmaxMLP_base epoch 8/18 - loss: 4.472228\n",
      "SoftmaxMLP_base epoch 9/18 - loss: 4.472040\n",
      "SoftmaxMLP_base epoch 10/18 - loss: 4.472055\n",
      "SoftmaxMLP_base epoch 11/18 - loss: 4.472115\n",
      "SoftmaxMLP_base epoch 12/18 - loss: 4.471971\n",
      "SoftmaxMLP_base epoch 13/18 - loss: 4.471942\n",
      "SoftmaxMLP_base epoch 14/18 - loss: 4.471918\n",
      "SoftmaxMLP_base epoch 15/18 - loss: 4.471896\n",
      "SoftmaxMLP_base epoch 16/18 - loss: 4.471815\n",
      "SoftmaxMLP_base epoch 17/18 - loss: 4.471831\n",
      "SoftmaxMLP_base epoch 18/18 - loss: 4.471953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "SoftmaxMLP_base: 100%|██████████| 3000/3000 [00:00<00:00, 3252.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training Softmax MLP large\n",
      "SoftmaxMLP_large epoch 1/18 - loss: 4.474768\n",
      "SoftmaxMLP_large epoch 2/18 - loss: 4.473153\n",
      "SoftmaxMLP_large epoch 3/18 - loss: 4.472868\n",
      "SoftmaxMLP_large epoch 4/18 - loss: 4.472367\n",
      "SoftmaxMLP_large epoch 5/18 - loss: 4.472347\n",
      "SoftmaxMLP_large epoch 6/18 - loss: 4.472274\n",
      "SoftmaxMLP_large epoch 7/18 - loss: 4.472229\n",
      "SoftmaxMLP_large epoch 8/18 - loss: 4.472126\n",
      "SoftmaxMLP_large epoch 9/18 - loss: 4.471973\n",
      "SoftmaxMLP_large epoch 10/18 - loss: 4.472106\n",
      "SoftmaxMLP_large epoch 11/18 - loss: 4.471998\n",
      "SoftmaxMLP_large epoch 12/18 - loss: 4.472055\n",
      "SoftmaxMLP_large epoch 13/18 - loss: 4.471984\n",
      "SoftmaxMLP_large epoch 14/18 - loss: 4.471943\n",
      "SoftmaxMLP_large epoch 15/18 - loss: 4.471994\n",
      "SoftmaxMLP_large epoch 16/18 - loss: 4.471955\n",
      "SoftmaxMLP_large epoch 17/18 - loss: 4.471955\n",
      "SoftmaxMLP_large epoch 18/18 - loss: 4.471945\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "SoftmaxMLP_large: 100%|██████████| 3000/3000 [00:01<00:00, 2947.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training TwoTower base\n",
      "TwoTower_base epoch 1/18 - loss: 8.331500\n",
      "TwoTower_base epoch 2/18 - loss: 8.307952\n",
      "TwoTower_base epoch 3/18 - loss: 8.308065\n",
      "TwoTower_base epoch 4/18 - loss: 8.308047\n",
      "TwoTower_base epoch 5/18 - loss: 8.308043\n",
      "TwoTower_base epoch 6/18 - loss: 8.307884\n",
      "TwoTower_base epoch 7/18 - loss: 8.307784\n",
      "TwoTower_base epoch 8/18 - loss: 8.307836\n",
      "TwoTower_base epoch 9/18 - loss: 8.307878\n",
      "TwoTower_base epoch 10/18 - loss: 8.307800\n",
      "TwoTower_base epoch 11/18 - loss: 8.307775\n",
      "TwoTower_base epoch 12/18 - loss: 8.307774\n",
      "TwoTower_base epoch 13/18 - loss: 8.307700\n",
      "TwoTower_base epoch 14/18 - loss: 8.307647\n",
      "TwoTower_base epoch 15/18 - loss: 8.307713\n",
      "TwoTower_base epoch 16/18 - loss: 8.307732\n",
      "TwoTower_base epoch 17/18 - loss: 8.307689\n",
      "TwoTower_base epoch 18/18 - loss: 8.307714\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TwoTower_base: 100%|██████████| 3000/3000 [00:00<00:00, 3151.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training TwoTower large\n",
      "TwoTower_large epoch 1/18 - loss: 8.327863\n",
      "TwoTower_large epoch 2/18 - loss: 8.308226\n",
      "TwoTower_large epoch 3/18 - loss: 8.308305\n",
      "TwoTower_large epoch 4/18 - loss: 8.308264\n",
      "TwoTower_large epoch 5/18 - loss: 8.308018\n",
      "TwoTower_large epoch 6/18 - loss: 8.307931\n",
      "TwoTower_large epoch 7/18 - loss: 8.307845\n",
      "TwoTower_large epoch 8/18 - loss: 8.307820\n",
      "TwoTower_large epoch 9/18 - loss: 8.307988\n",
      "TwoTower_large epoch 10/18 - loss: 8.307590\n",
      "TwoTower_large epoch 11/18 - loss: 8.307900\n",
      "TwoTower_large epoch 12/18 - loss: 8.307975\n",
      "TwoTower_large epoch 13/18 - loss: 8.307590\n",
      "TwoTower_large epoch 14/18 - loss: 8.307721\n",
      "TwoTower_large epoch 15/18 - loss: 8.307705\n",
      "TwoTower_large epoch 16/18 - loss: 8.307673\n",
      "TwoTower_large epoch 17/18 - loss: 8.307538\n",
      "TwoTower_large epoch 18/18 - loss: 8.307616\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TwoTower_large: 100%|██████████| 3000/3000 [00:01<00:00, 2711.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultVAE base\n",
      "MultVAE_base epoch 1/60 - loss: 382.892679 - anneal: 0.028\n",
      "MultVAE_base epoch 10/60 - loss: 380.400827 - anneal: 0.528\n",
      "MultVAE_base epoch 20/60 - loss: 380.204298 - anneal: 1.000\n",
      "MultVAE_base epoch 30/60 - loss: 380.174375 - anneal: 1.000\n",
      "MultVAE_base epoch 40/60 - loss: 380.168634 - anneal: 1.000\n",
      "MultVAE_base epoch 50/60 - loss: 380.161901 - anneal: 1.000\n",
      "MultVAE_base epoch 60/60 - loss: 380.158152 - anneal: 1.000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultVAE_base: 100%|██████████| 3000/3000 [00:00<00:00, 4809.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultVAE_large\n",
      "MultVAE_large epoch 1/60 - loss: 383.119435 - anneal: 0.028\n",
      "MultVAE_large epoch 10/60 - loss: 380.418164 - anneal: 0.528\n",
      "MultVAE_large epoch 20/60 - loss: 380.238343 - anneal: 1.000\n",
      "MultVAE_large epoch 30/60 - loss: 380.195111 - anneal: 1.000\n",
      "MultVAE_large epoch 40/60 - loss: 380.184380 - anneal: 1.000\n",
      "MultVAE_large epoch 50/60 - loss: 380.180346 - anneal: 1.000\n",
      "MultVAE_large epoch 60/60 - loss: 380.177089 - anneal: 1.000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultVAE_large: 100%|██████████| 3000/3000 [00:00<00:00, 4873.00it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultVAE_large</td>\n",
       "      <td>0.992667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.668940</td>\n",
       "      <td>0.752430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE_base</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.665517</td>\n",
       "      <td>0.749753</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>SoftmaxMLP_base</td>\n",
       "      <td>0.987667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.663778</td>\n",
       "      <td>0.748419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>TwoTower_large</td>\n",
       "      <td>0.987667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.663210</td>\n",
       "      <td>0.747932</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>SoftmaxMLP_large</td>\n",
       "      <td>0.985667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.661187</td>\n",
       "      <td>0.746365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>TwoTower_base</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.656112</td>\n",
       "      <td>0.742579</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Model      HR@5  HR@10  HR@20    MRR@10   NDCG@10\n",
       "0     MultVAE_large  0.992667    1.0    1.0  0.668940  0.752430\n",
       "1      MultVAE_base  0.989000    1.0    1.0  0.665517  0.749753\n",
       "2   SoftmaxMLP_base  0.987667    1.0    1.0  0.663778  0.748419\n",
       "3    TwoTower_large  0.987667    1.0    1.0  0.663210  0.747932\n",
       "4  SoftmaxMLP_large  0.985667    1.0    1.0  0.661187  0.746365\n",
       "5     TwoTower_base  0.986000    1.0    1.0  0.656112  0.742579"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit and evaluate neural models\n",
    "# These models are intentionally larger than before.\n",
    "\n",
    "neural_results = []\n",
    "neural_models = {}\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training Softmax MLP base\")\n",
    "softmax_base = train_softmax_model(\n",
    "    model_name=\"SoftmaxMLP_base\",\n",
    "    emb_dim=64,\n",
    "    hidden_dims=(512, 512, 256),\n",
    "    epochs=SOFTMAX_EPOCHS,\n",
    "    lr=2e-3,\n",
    "    wd=1e-5,\n",
    ")\n",
    "rec_softmax_base = make_softmax_recommender(softmax_base)\n",
    "neural_models[\"SoftmaxMLP_base\"] = rec_softmax_base\n",
    "neural_results.append(evaluate_model(rec_softmax_base, \"SoftmaxMLP_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training Softmax MLP large\")\n",
    "softmax_large = train_softmax_model(\n",
    "    model_name=\"SoftmaxMLP_large\",\n",
    "    emb_dim=96,\n",
    "    hidden_dims=(1024, 1024, 512, 256),\n",
    "    epochs=SOFTMAX_EPOCHS,\n",
    "    lr=1.5e-3,\n",
    "    wd=1e-5,\n",
    ")\n",
    "rec_softmax_large = make_softmax_recommender(softmax_large)\n",
    "neural_models[\"SoftmaxMLP_large\"] = rec_softmax_large\n",
    "neural_results.append(evaluate_model(rec_softmax_large, \"SoftmaxMLP_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training TwoTower base\")\n",
    "twotower_base = train_two_tower(\n",
    "    model_name=\"TwoTower_base\",\n",
    "    emb_dim=64,\n",
    "    hidden_dims=(512, 256),\n",
    "    out_dim=128,\n",
    "    epochs=TWOTOWER_EPOCHS,\n",
    "    lr=2e-3,\n",
    "    wd=1e-5,\n",
    "    temperature=0.07,\n",
    ")\n",
    "rec_twotower_base = make_two_tower_recommender(twotower_base)\n",
    "neural_models[\"TwoTower_base\"] = rec_twotower_base\n",
    "neural_results.append(evaluate_model(rec_twotower_base, \"TwoTower_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training TwoTower large\")\n",
    "twotower_large = train_two_tower(\n",
    "    model_name=\"TwoTower_large\",\n",
    "    emb_dim=96,\n",
    "    hidden_dims=(1024, 512, 256),\n",
    "    out_dim=192,\n",
    "    epochs=TWOTOWER_EPOCHS,\n",
    "    lr=1.5e-3,\n",
    "    wd=1e-5,\n",
    "    temperature=0.05,\n",
    ")\n",
    "rec_twotower_large = make_two_tower_recommender(twotower_large)\n",
    "neural_models[\"TwoTower_large\"] = rec_twotower_large\n",
    "neural_results.append(evaluate_model(rec_twotower_large, \"TwoTower_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training MultVAE base\")\n",
    "multvae_base = train_multvae(\n",
    "    model_name=\"MultVAE_base\",\n",
    "    hidden_dim=768,\n",
    "    latent_dim=192,\n",
    "    dropout=0.25,\n",
    "    epochs=MULTVAE_EPOCHS,\n",
    "    lr=1e-3,\n",
    "    wd=0.0,\n",
    ")\n",
    "rec_multvae_base = make_multvae_recommender(multvae_base)\n",
    "neural_models[\"MultVAE_base\"] = rec_multvae_base\n",
    "neural_results.append(evaluate_model(rec_multvae_base, \"MultVAE_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Training MultVAE_large\")\n",
    "multvae_large = train_multvae(\n",
    "    model_name=\"MultVAE_large\",\n",
    "    hidden_dim=1024,\n",
    "    latent_dim=256,\n",
    "    dropout=0.30,\n",
    "    epochs=MULTVAE_EPOCHS,\n",
    "    lr=8e-4,\n",
    "    wd=0.0,\n",
    ")\n",
    "rec_multvae_large = make_multvae_recommender(multvae_large)\n",
    "neural_models[\"MultVAE_large\"] = rec_multvae_large\n",
    "neural_results.append(evaluate_model(rec_multvae_large, \"MultVAE_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "neural_results_df = pd.DataFrame(neural_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "display(neural_results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c0a525a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "      <th>Family</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>EASE_binary_l500</td>\n",
       "      <td>0.987000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.671896</td>\n",
       "      <td>0.754519</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ItemKNN_binary_k100</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ItemKNN_binary_k150</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>0.989333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.670471</td>\n",
       "      <td>0.753456</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>MultVAE_large</td>\n",
       "      <td>0.992667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.668940</td>\n",
       "      <td>0.752430</td>\n",
       "      <td>Neural</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>EASE_binary_l200</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.668304</td>\n",
       "      <td>0.751848</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>MultVAE_base</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.665517</td>\n",
       "      <td>0.749753</td>\n",
       "      <td>Neural</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>EASE_binary_l50</td>\n",
       "      <td>0.985333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.665105</td>\n",
       "      <td>0.749452</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>EASE_binary_l100</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.665007</td>\n",
       "      <td>0.749386</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>SoftmaxMLP_base</td>\n",
       "      <td>0.987667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663778</td>\n",
       "      <td>0.748419</td>\n",
       "      <td>Neural</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>EASE_count_l100</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663710</td>\n",
       "      <td>0.748281</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>EASE_count_l50</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663693</td>\n",
       "      <td>0.748266</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663483</td>\n",
       "      <td>0.748141</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>CountryPopularity</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663483</td>\n",
       "      <td>0.748141</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>EASE_count_l200</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663510</td>\n",
       "      <td>0.748131</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>EASE_count_l500</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663443</td>\n",
       "      <td>0.748078</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>TwoTower_large</td>\n",
       "      <td>0.987667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.663210</td>\n",
       "      <td>0.747932</td>\n",
       "      <td>Neural</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.662649</td>\n",
       "      <td>0.747472</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>SoftmaxMLP_large</td>\n",
       "      <td>0.985667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.661187</td>\n",
       "      <td>0.746365</td>\n",
       "      <td>Neural</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>TwoTower_base</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.656112</td>\n",
       "      <td>0.742579</td>\n",
       "      <td>Neural</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>ItemKNN_bm25_k100</td>\n",
       "      <td>0.991333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.651311</td>\n",
       "      <td>0.739278</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>ItemKNN_bm25_k150</td>\n",
       "      <td>0.991333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.651311</td>\n",
       "      <td>0.739278</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>ItemKNN_binary_k50</td>\n",
       "      <td>0.856333</td>\n",
       "      <td>0.858000</td>\n",
       "      <td>0.858000</td>\n",
       "      <td>0.620887</td>\n",
       "      <td>0.681031</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>ItemKNN_bm25_k50</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.704667</td>\n",
       "      <td>0.546567</td>\n",
       "      <td>0.587167</td>\n",
       "      <td>Regular</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model      HR@5     HR@10     HR@20    MRR@10   NDCG@10  \\\n",
       "0      EASE_binary_l500  0.987000  1.000000  1.000000  0.671896  0.754519   \n",
       "1   ItemKNN_binary_k100  0.992000  1.000000  1.000000  0.670833  0.753761   \n",
       "2   ItemKNN_binary_k150  0.992000  1.000000  1.000000  0.670833  0.753761   \n",
       "3     EASE_binary_l1000  0.989333  1.000000  1.000000  0.670471  0.753456   \n",
       "4         MultVAE_large  0.992667  1.000000  1.000000  0.668940  0.752430   \n",
       "5      EASE_binary_l200  0.986000  1.000000  1.000000  0.668304  0.751848   \n",
       "6          MultVAE_base  0.989000  1.000000  1.000000  0.665517  0.749753   \n",
       "7       EASE_binary_l50  0.985333  1.000000  1.000000  0.665105  0.749452   \n",
       "8      EASE_binary_l100  0.986000  1.000000  1.000000  0.665007  0.749386   \n",
       "9       SoftmaxMLP_base  0.987667  1.000000  1.000000  0.663778  0.748419   \n",
       "10      EASE_count_l100  0.988667  1.000000  1.000000  0.663710  0.748281   \n",
       "11       EASE_count_l50  0.988667  1.000000  1.000000  0.663693  0.748266   \n",
       "12           Popularity  0.986000  1.000000  1.000000  0.663483  0.748141   \n",
       "13    CountryPopularity  0.986000  1.000000  1.000000  0.663483  0.748141   \n",
       "14      EASE_count_l200  0.989000  1.000000  1.000000  0.663510  0.748131   \n",
       "15      EASE_count_l500  0.989000  1.000000  1.000000  0.663443  0.748078   \n",
       "16       TwoTower_large  0.987667  1.000000  1.000000  0.663210  0.747932   \n",
       "17     EASE_count_l1000  0.988667  1.000000  1.000000  0.662649  0.747472   \n",
       "18     SoftmaxMLP_large  0.985667  1.000000  1.000000  0.661187  0.746365   \n",
       "19        TwoTower_base  0.986000  1.000000  1.000000  0.656112  0.742579   \n",
       "20    ItemKNN_bm25_k100  0.991333  1.000000  1.000000  0.651311  0.739278   \n",
       "21    ItemKNN_bm25_k150  0.991333  1.000000  1.000000  0.651311  0.739278   \n",
       "22   ItemKNN_binary_k50  0.856333  0.858000  0.858000  0.620887  0.681031   \n",
       "23     ItemKNN_bm25_k50  0.704667  0.704667  0.704667  0.546567  0.587167   \n",
       "\n",
       "     Family  \n",
       "0   Regular  \n",
       "1   Regular  \n",
       "2   Regular  \n",
       "3   Regular  \n",
       "4    Neural  \n",
       "5   Regular  \n",
       "6    Neural  \n",
       "7   Regular  \n",
       "8   Regular  \n",
       "9    Neural  \n",
       "10  Regular  \n",
       "11  Regular  \n",
       "12  Regular  \n",
       "13  Regular  \n",
       "14  Regular  \n",
       "15  Regular  \n",
       "16   Neural  \n",
       "17  Regular  \n",
       "18   Neural  \n",
       "19   Neural  \n",
       "20  Regular  \n",
       "21  Regular  \n",
       "22  Regular  \n",
       "23  Regular  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Top regular models:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>EASE_binary_l500</td>\n",
       "      <td>0.987000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.671896</td>\n",
       "      <td>0.754519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ItemKNN_binary_k100</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ItemKNN_binary_k150</td>\n",
       "      <td>0.992000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.753761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>0.989333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.670471</td>\n",
       "      <td>0.753456</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>EASE_binary_l200</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.668304</td>\n",
       "      <td>0.751848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>EASE_binary_l50</td>\n",
       "      <td>0.985333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.665105</td>\n",
       "      <td>0.749452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>EASE_binary_l100</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.665007</td>\n",
       "      <td>0.749386</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>EASE_count_l100</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.663710</td>\n",
       "      <td>0.748281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>EASE_count_l50</td>\n",
       "      <td>0.988667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.663693</td>\n",
       "      <td>0.748266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.663483</td>\n",
       "      <td>0.748141</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Model      HR@5  HR@10  HR@20    MRR@10   NDCG@10\n",
       "0     EASE_binary_l500  0.987000    1.0    1.0  0.671896  0.754519\n",
       "1  ItemKNN_binary_k100  0.992000    1.0    1.0  0.670833  0.753761\n",
       "2  ItemKNN_binary_k150  0.992000    1.0    1.0  0.670833  0.753761\n",
       "3    EASE_binary_l1000  0.989333    1.0    1.0  0.670471  0.753456\n",
       "4     EASE_binary_l200  0.986000    1.0    1.0  0.668304  0.751848\n",
       "5      EASE_binary_l50  0.985333    1.0    1.0  0.665105  0.749452\n",
       "6     EASE_binary_l100  0.986000    1.0    1.0  0.665007  0.749386\n",
       "7      EASE_count_l100  0.988667    1.0    1.0  0.663710  0.748281\n",
       "8       EASE_count_l50  0.988667    1.0    1.0  0.663693  0.748266\n",
       "9           Popularity  0.986000    1.0    1.0  0.663483  0.748141"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Top neural models:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultVAE_large</td>\n",
       "      <td>0.992667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.668940</td>\n",
       "      <td>0.752430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE_base</td>\n",
       "      <td>0.989000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.665517</td>\n",
       "      <td>0.749753</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>SoftmaxMLP_base</td>\n",
       "      <td>0.987667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.663778</td>\n",
       "      <td>0.748419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>TwoTower_large</td>\n",
       "      <td>0.987667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.663210</td>\n",
       "      <td>0.747932</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>SoftmaxMLP_large</td>\n",
       "      <td>0.985667</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.661187</td>\n",
       "      <td>0.746365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>TwoTower_base</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.656112</td>\n",
       "      <td>0.742579</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Model      HR@5  HR@10  HR@20    MRR@10   NDCG@10\n",
       "0     MultVAE_large  0.992667    1.0    1.0  0.668940  0.752430\n",
       "1      MultVAE_base  0.989000    1.0    1.0  0.665517  0.749753\n",
       "2   SoftmaxMLP_base  0.987667    1.0    1.0  0.663778  0.748419\n",
       "3    TwoTower_large  0.987667    1.0    1.0  0.663210  0.747932\n",
       "4  SoftmaxMLP_large  0.985667    1.0    1.0  0.661187  0.746365\n",
       "5     TwoTower_base  0.986000    1.0    1.0  0.656112  0.742579"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Final combined comparison\n",
    "\n",
    "all_results_df = pd.concat(\n",
    "    [regular_results_df.assign(Family=\"Regular\"), neural_results_df.assign(Family=\"Neural\")],\n",
    "    ignore_index=True\n",
    ").sort_values([\"HR@10\", \"NDCG@10\", \"HR@20\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "display(all_results_df)\n",
    "\n",
    "print()\n",
    "print(\"Top regular models:\")\n",
    "display(regular_results_df.head(10))\n",
    "\n",
    "print()\n",
    "print(\"Top neural models:\")\n",
    "display(neural_results_df.head(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8f4f73bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best overall model: EASE_binary_l500\n",
      "Example user / segment:\n",
      "UK | p2 | m2 | Used\n",
      "\n",
      "Seen items (first 20):\n",
      "['Audi :: A4', 'Audi :: A6', 'Audi :: Q5', 'Audi :: Q7', 'BMW :: 3 Series', 'BMW :: 5 Series', 'BMW :: X3', 'BMW :: X5', 'BMW :: Z4', 'Chevrolet :: Malibu', 'Chevrolet :: Silverado', 'Chevrolet :: Trax', 'Chrysler :: 300', 'Chrysler :: Pacifica', 'Chrysler :: Voyager', 'Dodge :: Challenger', 'Dodge :: Charger', 'Dodge :: Durango', 'Dodge :: Journey', 'Dodge :: Ram 1500']\n",
      "\n",
      "Recommended items:\n",
      "['Chevrolet :: Equinox', 'Chevrolet :: Camaro', 'Subaru :: Forester', 'Audi :: A3']\n"
     ]
    }
   ],
   "source": [
    "# Show example recommendations from the best overall model\n",
    "\n",
    "best_model_name = all_results_df.iloc[0][\"Model\"]\n",
    "print(\"Best overall model:\", best_model_name)\n",
    "\n",
    "if best_model_name in regular_models:\n",
    "    best_recommender = regular_models[best_model_name]\n",
    "else:\n",
    "    best_recommender = neural_models[best_model_name]\n",
    "\n",
    "example_user = int(train_interactions[\"user_idx\"].sample(1, random_state=RANDOM_STATE).iloc[0])\n",
    "\n",
    "print(\"Example user / segment:\")\n",
    "print(data[\"user_ids\"][example_user])\n",
    "print()\n",
    "\n",
    "print(\"Seen items (first 20):\")\n",
    "print([item_ids[i] for i in sorted(list(user_seen[example_user]))[:20]])\n",
    "print()\n",
    "\n",
    "print(\"Recommended items:\")\n",
    "print([item_ids[i] for i in best_recommender(example_user, n=10)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b394fadb",
   "metadata": {},
   "source": [
    "## Notes for the next iteration\n",
    "\n",
    "If you still want even better quality after this run, the next tuning steps are:\n",
    "\n",
    "1. keep only the best **formulation** from the quick search;\n",
    "2. keep only the best **2–3 regular** and **2–3 neural** models from the result table;\n",
    "3. increase:\n",
    "   - item granularity a little more if the item space is still too small;\n",
    "   - batch sizes if the GPU still has spare memory;\n",
    "   - hidden sizes for the strongest neural model;\n",
    "4. tune:\n",
    "   - `EASE_LAMBDA_GRID`\n",
    "   - `ITEMKNN_NEIGHBORS_GRID`\n",
    "   - `temperature` in the two-tower model\n",
    "   - `latent_dim` and `hidden_dim` in MultVAE\n",
    "5. if the GPU is still underused, the most direct way to push it harder is:\n",
    "   - larger item granularity,\n",
    "   - larger batches,\n",
    "   - larger hidden layers,\n",
    "   - or replacing the softmax / two-tower models with even wider networks.\n",
    "\n",
    "This notebook is intentionally written to search broadly first."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
