{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2e584ae2",
   "metadata": {},
   "source": [
    "# Car recommender notebook — hard formulation, broader model zoo, GPU-heavy neural search\n",
    "\n",
    "This notebook is the corrected follow-up to the earlier car recommender runs.\n",
    "\n",
    "Main changes:\n",
    "- screens out trivial task formulations automatically\n",
    "- uses only harder pseudo-user / item constructions\n",
    "- evaluates a larger set of regular and neural recommenders\n",
    "- keeps the neural training path GPU-friendly with manual on-device batching\n",
    "- produces one combined ranking table for comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3dd295bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU thread target: 16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import gc\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "# Thread settings for NumPy / SciPy BLAS backends.\n",
    "# These must be set before importing numpy/scipy to have the best chance of taking effect.\n",
    "CPU_THREADS = min(16, os.cpu_count() or 1)\n",
    "os.environ[\"OMP_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"MKL_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = str(CPU_THREADS)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn.preprocessing import LabelEncoder, normalize\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from tqdm.auto import tqdm\n",
    "try:\n",
    "    from threadpoolctl import threadpool_limits\n",
    "except Exception:\n",
    "    threadpool_limits = None\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "\n",
    "if threadpool_limits is not None:\n",
    "    try:\n",
    "        threadpool_limits(limits=CPU_THREADS)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"CPU thread target:\", CPU_THREADS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77a4eb83",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Formulations: ['H1_country_price_mileage_cond_age__brand_model_age_cond', 'H2_country_price_mileage_cond_age_color__brand_model_age_cond_color', 'H3_country_price_mileage_age_color__brand_model_age_cond_color', 'H4_country_price_mileage_cond_color__brand_model_age_cond_color']\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "CURRENT_YEAR = 2026\n",
    "\n",
    "# Binning\n",
    "N_PRICE_BINS = 12\n",
    "N_MILEAGE_BINS = 12\n",
    "N_AGE_BINS = 10\n",
    "\n",
    "# Interaction filtering\n",
    "MIN_ITEMS_PER_USER = 5\n",
    "MIN_USERS_PER_ITEM = 10\n",
    "MAX_FILTER_ITERS = 10\n",
    "\n",
    "# If a formulation is too sparse under the default thresholds, progressively relax them.\n",
    "FILTER_SCHEDULE = [\n",
    "    (5, 10),\n",
    "    (4, 8),\n",
    "    (3, 5),\n",
    "    (2, 3),\n",
    "]\n",
    "\n",
    "# Non-triviality constraints for automatic formulation selection\n",
    "MIN_ITEMS_FOR_VALID_FORMULATION = 500\n",
    "MAX_POP_HR10_FOR_VALID_FORMULATION = 0.15\n",
    "MAX_AVG_INTERACTIONS_PER_USER = 50\n",
    "\n",
    "# Ranking metrics\n",
    "TOP_KS = [5, 10, 20]\n",
    "\n",
    "# Regular model search grids\n",
    "ITEMKNN_NEIGHBORS_GRID = [100, 200, 300]\n",
    "EASE_LAMBDA_GRID = [100.0, 200.0, 500.0, 1000.0, 2000.0]\n",
    "P3_ALPHA_GRID = [0.5, 1.0]\n",
    "RP3_GRID = [(0.8, 0.3), (1.0, 0.6)]\n",
    "\n",
    "# Neural training defaults\n",
    "SOFTMAX_EPOCHS = 20\n",
    "TWOTOWER_EPOCHS = 20\n",
    "AUTOENC_EPOCHS = 80\n",
    "PAIRWISE_EPOCHS = 24\n",
    "\n",
    "SOFTMAX_BATCH_SIZE = 16384\n",
    "TWOTOWER_BATCH_SIZE = 8192\n",
    "AUTOENC_BATCH_SIZE = 4096\n",
    "PAIRWISE_BATCH_SIZE = 32768\n",
    "\n",
    "USE_AMP = False  # initialized safely before optional torch import; updated later if neural models are enabled\n",
    "\n",
    "# Harder candidate formulations only.\n",
    "# These all keep the richer item granularity and avoid the trivial 88-item setup.\n",
    "FORMULATIONS = {\n",
    "    \"H1_country_price_mileage_cond_age__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"AgeBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "    \"H2_country_price_mileage_cond_age_color__brand_model_age_cond_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"AgeBin\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H3_country_price_mileage_age_color__brand_model_age_cond_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"AgeBin\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H4_country_price_mileage_cond_color__brand_model_age_cond_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))\n",
    "\n",
    "EASE_EVAL_BATCH_SIZE = 2048\n",
    "EASE_USE_GPU = True\n",
    "EASE_GPU_BATCH_SIZE = 4096\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "44983c26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch device: cuda\n",
      "USE_AMP: True\n"
     ]
    }
   ],
   "source": [
    "# Runtime toggles\n",
    "\n",
    "# Neural models section\n",
    "RUN_NEURAL = True\n",
    "PREFER_CUDA = True\n",
    "\n",
    "# EASE can also use torch on GPU for batched scoring even before the neural section.\n",
    "NEED_TORCH = RUN_NEURAL or EASE_USE_GPU\n",
    "\n",
    "HAS_TORCH = False\n",
    "torch = None\n",
    "nn = None\n",
    "F = None\n",
    "device = None\n",
    "\n",
    "if NEED_TORCH:\n",
    "    import torch\n",
    "    import torch.nn as nn\n",
    "    import torch.nn.functional as F\n",
    "\n",
    "    HAS_TORCH = True\n",
    "    torch.manual_seed(RANDOM_STATE)\n",
    "\n",
    "    device = torch.device(\"cuda\" if PREFER_CUDA and torch.cuda.is_available() else \"cpu\")\n",
    "    USE_AMP = (device.type == \"cuda\")\n",
    "    print(\"Torch device:\", device)\n",
    "    print(\"USE_AMP:\", USE_AMP)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "        torch.backends.cuda.matmul.allow_tf32 = True\n",
    "        torch.backends.cudnn.allow_tf32 = True\n",
    "        torch.backends.cudnn.benchmark = True\n",
    "        try:\n",
    "            torch.set_float32_matmul_precision(\"high\")\n",
    "        except Exception:\n",
    "            pass\n",
    "else:\n",
    "    USE_AMP = False\n",
    "    print(\"PyTorch disabled for this run.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4c1a8ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows: 1000000\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>price_(23744.848 to  29982.3]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(24.0 to  26.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>price_(48755.822 to  55009.26]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(16665.0 to  33316.0]</td>\n",
       "      <td>age_(22.0 to  24.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>price_(5000.059 to  11233.652]</td>\n",
       "      <td>mileage_(66608.0 to  83237.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    3   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   26   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   12   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   23   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   14   \n",
       "\n",
       "                         PriceBin                     MileageBin  \\\n",
       "0   price_(23744.848 to  29982.3]  mileage_(49995.0 to  66608.0]   \n",
       "1  price_(11233.652 to  17482.96]  mileage_(49995.0 to  66608.0]   \n",
       "2  price_(48755.822 to  55009.26]   mileage_(-0.001 to  16665.0]   \n",
       "3  price_(11233.652 to  17482.96]  mileage_(16665.0 to  33316.0]   \n",
       "4  price_(5000.059 to  11233.652]  mileage_(66608.0 to  83237.0]   \n",
       "\n",
       "                AgeBin  \n",
       "0  age_(1.999 to  4.0]  \n",
       "1  age_(24.0 to  26.0]  \n",
       "2  age_(11.0 to  14.0]  \n",
       "3  age_(22.0 to  24.0]  \n",
       "4  age_(11.0 to  14.0]  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>price_(23744.848 to  29982.3]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(24.0 to  26.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>price_(48755.822 to  55009.26]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(16665.0 to  33316.0]</td>\n",
       "      <td>age_(22.0 to  24.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>price_(5000.059 to  11233.652]</td>\n",
       "      <td>mileage_(66608.0 to  83237.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition Country  Age                        PriceBin  \\\n",
       "0  Certified Pre-Owned  Brazil    3   price_(23744.848 to  29982.3]   \n",
       "1  Certified Pre-Owned   Italy   26  price_(11233.652 to  17482.96]   \n",
       "2  Certified Pre-Owned      UK   12  price_(48755.822 to  55009.26]   \n",
       "3                 Used  Mexico   23  price_(11233.652 to  17482.96]   \n",
       "4                 Used     USA   14  price_(5000.059 to  11233.652]   \n",
       "\n",
       "                      MileageBin               AgeBin  \n",
       "0  mileage_(49995.0 to  66608.0]  age_(1.999 to  4.0]  \n",
       "1  mileage_(49995.0 to  66608.0]  age_(24.0 to  26.0]  \n",
       "2   mileage_(-0.001 to  16665.0]  age_(11.0 to  14.0]  \n",
       "3  mileage_(16665.0 to  33316.0]  age_(22.0 to  24.0]  \n",
       "4  mileage_(66608.0 to  83237.0]  age_(11.0 to  14.0]  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean data\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "expected_cols = [\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\"]\n",
    "missing = [c for c in expected_cols if c not in df.columns]\n",
    "if missing:\n",
    "    raise ValueError(f\"Missing expected columns: {missing}\")\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\"]:\n",
    "    df[col] = df[col].astype(str).str.strip().replace({\"\": \"Unknown\", \"nan\": \"Unknown\"})\n",
    "\n",
    "df[\"Year\"] = pd.to_numeric(df[\"Year\"], errors=\"coerce\")\n",
    "df[\"Price\"] = pd.to_numeric(df[\"Price\"], errors=\"coerce\")\n",
    "df[\"Mileage\"] = pd.to_numeric(df[\"Mileage\"], errors=\"coerce\")\n",
    "\n",
    "df = df.dropna(subset=[\"Year\", \"Price\", \"Mileage\"]).copy()\n",
    "\n",
    "df = df[df[\"Year\"].between(1990, CURRENT_YEAR)].copy()\n",
    "df = df[df[\"Price\"] > 0].copy()\n",
    "df = df[df[\"Mileage\"] >= 0].copy()\n",
    "\n",
    "df[\"Age\"] = (CURRENT_YEAR - df[\"Year\"]).clip(lower=0, upper=50)\n",
    "\n",
    "def make_qbin(series, n_bins, prefix):\n",
    "    cat = pd.qcut(series, q=n_bins, duplicates=\"drop\")\n",
    "    return cat.astype(str).str.replace(\",\", \" to \", regex=False).map(lambda x: f\"{prefix}_{x}\")\n",
    "\n",
    "df[\"PriceBin\"] = make_qbin(df[\"Price\"], N_PRICE_BINS, \"price\")\n",
    "df[\"MileageBin\"] = make_qbin(df[\"Mileage\"], N_MILEAGE_BINS, \"mileage\")\n",
    "df[\"AgeBin\"] = make_qbin(df[\"Age\"], N_AGE_BINS, \"age\")\n",
    "\n",
    "print(\"Rows:\", len(df))\n",
    "display(df.head())\n",
    "display(df[[\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\", \"Age\", \"PriceBin\", \"MileageBin\", \"AgeBin\"]].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "071498f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user=5, min_users_per_item=10, max_iters=10):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_n = out.shape[0]\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if out.shape[0] == old_n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    raw_interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "           .size()\n",
    "           .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = None\n",
    "    used_thresholds = None\n",
    "\n",
    "    for min_items_per_user, min_users_per_item in FILTER_SCHEDULE:\n",
    "        candidate = iterative_filter(\n",
    "            raw_interactions,\n",
    "            min_items_per_user=min_items_per_user,\n",
    "            min_users_per_item=min_users_per_item,\n",
    "            max_iters=MAX_FILTER_ITERS\n",
    "        ).reset_index(drop=True)\n",
    "\n",
    "        if not candidate.empty:\n",
    "            interactions = candidate\n",
    "            used_thresholds = (min_items_per_user, min_users_per_item)\n",
    "            break\n",
    "\n",
    "    if interactions is None or interactions.empty:\n",
    "        raise ValueError(\n",
    "            f\"{formulation_name} produced no interactions after filtering, \"\n",
    "            f\"even after trying FILTER_SCHEDULE={FILTER_SCHEDULE}\"\n",
    "        )\n",
    "\n",
    "    valid_users = set(interactions[\"user_id\"])\n",
    "    valid_items = set(interactions[\"item_id\"])\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    num_users = interactions[\"user_idx\"].nunique()\n",
    "    num_items = interactions[\"item_idx\"].nunique()\n",
    "\n",
    "    user_feature_df = (\n",
    "        tmp[tmp[\"user_id\"].isin(valid_users)][[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "    user_feature_df = user_feature_df.sort_values(\"user_idx\").reset_index(drop=True)\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[tmp[\"item_id\"].isin(valid_items)][[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "    item_feature_df = item_feature_df.sort_values(\"item_idx\").reset_index(drop=True)\n",
    "\n",
    "    # Weighted leave-one-out split\n",
    "    rng = np.random.default_rng(RANDOM_STATE)\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = rng.choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(picked)\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    user_seen = {\n",
    "        int(uid): set(g[\"item_idx\"].astype(int).tolist())\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "    user_seen_arrays = {\n",
    "        uid: np.array(sorted(list(seen)), dtype=np.int32)\n",
    "        for uid, seen in user_seen.items()\n",
    "    }\n",
    "\n",
    "    user_strength = {\n",
    "        int(uid): {int(i): float(c) for i, c in zip(g[\"item_idx\"], g[\"count\"])}\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(uid): int(i)\n",
    "        for uid, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy()\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df,\n",
    "        \"item_feature_df\": item_feature_df,\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen\": user_seen,\n",
    "        \"user_seen_arrays\": user_seen_arrays,\n",
    "        \"user_strength\": user_strength,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_encoder.classes_.tolist(),\n",
    "        \"user_ids\": user_encoder.classes_.tolist(),\n",
    "        \"used_thresholds\": used_thresholds,\n",
    "    }\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "\n",
    "    avg_train_per_user = bundle[\"train_interactions\"].shape[0] / max(bundle[\"num_users\"], 1)\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Thresholds :\", bundle.get(\"used_thresholds\"))\n",
    "    print(\"Train rows :\", bundle[\"train_interactions\"].shape[0])\n",
    "    print(\"Test rows  :\", bundle[\"test_interactions\"].shape[0])\n",
    "    print(\"Avg train interactions/user:\", round(avg_train_per_user, 3))\n",
    "    print(\"Matrix shape:\", bundle[\"X_binary\"].shape)\n",
    "\n",
    "def topn_from_scores(scores, seen, n):\n",
    "    scores = np.asarray(scores, dtype=np.float32).copy()\n",
    "    if seen:\n",
    "        seen_idx = np.fromiter(seen, dtype=np.int32)\n",
    "        scores[seen_idx] = -np.inf\n",
    "    n = min(int(n), scores.shape[0])\n",
    "    if n <= 0:\n",
    "        return []\n",
    "    idx = np.argpartition(scores, -n)[-n:]\n",
    "    idx = idx[np.argsort(scores[idx])[::-1]]\n",
    "    return idx.astype(int).tolist()\n",
    "\n",
    "def topn_from_torch_scores(scores, seen, n):\n",
    "    x = scores.clone()\n",
    "    if seen:\n",
    "        idx = torch.tensor(list(seen), dtype=torch.long, device=x.device)\n",
    "        x[idx] = -1e9\n",
    "    k = min(int(n), int(x.shape[0]))\n",
    "    return torch.topk(x, k=k).indices.detach().cpu().tolist()\n",
    "\n",
    "def hit_rate_at_k(recs, true_item, k):\n",
    "    return 1.0 if true_item in recs[:k] else 0.0\n",
    "\n",
    "def mrr_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / rank\n",
    "    return 0.0\n",
    "\n",
    "def ndcg_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / math.log2(rank + 1)\n",
    "    return 0.0\n",
    "\n",
    "def evaluate_model(recommend_fn, model_name, bundle, user_indices=None, ks=(5, 10, 20)):\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr10 = 0.0\n",
    "    ndcg10 = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    for uid in tqdm(user_indices, desc=model_name):\n",
    "        uid = int(uid)\n",
    "        true_item = test_item_by_user.get(uid, None)\n",
    "        if true_item is None:\n",
    "            continue\n",
    "\n",
    "        recs = recommend_fn(uid, n=max(ks))\n",
    "        valid += 1\n",
    "\n",
    "        for k in ks:\n",
    "            hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "        mrr10 += mrr_at_k(recs, true_item, 10)\n",
    "        ndcg10 += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    out = {\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": valid,\n",
    "    }\n",
    "    for k in ks:\n",
    "        out[f\"HR@{k}\"] = hits[k] / max(valid, 1)\n",
    "    out[\"MRR@10\"] = mrr10 / max(valid, 1)\n",
    "    out[\"NDCG@10\"] = ndcg10 / max(valid, 1)\n",
    "    return out\n",
    "\n",
    "def evaluate_ease_batched(\n",
    "    B_matrix,\n",
    "    X_train_matrix,\n",
    "    user_seen_dict,\n",
    "    test_item_by_user,\n",
    "    model_name,\n",
    "    user_indices=None,\n",
    "    ks=(5, 10, 20),\n",
    "    batch_size=2048,\n",
    "    user_seen_arrays_dict=None,\n",
    "):\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(test_item_by_user.keys()), dtype=np.int32)\n",
    "    else:\n",
    "        user_indices = np.asarray(user_indices, dtype=np.int32)\n",
    "\n",
    "    max_k = max(ks)\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr10 = 0.0\n",
    "    ndcg10 = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    use_gpu = bool(\n",
    "        HAS_TORCH and EASE_USE_GPU and (device is not None) and (str(device).startswith(\"cuda\"))\n",
    "    )\n",
    "\n",
    "    B_t = None\n",
    "    if use_gpu:\n",
    "        B_t = torch.as_tensor(B_matrix, dtype=torch.float32, device=device)\n",
    "        batch_size = max(batch_size, EASE_GPU_BATCH_SIZE)\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), batch_size), desc=model_name):\n",
    "        batch_uids = user_indices[start:start + batch_size]\n",
    "\n",
    "        X_batch = X_train_matrix[batch_uids].toarray().astype(np.float32, copy=False)\n",
    "\n",
    "        if use_gpu:\n",
    "            with torch.inference_mode():\n",
    "                X_batch_t = torch.from_numpy(X_batch).to(device, non_blocking=True)\n",
    "                scores_t = X_batch_t @ B_t\n",
    "\n",
    "                for row_idx, uid in enumerate(batch_uids):\n",
    "                    uid_int = int(uid)\n",
    "                    if user_seen_arrays_dict is not None:\n",
    "                        seen_idx_np = user_seen_arrays_dict.get(uid_int, None)\n",
    "                    else:\n",
    "                        seen = user_seen_dict.get(uid_int, None)\n",
    "                        seen_idx_np = None if not seen else np.fromiter(seen, dtype=np.int32)\n",
    "                    if seen_idx_np is not None and len(seen_idx_np) > 0:\n",
    "                        seen_idx_t = torch.as_tensor(seen_idx_np, dtype=torch.long, device=device)\n",
    "                        scores_t[row_idx, seen_idx_t] = float(\"-inf\")\n",
    "\n",
    "                recs_batch = torch.topk(scores_t, k=max_k, dim=1).indices.detach().cpu().numpy()\n",
    "\n",
    "                del X_batch_t, scores_t\n",
    "        else:\n",
    "            scores = X_batch @ B_matrix\n",
    "            scores = np.asarray(scores, dtype=np.float32)\n",
    "\n",
    "            for row_idx, uid in enumerate(batch_uids):\n",
    "                uid_int = int(uid)\n",
    "                if user_seen_arrays_dict is not None:\n",
    "                    seen_idx = user_seen_arrays_dict.get(uid_int, None)\n",
    "                else:\n",
    "                    seen = user_seen_dict.get(uid_int, None)\n",
    "                    seen_idx = None if not seen else np.fromiter(seen, dtype=np.int32)\n",
    "                if seen_idx is not None and len(seen_idx) > 0:\n",
    "                    scores[row_idx, seen_idx] = -np.inf\n",
    "\n",
    "            top_idx = np.argpartition(scores, -max_k, axis=1)[:, -max_k:]\n",
    "            top_scores = np.take_along_axis(scores, top_idx, axis=1)\n",
    "            order = np.argsort(top_scores, axis=1)[:, ::-1]\n",
    "            recs_batch = np.take_along_axis(top_idx, order, axis=1)\n",
    "\n",
    "        for row_idx, uid in enumerate(batch_uids):\n",
    "            uid = int(uid)\n",
    "            true_item = test_item_by_user.get(uid, None)\n",
    "            if true_item is None:\n",
    "                continue\n",
    "\n",
    "            recs = recs_batch[row_idx].astype(int).tolist()\n",
    "            valid += 1\n",
    "\n",
    "            for k in ks:\n",
    "                hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "            mrr10 += mrr_at_k(recs, true_item, 10)\n",
    "            ndcg10 += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    if use_gpu:\n",
    "        del B_t\n",
    "        try:\n",
    "            torch.cuda.empty_cache()\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "    out = {\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": valid,\n",
    "    }\n",
    "    for k in ks:\n",
    "        out[f\"HR@{k}\"] = hits[k] / max(valid, 1)\n",
    "    out[\"MRR@10\"] = mrr10 / max(valid, 1)\n",
    "    out[\"NDCG@10\"] = ndcg10 / max(valid, 1)\n",
    "    return out\n",
    "\n",
    "def bm25_weight(X, K1=100, B=0.8):\n",
    "    X = X.tocoo(copy=True).astype(np.float32)\n",
    "    N = float(X.shape[0])\n",
    "\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    avgdl = row_sums.mean() + 1e-8\n",
    "\n",
    "    df = np.bincount(X.col, minlength=X.shape[1]).astype(np.float32)\n",
    "    idf = np.log((N - df + 0.5) / (df + 0.5))\n",
    "    idf = np.maximum(idf, 0)\n",
    "\n",
    "    denom = X.data + K1 * (1 - B + B * row_sums[X.row] / avgdl)\n",
    "    data = X.data * (K1 + 1) / denom * idf[X.col]\n",
    "\n",
    "    return csr_matrix((data, (X.row, X.col)), shape=X.shape, dtype=np.float32)\n",
    "\n",
    "def l1_row_normalize(X):\n",
    "    X = X.tocsr().astype(np.float32)\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    inv = np.zeros_like(row_sums, dtype=np.float32)\n",
    "    mask = row_sums > 0\n",
    "    inv[mask] = 1.0 / row_sums[mask]\n",
    "    return sp.diags(inv) @ X\n",
    "\n",
    "def fit_p3alpha(X_binary, alpha=1.0, beta=0.0):\n",
    "    Pui = l1_row_normalize(X_binary).power(alpha).tocsr()\n",
    "    Piu = l1_row_normalize(X_binary.T).power(alpha).tocsr()\n",
    "    S = (Piu @ Pui).astype(np.float32).tocsr()\n",
    "    S.setdiag(0.0)\n",
    "    S.eliminate_zeros()\n",
    "\n",
    "    if beta > 0:\n",
    "        item_degree = np.asarray(X_binary.sum(axis=0)).ravel().astype(np.float32)\n",
    "        penalty = np.ones_like(item_degree, dtype=np.float32)\n",
    "        mask = item_degree > 0\n",
    "        penalty[mask] = np.power(item_degree[mask], -beta)\n",
    "        S = S @ sp.diags(penalty.astype(np.float32))\n",
    "\n",
    "    return S.tocsr()\n",
    "\n",
    "def make_sparse_similarity_recommender(X_matrix, S_matrix, user_seen_dict):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_matrix.getrow(uid) @ S_matrix\n",
    "        scores = np.asarray(scores.todense()).ravel()\n",
    "        seen = user_seen_dict.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5b918a2",
   "metadata": {},
   "source": [
    "## Quick formulation screening\n",
    "\n",
    "The notebook will build several harder pseudo-user / item formulations, evaluate cheap baselines on each one, and automatically reject trivial setups.\n",
    "\n",
    "The main selection rule is:\n",
    "- enough items\n",
    "- popularity baseline not too strong\n",
    "- average interactions per pseudo-user not too high\n",
    "- strong EASE screen score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0d081b84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Formulation: H1_country_price_mileage_cond_age__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition', 'AgeBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 43199\n",
      "Items      : 2640\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 831152\n",
      "Test rows  : 43199\n",
      "Avg train interactions/user: 19.24\n",
      "Matrix shape: (43199, 2640)\n",
      "====================================================================================================\n",
      "Formulation: H2_country_price_mileage_cond_age_color__brand_model_age_cond_color\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition', 'AgeBin', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition', 'Color']\n",
      "Users      : 62921\n",
      "Items      : 13198\n",
      "Thresholds : (4, 8)\n",
      "Train rows : 237108\n",
      "Test rows  : 62921\n",
      "Avg train interactions/user: 3.768\n",
      "Matrix shape: (62921, 13198)\n",
      "====================================================================================================\n",
      "Formulation: H3_country_price_mileage_age_color__brand_model_age_cond_color\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'AgeBin', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition', 'Color']\n",
      "Users      : 112434\n",
      "Items      : 26394\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 773614\n",
      "Test rows  : 112434\n",
      "Avg train interactions/user: 6.881\n",
      "Matrix shape: (112434, 26394)\n",
      "====================================================================================================\n",
      "Formulation: H4_country_price_mileage_cond_color__brand_model_age_cond_color\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition', 'Color']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition', 'Color']\n",
      "Users      : 43200\n",
      "Items      : 26400\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 943146\n",
      "Test rows  : 43200\n",
      "Avg train interactions/user: 21.832\n",
      "Matrix shape: (43200, 26400)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>TrainRows</th>\n",
       "      <th>AvgTrainPerUser</th>\n",
       "      <th>Thresholds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H4_country_price_mileage_cond_color__brand_mod...</td>\n",
       "      <td>43200</td>\n",
       "      <td>26400</td>\n",
       "      <td>943146</td>\n",
       "      <td>21.832083</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H3_country_price_mileage_age_color__brand_mode...</td>\n",
       "      <td>112434</td>\n",
       "      <td>26394</td>\n",
       "      <td>773614</td>\n",
       "      <td>6.880606</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H2_country_price_mileage_cond_age_color__brand...</td>\n",
       "      <td>62921</td>\n",
       "      <td>13198</td>\n",
       "      <td>237108</td>\n",
       "      <td>3.768344</td>\n",
       "      <td>(4, 8)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H1_country_price_mileage_cond_age__brand_model...</td>\n",
       "      <td>43199</td>\n",
       "      <td>2640</td>\n",
       "      <td>831152</td>\n",
       "      <td>19.240075</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         Formulation   Users  Items  \\\n",
       "0  H4_country_price_mileage_cond_color__brand_mod...   43200  26400   \n",
       "1  H3_country_price_mileage_age_color__brand_mode...  112434  26394   \n",
       "2  H2_country_price_mileage_cond_age_color__brand...   62921  13198   \n",
       "3  H1_country_price_mileage_cond_age__brand_model...   43199   2640   \n",
       "\n",
       "   TrainRows  AvgTrainPerUser Thresholds  \n",
       "0     943146        21.832083    (5, 10)  \n",
       "1     773614         6.880606    (5, 10)  \n",
       "2     237108         3.768344     (4, 8)  \n",
       "3     831152        19.240075    (5, 10)  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Build all candidate formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "failed_formulations = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    try:\n",
    "        bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    except Exception as e:\n",
    "        print(f\"Skipping {name}: {type(e).__name__}: {e}\")\n",
    "        failed_formulations.append({\n",
    "            \"Formulation\": name,\n",
    "            \"Status\": \"failed\",\n",
    "            \"Reason\": f\"{type(e).__name__}: {e}\",\n",
    "        })\n",
    "        continue\n",
    "\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"TrainRows\": bundle[\"train_interactions\"].shape[0],\n",
    "        \"AvgTrainPerUser\": bundle[\"train_interactions\"].shape[0] / bundle[\"num_users\"],\n",
    "        \"Thresholds\": bundle.get(\"used_thresholds\"),\n",
    "    })\n",
    "\n",
    "if not bundles:\n",
    "    raise RuntimeError(\n",
    "        \"All candidate formulations failed. Try relaxing FILTER_SCHEDULE, reducing bin counts, \"\n",
    "        \"or simplifying item/user definitions.\"\n",
    "    )\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values(\n",
    "    [\"Items\", \"Users\"], ascending=False\n",
    ").reset_index(drop=True)\n",
    "display(summary_df)\n",
    "\n",
    "if failed_formulations:\n",
    "    failed_df = pd.DataFrame(failed_formulations)\n",
    "    display(failed_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "da9f1404",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: H1_country_price_mileage_cond_age__brand_model_age_cond\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H1_country_price_mileage_cond_age__brand_model_age_cond / Pop: 100%|██████████| 43199/43199 [00:09<00:00, 4327.28it/s]\n",
      "H1_country_price_mileage_cond_age__brand_model_age_cond / EASE200_batched: 100%|██████████| 11/11 [00:02<00:00,  5.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: H2_country_price_mileage_cond_age_color__brand_model_age_cond_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H2_country_price_mileage_cond_age_color__brand_model_age_cond_color / Pop: 100%|██████████| 62921/62921 [01:13<00:00, 858.15it/s]\n",
      "H2_country_price_mileage_cond_age_color__brand_model_age_cond_color / EASE200_batched: 100%|██████████| 16/16 [00:03<00:00,  4.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: H3_country_price_mileage_age_color__brand_model_age_cond_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H3_country_price_mileage_age_color__brand_model_age_cond_color / Pop: 100%|██████████| 112434/112434 [04:21<00:00, 430.72it/s]\n",
      "H3_country_price_mileage_age_color__brand_model_age_cond_color / EASE200_batched: 100%|██████████| 28/28 [00:11<00:00,  2.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Screening formulation: H4_country_price_mileage_cond_color__brand_model_age_cond_color\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "H4_country_price_mileage_cond_color__brand_model_age_cond_color / Pop: 100%|██████████| 43200/43200 [01:39<00:00, 435.56it/s]\n",
      "H4_country_price_mileage_cond_color__brand_model_age_cond_color / EASE200_batched: 100%|██████████| 11/11 [00:04<00:00,  2.50it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>AvgTrainPerUser</th>\n",
       "      <th>Pop_HR@10</th>\n",
       "      <th>EASE200_HR@10</th>\n",
       "      <th>EASE200_NDCG@10</th>\n",
       "      <th>Eligible</th>\n",
       "      <th>ScreenScore</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_country_price_mileage_cond_age__brand_model...</td>\n",
       "      <td>43199</td>\n",
       "      <td>2640</td>\n",
       "      <td>19.240075</td>\n",
       "      <td>0.006019</td>\n",
       "      <td>0.161717</td>\n",
       "      <td>0.077498</td>\n",
       "      <td>True</td>\n",
       "      <td>0.235992</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H2_country_price_mileage_cond_age_color__brand...</td>\n",
       "      <td>62921</td>\n",
       "      <td>13198</td>\n",
       "      <td>3.768344</td>\n",
       "      <td>0.001144</td>\n",
       "      <td>0.123091</td>\n",
       "      <td>0.055807</td>\n",
       "      <td>True</td>\n",
       "      <td>0.217112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H3_country_price_mileage_age_color__brand_mode...</td>\n",
       "      <td>112434</td>\n",
       "      <td>26394</td>\n",
       "      <td>6.880606</td>\n",
       "      <td>0.000587</td>\n",
       "      <td>0.041393</td>\n",
       "      <td>0.018853</td>\n",
       "      <td>True</td>\n",
       "      <td>0.142762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H4_country_price_mileage_cond_color__brand_mod...</td>\n",
       "      <td>43200</td>\n",
       "      <td>26400</td>\n",
       "      <td>21.832083</td>\n",
       "      <td>0.000671</td>\n",
       "      <td>0.015394</td>\n",
       "      <td>0.007158</td>\n",
       "      <td>True</td>\n",
       "      <td>0.116702</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         Formulation   Users  Items  \\\n",
       "0  H1_country_price_mileage_cond_age__brand_model...   43199   2640   \n",
       "1  H2_country_price_mileage_cond_age_color__brand...   62921  13198   \n",
       "2  H3_country_price_mileage_age_color__brand_mode...  112434  26394   \n",
       "3  H4_country_price_mileage_cond_color__brand_mod...   43200  26400   \n",
       "\n",
       "   AvgTrainPerUser  Pop_HR@10  EASE200_HR@10  EASE200_NDCG@10  Eligible  \\\n",
       "0        19.240075   0.006019       0.161717         0.077498      True   \n",
       "1         3.768344   0.001144       0.123091         0.055807      True   \n",
       "2         6.880606   0.000587       0.041393         0.018853      True   \n",
       "3        21.832083   0.000671       0.015394         0.007158      True   \n",
       "\n",
       "   ScreenScore  \n",
       "0     0.235992  \n",
       "1     0.217112  \n",
       "2     0.142762  \n",
       "3     0.116702  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected formulation: H1_country_price_mileage_cond_age__brand_model_age_cond\n",
      "Formulation: H1_country_price_mileage_cond_age__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition', 'AgeBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 43199\n",
      "Items      : 2640\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 831152\n",
      "Test rows  : 43199\n",
      "Avg train interactions/user: 19.24\n",
      "Matrix shape: (43199, 2640)\n"
     ]
    }
   ],
   "source": [
    "# Cheap screen:\n",
    "# 1) global popularity\n",
    "# 2) EASE with lambda=200 on binary interactions\n",
    "# Then reject trivial formulations automatically.\n",
    "\n",
    "screen_rows = []\n",
    "\n",
    "for name, bundle in bundles.items():\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Screening formulation:\", name)\n",
    "\n",
    "    def rec_pop(uid, n=10, bundle=bundle):\n",
    "        seen = bundle[\"user_seen\"].get(uid, set())\n",
    "        return [int(i) for i in bundle[\"global_pop_rank\"] if int(i) not in seen][:n]\n",
    "\n",
    "    pop_res = evaluate_model(rec_pop, f\"{name} / Pop\", bundle, ks=TOP_KS)\n",
    "\n",
    "    Xb = bundle[\"X_binary\"]\n",
    "    G = (Xb.T @ Xb).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += 200.0\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    B = B.astype(np.float32)\n",
    "\n",
    "    ease_res = evaluate_ease_batched(\n",
    "        B_matrix=B,\n",
    "        X_train_matrix=Xb,\n",
    "        user_seen_dict=bundle[\"user_seen\"],\n",
    "        test_item_by_user=bundle[\"test_item_by_user\"],\n",
    "        model_name=f\"{name} / EASE200_batched\",\n",
    "        user_indices=None,\n",
    "        ks=TOP_KS,\n",
    "        batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "        user_seen_arrays_dict=bundle[\"user_seen_arrays\"],\n",
    "    )\n",
    "\n",
    "    avg_train_per_user = bundle[\"train_interactions\"].shape[0] / bundle[\"num_users\"]\n",
    "\n",
    "    eligible = (\n",
    "        (bundle[\"num_items\"] >= MIN_ITEMS_FOR_VALID_FORMULATION) and\n",
    "        (pop_res[\"HR@10\"] <= MAX_POP_HR10_FOR_VALID_FORMULATION) and\n",
    "        (avg_train_per_user <= MAX_AVG_INTERACTIONS_PER_USER)\n",
    "    )\n",
    "\n",
    "    screen_score = (\n",
    "        ease_res[\"HR@10\"]\n",
    "        - 0.75 * pop_res[\"HR@10\"]\n",
    "        + 0.01 * np.log1p(bundle[\"num_items\"])\n",
    "    )\n",
    "\n",
    "    screen_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"AvgTrainPerUser\": avg_train_per_user,\n",
    "        \"Pop_HR@10\": pop_res[\"HR@10\"],\n",
    "        \"EASE200_HR@10\": ease_res[\"HR@10\"],\n",
    "        \"EASE200_NDCG@10\": ease_res[\"NDCG@10\"],\n",
    "        \"Eligible\": eligible,\n",
    "        \"ScreenScore\": screen_score,\n",
    "    })\n",
    "\n",
    "screen_df = pd.DataFrame(screen_rows).sort_values(\n",
    "    [\"Eligible\", \"ScreenScore\", \"EASE200_HR@10\", \"Items\"],\n",
    "    ascending=[False, False, False, False]\n",
    ").reset_index(drop=True)\n",
    "\n",
    "display(screen_df)\n",
    "\n",
    "if screen_df[\"Eligible\"].any():\n",
    "    BEST_FORMULATION = screen_df.loc[screen_df[\"Eligible\"]].iloc[0][\"Formulation\"]\n",
    "else:\n",
    "    BEST_FORMULATION = screen_df.iloc[0][\"Formulation\"]\n",
    "\n",
    "print(\"Selected formulation:\", BEST_FORMULATION)\n",
    "data = bundles[BEST_FORMULATION]\n",
    "print_bundle_summary(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ab89fdb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final working matrix: (43199, 2640)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>AgeBin</th>\n",
       "      <th>user_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(14.0 to  16.0]</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(16.0 to  19.0]</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(19.0 to  22.0]</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             user_id    Country  \\\n",
       "0  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "1  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "2  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "3  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "4  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "\n",
       "                         PriceBin                    MileageBin  \\\n",
       "0  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "1  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "2  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "3  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "4  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "\n",
       "             Condition               AgeBin  user_idx  \n",
       "0  Certified Pre-Owned  age_(1.999 to  4.0]         0  \n",
       "1  Certified Pre-Owned  age_(11.0 to  14.0]         1  \n",
       "2  Certified Pre-Owned  age_(14.0 to  16.0]         2  \n",
       "3  Certified Pre-Owned  age_(16.0 to  19.0]         3  \n",
       "4  Certified Pre-Owned  age_(19.0 to  22.0]         4  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>AgeBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>item_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Audi :: A3 :: age_(1.999 to  4.0] :: Certified...</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Audi :: A3 :: age_(1.999 to  4.0] :: New</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>New</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Audi :: A3 :: age_(1.999 to  4.0] :: Used</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>Used</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Audi :: A3 :: age_(11.0 to  14.0] :: Certified...</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Audi :: A3 :: age_(11.0 to  14.0] :: New</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "      <td>New</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             item_id Brand Model  \\\n",
       "0  Audi :: A3 :: age_(1.999 to  4.0] :: Certified...  Audi    A3   \n",
       "1           Audi :: A3 :: age_(1.999 to  4.0] :: New  Audi    A3   \n",
       "2          Audi :: A3 :: age_(1.999 to  4.0] :: Used  Audi    A3   \n",
       "3  Audi :: A3 :: age_(11.0 to  14.0] :: Certified...  Audi    A3   \n",
       "4           Audi :: A3 :: age_(11.0 to  14.0] :: New  Audi    A3   \n",
       "\n",
       "                AgeBin            Condition  item_idx  \n",
       "0  age_(1.999 to  4.0]  Certified Pre-Owned         0  \n",
       "1  age_(1.999 to  4.0]                  New         1  \n",
       "2  age_(1.999 to  4.0]                 Used         2  \n",
       "3  age_(11.0 to  14.0]  Certified Pre-Owned         3  \n",
       "4  age_(11.0 to  14.0]                  New         4  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Prepare reusable arrays and feature tables for the selected formulation\n",
    "\n",
    "user_feature_df = data[\"user_feature_df\"].copy()\n",
    "item_feature_df = data[\"item_feature_df\"].copy()\n",
    "\n",
    "train_interactions = data[\"train_interactions\"].copy()\n",
    "test_interactions = data[\"test_interactions\"].copy()\n",
    "\n",
    "X_counts = data[\"X_counts\"].tocsr()\n",
    "X_binary = data[\"X_binary\"].tocsr()\n",
    "\n",
    "num_users = data[\"num_users\"]\n",
    "num_items = data[\"num_items\"]\n",
    "\n",
    "user_seen = data[\"user_seen\"]\n",
    "user_seen_arrays = data[\"user_seen_arrays\"]\n",
    "user_strength = data[\"user_strength\"]\n",
    "test_item_by_user = data[\"test_item_by_user\"]\n",
    "global_pop_rank = data[\"global_pop_rank\"]\n",
    "item_ids = data[\"item_ids\"]\n",
    "\n",
    "X_binary_dense = X_binary.toarray().astype(np.float32)\n",
    "X_counts_dense = X_counts.toarray().astype(np.float32)\n",
    "\n",
    "eval_users = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "print(\"Final working matrix:\", X_binary.shape)\n",
    "display(user_feature_df.head())\n",
    "display(item_feature_df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa1814d8",
   "metadata": {},
   "source": [
    "## Regular model zoo\n",
    "\n",
    "This section includes:\n",
    "- popularity baselines\n",
    "- item-based collaborative filtering\n",
    "- content-based KNN\n",
    "- graph-style recommenders\n",
    "- EASE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a1afb92e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Popularity baselines\n",
    "\n",
    "def recommend_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(int(user_idx), set())\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]\n",
    "\n",
    "country_col = \"Country\" if \"Country\" in user_feature_df.columns else None\n",
    "user_country = {}\n",
    "country_pop_rank = {}\n",
    "\n",
    "if country_col is not None:\n",
    "    user_country = dict(zip(user_feature_df[\"user_idx\"], user_feature_df[country_col]))\n",
    "    train_with_country = train_interactions.merge(\n",
    "        user_feature_df[[\"user_idx\", country_col]],\n",
    "        on=\"user_idx\",\n",
    "        how=\"left\"\n",
    "    )\n",
    "\n",
    "    for country, g in train_with_country.groupby(country_col):\n",
    "        country_pop_rank[country] = (\n",
    "            g.groupby(\"item_idx\")[\"count\"]\n",
    "             .sum()\n",
    "             .sort_values(ascending=False)\n",
    "             .index.to_numpy()\n",
    "        )\n",
    "\n",
    "def recommend_country_popularity(user_idx, n=10):\n",
    "    uid = int(user_idx)\n",
    "    seen = user_seen.get(uid, set())\n",
    "    country = user_country.get(uid, None)\n",
    "\n",
    "    if country in country_pop_rank:\n",
    "        recs = [int(i) for i in country_pop_rank[country] if int(i) not in seen][:n]\n",
    "        if len(recs) >= n:\n",
    "            return recs\n",
    "\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "89251180",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ItemKNN variants\n",
    "\n",
    "def fit_itemknn_recommender(X_matrix, neighbors=100):\n",
    "    item_user = X_matrix.T.tocsr()\n",
    "\n",
    "    knn = NearestNeighbors(\n",
    "        metric=\"cosine\",\n",
    "        algorithm=\"brute\",\n",
    "        n_neighbors=min(neighbors, item_user.shape[0]),\n",
    "        n_jobs=-1\n",
    "    )\n",
    "    knn.fit(item_user)\n",
    "\n",
    "    distances, indices = knn.kneighbors(item_user, n_neighbors=min(neighbors, item_user.shape[0]))\n",
    "    similarities = (1.0 - distances).astype(np.float32)\n",
    "    return indices, similarities\n",
    "\n",
    "def make_itemknn_recommender(X_matrix, neighbors=100, use_strength=True):\n",
    "    neighbor_idx, neighbor_sim = fit_itemknn_recommender(X_matrix, neighbors=neighbors)\n",
    "\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        seen = user_seen.get(uid, set())\n",
    "        strength_map = user_strength.get(uid, {})\n",
    "        scores = {}\n",
    "\n",
    "        for item_idx in seen:\n",
    "            weight = float(strength_map.get(item_idx, 1.0)) if use_strength else 1.0\n",
    "            nbrs = neighbor_idx[item_idx]\n",
    "            sims = neighbor_sim[item_idx]\n",
    "\n",
    "            for j, sim in zip(nbrs, sims):\n",
    "                j = int(j)\n",
    "                if j == item_idx or j in seen:\n",
    "                    continue\n",
    "                scores[j] = scores.get(j, 0.0) + float(sim) * weight\n",
    "\n",
    "        if not scores:\n",
    "            return recommend_popularity(uid, n=n)\n",
    "\n",
    "        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:n]\n",
    "        return [int(i) for i, _ in ranked]\n",
    "\n",
    "    return recommend\n",
    "\n",
    "X_bm25 = bm25_weight(X_counts, K1=100, B=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "24fed1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Content-based KNN on item attributes\n",
    "\n",
    "item_feature_cols = data[\"item_cols\"]\n",
    "item_feature_onehot = pd.get_dummies(item_feature_df[item_feature_cols].astype(str), sparse=True)\n",
    "item_feature_sparse = csr_matrix(item_feature_onehot.sparse.to_coo()).astype(np.float32)\n",
    "\n",
    "item_feature_norm = normalize(item_feature_sparse, norm=\"l2\", axis=1)\n",
    "content_sim = (item_feature_norm @ item_feature_norm.T).astype(np.float32).toarray()\n",
    "np.fill_diagonal(content_sim, 0.0)\n",
    "\n",
    "def recommend_content_knn(user_idx, n=10):\n",
    "    uid = int(user_idx)\n",
    "    seen = user_seen.get(uid, set())\n",
    "    user_vec = X_counts.getrow(uid).toarray().ravel().astype(np.float32)\n",
    "    scores = user_vec @ content_sim\n",
    "    return topn_from_scores(scores, seen, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bee1b5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Graph recommenders: P3alpha and RP3beta\n",
    "\n",
    "def make_p3_recommender(X_input, alpha=1.0, beta=0.0):\n",
    "    S = fit_p3alpha(X_input, alpha=alpha, beta=beta)\n",
    "    return make_sparse_similarity_recommender(X_input, S, user_seen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "500d74ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EASE\n",
    "\n",
    "def fit_ease(X_matrix, lam):\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += lam\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    return B.astype(np.float32)\n",
    "\n",
    "def make_ease_recommender(X_train_dense, B_matrix):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_train_dense[uid] @ B_matrix\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af5374c9",
   "metadata": {},
   "source": [
    "## Neural model zoo\n",
    "\n",
    "This section favors models that actually keep the GPU busy:\n",
    "- full-softmax MLPs\n",
    "- two-tower retrieval\n",
    "- dense autoencoders\n",
    "- pairwise embedding models\n",
    "- NeuMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "366f6ed2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HAS_TORCH: True\n",
      "device: cuda\n",
      "USE_AMP: True\n"
     ]
    }
   ],
   "source": [
    "# PyTorch setup was already performed in the earlier runtime cell.\n",
    "\n",
    "print(\"HAS_TORCH:\", HAS_TORCH)\n",
    "print(\"device:\", device)\n",
    "print(\"USE_AMP:\", USE_AMP)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "89dad4e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive interaction rows: 831152\n",
      "Dense matrix shape: (43199, 2640)\n"
     ]
    }
   ],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Feature encoders for neural models\n",
    "\n",
    "    user_cols = data[\"user_cols\"]\n",
    "    item_cols = data[\"item_cols\"]\n",
    "\n",
    "    feature_encoders = {}\n",
    "    for col in sorted(set(user_cols + item_cols)):\n",
    "        enc = LabelEncoder()\n",
    "        values = pd.concat([\n",
    "            user_feature_df[col].astype(str) if col in user_feature_df.columns else pd.Series(dtype=str),\n",
    "            item_feature_df[col].astype(str) if col in item_feature_df.columns else pd.Series(dtype=str),\n",
    "        ], ignore_index=True)\n",
    "        enc.fit(values)\n",
    "        feature_encoders[col] = enc\n",
    "\n",
    "    user_feature_arrays = {\n",
    "        col: feature_encoders[col].transform(user_feature_df[col].astype(str))\n",
    "        for col in user_cols\n",
    "    }\n",
    "    item_feature_arrays = {\n",
    "        col: feature_encoders[col].transform(item_feature_df[col].astype(str))\n",
    "        for col in item_cols\n",
    "    }\n",
    "\n",
    "    user_feature_tensors = {\n",
    "        col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "        for col, arr in user_feature_arrays.items()\n",
    "    }\n",
    "    item_feature_tensors = {\n",
    "        col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "        for col, arr in item_feature_arrays.items()\n",
    "    }\n",
    "\n",
    "    positive_user_idx = torch.tensor(train_interactions[\"user_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "    positive_item_idx = torch.tensor(train_interactions[\"item_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "    positive_weight = torch.tensor(train_interactions[\"count\"].astype(np.float32).to_numpy(), dtype=torch.float32, device=device)\n",
    "\n",
    "    X_binary_dense_tensor = torch.tensor(X_binary_dense, dtype=torch.float32, device=device)\n",
    "    X_counts_dense_tensor = torch.tensor(X_counts_dense, dtype=torch.float32, device=device)\n",
    "\n",
    "    n_positive_rows = int(positive_user_idx.shape[0])\n",
    "    print(\"Positive interaction rows:\", n_positive_rows)\n",
    "    print(\"Dense matrix shape:\", tuple(X_binary_dense_tensor.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8fe48d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Manual batch iterators\n",
    "\n",
    "    def iterate_positive_batches(batch_size, shuffle=True):\n",
    "        n = n_positive_rows\n",
    "        order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "        for start in range(0, n, batch_size):\n",
    "            idx = order[start:start + batch_size]\n",
    "            yield positive_user_idx[idx], positive_item_idx[idx], positive_weight[idx]\n",
    "\n",
    "    def iterate_dense_user_batches(batch_size, shuffle=True):\n",
    "        n = int(X_binary_dense_tensor.shape[0])\n",
    "        order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "        for start in range(0, n, batch_size):\n",
    "            idx = order[start:start + batch_size]\n",
    "            yield X_binary_dense_tensor[idx], idx\n",
    "\n",
    "    def make_user_feature_batch(user_idx_batch):\n",
    "        return {col: user_feature_tensors[col][user_idx_batch] for col in user_cols}\n",
    "\n",
    "    def make_item_feature_batch(item_idx_batch):\n",
    "        return {col: item_feature_tensors[col][item_idx_batch] for col in item_cols}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f30d69b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Shared building blocks\n",
    "\n",
    "    class FeatureEmbeddingBlock(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, emb_dim):\n",
    "            super().__init__()\n",
    "            self.feature_names = list(feature_cardinalities.keys())\n",
    "            self.embs = nn.ModuleDict({\n",
    "                name: nn.Embedding(card, emb_dim)\n",
    "                for name, card in feature_cardinalities.items()\n",
    "            })\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            parts = [self.embs[name](feature_batch[name]) for name in self.feature_names]\n",
    "            return torch.cat(parts, dim=-1)\n",
    "\n",
    "    class MLPBlock(nn.Module):\n",
    "        def __init__(self, input_dim, hidden_dims, dropout=0.1, final_dim=None, final_activation=False):\n",
    "            super().__init__()\n",
    "            layers = []\n",
    "            last = input_dim\n",
    "            for h in hidden_dims:\n",
    "                layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]\n",
    "                last = h\n",
    "            if final_dim is not None:\n",
    "                layers.append(nn.Linear(last, final_dim))\n",
    "                last = final_dim\n",
    "                if final_activation:\n",
    "                    layers.append(nn.ReLU())\n",
    "            self.net = nn.Sequential(*layers)\n",
    "            self.output_dim = last\n",
    "\n",
    "        def forward(self, x):\n",
    "            return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0c79d1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 1: full-softmax feature MLP\n",
    "\n",
    "    class SoftmaxSegmentMLP(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, num_items, emb_dim=64, hidden_dims=(512, 512, 256), dropout=0.15):\n",
    "            super().__init__()\n",
    "            self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "            input_dim = len(feature_cardinalities) * emb_dim\n",
    "            self.mlp = MLPBlock(input_dim, hidden_dims, dropout=dropout)\n",
    "            self.out = nn.Linear(self.mlp.output_dim, num_items)\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            x = self.encoder(feature_batch)\n",
    "            x = self.mlp(x)\n",
    "            return self.out(x)\n",
    "\n",
    "    def train_softmax_model(model_name, emb_dim=64, hidden_dims=(512, 512, 256), epochs=20, lr=2e-3, wd=1e-5):\n",
    "        feature_cardinalities = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "        model = SoftmaxSegmentMLP(\n",
    "            feature_cardinalities=feature_cardinalities,\n",
    "            num_items=num_items,\n",
    "            emb_dim=emb_dim,\n",
    "            hidden_dims=hidden_dims,\n",
    "            dropout=0.15,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "        criterion = nn.CrossEntropyLoss(reduction=\"none\")\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "            for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(SOFTMAX_BATCH_SIZE, shuffle=True):\n",
    "                feat_batch = make_user_feature_batch(user_idx_batch)\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model(feat_batch)\n",
    "                    loss_vec = criterion(logits, item_idx_batch)\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_softmax_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "            feat_batch = make_user_feature_batch(u)\n",
    "            logits = model(feat_batch).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(logits, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c2fc4897",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 2: feature two-tower with in-batch negatives\n",
    "\n",
    "    class TowerEncoder(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "            super().__init__()\n",
    "            self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "            input_dim = len(feature_cardinalities) * emb_dim\n",
    "            self.net = MLPBlock(input_dim, hidden_dims, dropout=dropout, final_dim=out_dim)\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            x = self.encoder(feature_batch)\n",
    "            x = self.net(x)\n",
    "            return F.normalize(x, dim=-1)\n",
    "\n",
    "    class TwoTowerModel(nn.Module):\n",
    "        def __init__(self, user_feature_cards, item_feature_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "            super().__init__()\n",
    "            self.user_tower = TowerEncoder(user_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "            self.item_tower = TowerEncoder(item_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "    def train_two_tower(model_name, emb_dim=64, hidden_dims=(512, 256), out_dim=128, epochs=20, lr=2e-3, wd=1e-5, temperature=0.07):\n",
    "        user_feature_cards = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "        item_feature_cards = {col: len(feature_encoders[col].classes_) for col in item_cols}\n",
    "\n",
    "        model = TwoTowerModel(\n",
    "            user_feature_cards=user_feature_cards,\n",
    "            item_feature_cards=item_feature_cards,\n",
    "            emb_dim=emb_dim,\n",
    "            hidden_dims=hidden_dims,\n",
    "            out_dim=out_dim,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "            for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(TWOTOWER_BATCH_SIZE, shuffle=True):\n",
    "                user_batch = make_user_feature_batch(user_idx_batch)\n",
    "                item_batch = make_item_feature_batch(item_idx_batch)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    user_vec = model.user_tower(user_batch)\n",
    "                    item_vec = model.item_tower(item_batch)\n",
    "                    logits = (user_vec @ item_vec.T) / temperature\n",
    "                    targets = torch.arange(logits.shape[0], device=device)\n",
    "                    loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_two_tower_recommender(model):\n",
    "        model.eval()\n",
    "        all_item_idx = torch.arange(num_items, dtype=torch.long, device=device)\n",
    "        item_matrix_parts = []\n",
    "        with torch.no_grad():\n",
    "            for start in range(0, num_items, 4096):\n",
    "                idx = all_item_idx[start:start + 4096]\n",
    "                batch = make_item_feature_batch(idx)\n",
    "                item_matrix_parts.append(model.item_tower(batch))\n",
    "        item_matrix = torch.cat(item_matrix_parts, dim=0)\n",
    "\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "            user_batch = make_user_feature_batch(u)\n",
    "            user_vec = model.user_tower(user_batch)\n",
    "            scores = (user_vec @ item_matrix.T).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "918da172",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 3: dense autoencoders\n",
    "\n",
    "    class MultDAE(nn.Module):\n",
    "        def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.2):\n",
    "            super().__init__()\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "            self.encoder = nn.Sequential(\n",
    "                nn.Linear(num_items, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, latent_dim),\n",
    "                nn.Tanh(),\n",
    "            )\n",
    "            self.decoder = nn.Sequential(\n",
    "                nn.Linear(latent_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, num_items),\n",
    "            )\n",
    "\n",
    "        def forward(self, x):\n",
    "            z = self.encoder(self.dropout(x))\n",
    "            logits = self.decoder(z)\n",
    "            return logits\n",
    "\n",
    "    class MultVAE(nn.Module):\n",
    "        def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.3):\n",
    "            super().__init__()\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "            self.encoder = nn.Sequential(\n",
    "                nn.Linear(num_items, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "            )\n",
    "            self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.decoder = nn.Sequential(\n",
    "                nn.Linear(latent_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, num_items),\n",
    "            )\n",
    "\n",
    "        def encode(self, x):\n",
    "            h = self.encoder(self.dropout(x))\n",
    "            return self.mu(h), self.logvar(h)\n",
    "\n",
    "        def reparameterize(self, mu, logvar):\n",
    "            if self.training:\n",
    "                std = torch.exp(0.5 * logvar)\n",
    "                eps = torch.randn_like(std)\n",
    "                return mu + eps * std\n",
    "            return mu\n",
    "\n",
    "        def forward(self, x):\n",
    "            mu, logvar = self.encode(x)\n",
    "            z = self.reparameterize(mu, logvar)\n",
    "            logits = self.decoder(z)\n",
    "            return logits, mu, logvar\n",
    "\n",
    "    def train_multdae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.2, epochs=80, lr=1e-3, wd=0.0):\n",
    "        model = MultDAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_rows = 0\n",
    "            for batch_x, _ in iterate_dense_user_batches(AUTOENC_BATCH_SIZE, shuffle=True):\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model(batch_x)\n",
    "                    loss = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1).mean()\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "                total_rows += batch_x.shape[0]\n",
    "\n",
    "            if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "                print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def train_multvae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.3, epochs=80, lr=1e-3, wd=0.0, anneal_cap=1.0):\n",
    "        model = MultVAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        steps_per_epoch = max(1, math.ceil(num_users / AUTOENC_BATCH_SIZE))\n",
    "        total_steps = max(1, epochs * steps_per_epoch)\n",
    "        step = 0\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_rows = 0\n",
    "\n",
    "            for batch_x, _ in iterate_dense_user_batches(AUTOENC_BATCH_SIZE, shuffle=True):\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                anneal = min(anneal_cap, step / max(total_steps * 0.3, 1))\n",
    "\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits, mu, logvar = model(batch_x)\n",
    "                    recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                    loss = (recon + anneal * kl).mean()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "                total_rows += batch_x.shape[0]\n",
    "                step += 1\n",
    "\n",
    "            if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "                print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f} - anneal: {anneal:.3f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_multdae_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            x = X_binary_dense_tensor[uid:uid + 1]\n",
    "            logits = model(x).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(logits, seen, n)\n",
    "        return recommend\n",
    "\n",
    "    def make_multvae_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            x = X_binary_dense_tensor[uid:uid + 1]\n",
    "            logits, _, _ = model(x)\n",
    "            scores = logits.squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "2d7bc556",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 4: BPR matrix factorization\n",
    "\n",
    "    class BPRMF(nn.Module):\n",
    "        def __init__(self, num_users, num_items, dim=128):\n",
    "            super().__init__()\n",
    "            self.user_emb = nn.Embedding(num_users, dim)\n",
    "            self.item_emb = nn.Embedding(num_items, dim)\n",
    "            self.item_bias = nn.Embedding(num_items, 1)\n",
    "\n",
    "            nn.init.normal_(self.user_emb.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_emb.weight, std=0.02)\n",
    "            nn.init.zeros_(self.item_bias.weight)\n",
    "\n",
    "        def forward(self, user_idx, pos_idx, neg_idx):\n",
    "            u = self.user_emb(user_idx)\n",
    "            p = self.item_emb(pos_idx)\n",
    "            n = self.item_emb(neg_idx)\n",
    "            pb = self.item_bias(pos_idx).squeeze(-1)\n",
    "            nb = self.item_bias(neg_idx).squeeze(-1)\n",
    "\n",
    "            pos_scores = (u * p).sum(dim=1) + pb\n",
    "            neg_scores = (u * n).sum(dim=1) + nb\n",
    "            return pos_scores, neg_scores\n",
    "\n",
    "    def train_bprmf(model_name, dim=128, epochs=24, lr=2e-3, wd=1e-6):\n",
    "        model = BPRMF(num_users=num_users, num_items=num_items, dim=dim).to(device)\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "\n",
    "            for user_idx_batch, pos_item_batch, weight_batch in iterate_positive_batches(PAIRWISE_BATCH_SIZE, shuffle=True):\n",
    "                neg_item_batch = torch.randint(0, num_items, size=pos_item_batch.shape, device=device)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    pos_scores, neg_scores = model(user_idx_batch, pos_item_batch, neg_item_batch)\n",
    "                    loss_vec = -F.logsigmoid(pos_scores - neg_scores)\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_bprmf_recommender(model):\n",
    "        item_matrix = model.item_emb.weight\n",
    "        item_bias = model.item_bias.weight.squeeze(-1)\n",
    "\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = model.user_emb.weight[uid]\n",
    "            scores = item_matrix @ u + item_bias\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "48fb928d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 5: NeuMF\n",
    "\n",
    "    class NeuMF(nn.Module):\n",
    "        def __init__(self, num_users, num_items, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128)):\n",
    "            super().__init__()\n",
    "            self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "            self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "            self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "            self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "            layers = []\n",
    "            last = mlp_dim * 2\n",
    "            for h in hidden_dims:\n",
    "                layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(0.1)]\n",
    "                last = h\n",
    "            self.mlp = nn.Sequential(*layers)\n",
    "            self.out = nn.Linear(last + mf_dim, 1)\n",
    "\n",
    "            nn.init.normal_(self.user_mf.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_mf.weight, std=0.02)\n",
    "            nn.init.normal_(self.user_mlp.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_mlp.weight, std=0.02)\n",
    "\n",
    "        def score(self, user_idx, item_idx):\n",
    "            mf_u = self.user_mf(user_idx)\n",
    "            mf_i = self.item_mf(item_idx)\n",
    "            mf = mf_u * mf_i\n",
    "\n",
    "            mlp_u = self.user_mlp(user_idx)\n",
    "            mlp_i = self.item_mlp(item_idx)\n",
    "            mlp = self.mlp(torch.cat([mlp_u, mlp_i], dim=-1))\n",
    "\n",
    "            x = torch.cat([mf, mlp], dim=-1)\n",
    "            return self.out(x).squeeze(-1)\n",
    "\n",
    "    def train_neumf(model_name, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128), epochs=24, lr=2e-3, wd=1e-6):\n",
    "        model = NeuMF(num_users=num_users, num_items=num_items, mf_dim=mf_dim, mlp_dim=mlp_dim, hidden_dims=hidden_dims).to(device)\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "\n",
    "            for user_idx_batch, pos_item_batch, weight_batch in iterate_positive_batches(PAIRWISE_BATCH_SIZE, shuffle=True):\n",
    "                neg_item_batch = torch.randint(0, num_items, size=pos_item_batch.shape, device=device)\n",
    "\n",
    "                user_cat = torch.cat([user_idx_batch, user_idx_batch], dim=0)\n",
    "                item_cat = torch.cat([pos_item_batch, neg_item_batch], dim=0)\n",
    "                target = torch.cat([\n",
    "                    torch.ones_like(pos_item_batch, dtype=torch.float32),\n",
    "                    torch.zeros_like(neg_item_batch, dtype=torch.float32),\n",
    "                ], dim=0)\n",
    "                sample_weight = torch.cat([weight_batch, weight_batch], dim=0)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model.score(user_cat, item_cat)\n",
    "                    loss_vec = F.binary_cross_entropy_with_logits(logits, target, reduction=\"none\")\n",
    "                    loss = (loss_vec * sample_weight).sum() / sample_weight.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(sample_weight.sum().item())\n",
    "                total_w += float(sample_weight.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_neumf_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            users = torch.full((num_items,), uid, dtype=torch.long, device=device)\n",
    "            items = torch.arange(num_items, dtype=torch.long, device=device)\n",
    "            scores = model.score(users, items)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c9e63b69",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: Popularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Popularity: 100%|██████████| 43199/43199 [00:10<00:00, 4268.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: CountryPopularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "CountryPopularity: 100%|██████████| 43199/43199 [00:09<00:00, 4393.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: ContentKNN\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ContentKNN: 100%|██████████| 43199/43199 [01:16<00:00, 563.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN binary neighbors=100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_binary_k100: 100%|██████████| 43199/43199 [00:16<00:00, 2559.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN binary neighbors=200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_binary_k200: 100%|██████████| 43199/43199 [00:30<00:00, 1439.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN binary neighbors=300\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_binary_k300: 100%|██████████| 43199/43199 [00:44<00:00, 960.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k100: 100%|██████████| 43199/43199 [00:16<00:00, 2619.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k200: 100%|██████████| 43199/43199 [00:30<00:00, 1425.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting ItemKNN BM25 neighbors=300\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ItemKNN_bm25_k300: 100%|██████████| 43199/43199 [00:45<00:00, 955.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting P3alpha alpha=0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "P3alpha_a0_5: 100%|██████████| 43199/43199 [00:05<00:00, 8126.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting P3alpha alpha=1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "P3alpha_a1_0: 100%|██████████| 43199/43199 [00:05<00:00, 7895.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=0.8 beta=0.3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a0_8_b0_3: 100%|██████████| 43199/43199 [00:05<00:00, 7915.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=1.0 beta=0.6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a1_0_b0_6: 100%|██████████| 43199/43199 [00:05<00:00, 8249.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l100: 100%|██████████| 11/11 [00:02<00:00,  4.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l100: 100%|██████████| 11/11 [00:02<00:00,  4.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l200: 100%|██████████| 11/11 [00:02<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l200: 100%|██████████| 11/11 [00:02<00:00,  4.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=500.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l500: 100%|██████████| 11/11 [00:02<00:00,  5.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=500.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l500: 100%|██████████| 11/11 [00:02<00:00,  5.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l1000: 100%|██████████| 11/11 [00:02<00:00,  4.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l1000: 100%|██████████| 11/11 [00:02<00:00,  4.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=2000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l2000: 100%|██████████| 11/11 [00:02<00:00,  5.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting EASE count lambda=2000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l2000: 100%|██████████| 11/11 [00:02<00:00,  5.03it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169448</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060618</td>\n",
       "      <td>0.085705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097711</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>P3alpha_a1_0</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060332</td>\n",
       "      <td>0.085333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096924</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310563</td>\n",
       "      <td>0.060030</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ItemKNN_bm25_k100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>ItemKNN_bm25_k200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>ItemKNN_bm25_k300</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>ItemKNN_binary_k100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>ItemKNN_binary_k200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ItemKNN_binary_k300</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>EASE_binary_l2000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168453</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060929</td>\n",
       "      <td>0.085723</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>P3alpha_a0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097387</td>\n",
       "      <td>0.168407</td>\n",
       "      <td>0.309405</td>\n",
       "      <td>0.060372</td>\n",
       "      <td>0.085252</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>RP3beta_a0_8_b0_3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097016</td>\n",
       "      <td>0.168337</td>\n",
       "      <td>0.308086</td>\n",
       "      <td>0.060347</td>\n",
       "      <td>0.085230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>EASE_count_l500</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.092271</td>\n",
       "      <td>0.167759</td>\n",
       "      <td>0.309359</td>\n",
       "      <td>0.056853</td>\n",
       "      <td>0.082288</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>EASE_count_l2000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097618</td>\n",
       "      <td>0.167527</td>\n",
       "      <td>0.309058</td>\n",
       "      <td>0.060644</td>\n",
       "      <td>0.085284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>EASE_binary_l500</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095118</td>\n",
       "      <td>0.167434</td>\n",
       "      <td>0.309660</td>\n",
       "      <td>0.058010</td>\n",
       "      <td>0.083173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>EASE_count_l200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.086229</td>\n",
       "      <td>0.162342</td>\n",
       "      <td>0.306720</td>\n",
       "      <td>0.052313</td>\n",
       "      <td>0.077514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>EASE_binary_l200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.088636</td>\n",
       "      <td>0.161717</td>\n",
       "      <td>0.306952</td>\n",
       "      <td>0.052433</td>\n",
       "      <td>0.077498</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>EASE_count_l100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.083289</td>\n",
       "      <td>0.160235</td>\n",
       "      <td>0.304289</td>\n",
       "      <td>0.050601</td>\n",
       "      <td>0.075694</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>EASE_binary_l100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.084909</td>\n",
       "      <td>0.158916</td>\n",
       "      <td>0.304799</td>\n",
       "      <td>0.050151</td>\n",
       "      <td>0.075070</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>ContentKNN</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.072942</td>\n",
       "      <td>0.145235</td>\n",
       "      <td>0.291812</td>\n",
       "      <td>0.043139</td>\n",
       "      <td>0.066449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.003218</td>\n",
       "      <td>0.006019</td>\n",
       "      <td>0.011621</td>\n",
       "      <td>0.001746</td>\n",
       "      <td>0.002724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>CountryPopularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.002639</td>\n",
       "      <td>0.005185</td>\n",
       "      <td>0.010579</td>\n",
       "      <td>0.001451</td>\n",
       "      <td>0.002306</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0     RP3beta_a1_0_b0_6      43199  0.097687  0.169448  0.308896  0.060618   \n",
       "1     EASE_binary_l1000      43199  0.097711  0.169356  0.309614  0.060787   \n",
       "2          P3alpha_a1_0      43199  0.098104  0.168870  0.310609  0.060332   \n",
       "3      EASE_count_l1000      43199  0.096924  0.168870  0.310563  0.060030   \n",
       "4     ItemKNN_bm25_k100      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "5     ItemKNN_bm25_k200      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "6     ItemKNN_bm25_k300      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "7   ItemKNN_binary_k100      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "8   ItemKNN_binary_k200      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "9   ItemKNN_binary_k300      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "10    EASE_binary_l2000      43199  0.098104  0.168453  0.309614  0.060929   \n",
       "11         P3alpha_a0_5      43199  0.097387  0.168407  0.309405  0.060372   \n",
       "12    RP3beta_a0_8_b0_3      43199  0.097016  0.168337  0.308086  0.060347   \n",
       "13      EASE_count_l500      43199  0.092271  0.167759  0.309359  0.056853   \n",
       "14     EASE_count_l2000      43199  0.097618  0.167527  0.309058  0.060644   \n",
       "15     EASE_binary_l500      43199  0.095118  0.167434  0.309660  0.058010   \n",
       "16      EASE_count_l200      43199  0.086229  0.162342  0.306720  0.052313   \n",
       "17     EASE_binary_l200      43199  0.088636  0.161717  0.306952  0.052433   \n",
       "18      EASE_count_l100      43199  0.083289  0.160235  0.304289  0.050601   \n",
       "19     EASE_binary_l100      43199  0.084909  0.158916  0.304799  0.050151   \n",
       "20           ContentKNN      43199  0.072942  0.145235  0.291812  0.043139   \n",
       "21           Popularity      43199  0.003218  0.006019  0.011621  0.001746   \n",
       "22    CountryPopularity      43199  0.002639  0.005185  0.010579  0.001451   \n",
       "\n",
       "     NDCG@10  \n",
       "0   0.085705  \n",
       "1   0.085790  \n",
       "2   0.085333  \n",
       "3   0.085076  \n",
       "4   0.085358  \n",
       "5   0.085358  \n",
       "6   0.085358  \n",
       "7   0.085517  \n",
       "8   0.085517  \n",
       "9   0.085517  \n",
       "10  0.085723  \n",
       "11  0.085252  \n",
       "12  0.085230  \n",
       "13  0.082288  \n",
       "14  0.085284  \n",
       "15  0.083173  \n",
       "16  0.077514  \n",
       "17  0.077498  \n",
       "18  0.075694  \n",
       "19  0.075070  \n",
       "20  0.066449  \n",
       "21  0.002724  \n",
       "22  0.002306  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit and evaluate regular models\n",
    "\n",
    "regular_results = []\n",
    "regular_models = {}\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: Popularity\")\n",
    "regular_models[\"Popularity\"] = recommend_popularity\n",
    "regular_results.append(evaluate_model(recommend_popularity, \"Popularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: CountryPopularity\")\n",
    "regular_models[\"CountryPopularity\"] = recommend_country_popularity\n",
    "regular_results.append(evaluate_model(recommend_country_popularity, \"CountryPopularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: ContentKNN\")\n",
    "regular_models[\"ContentKNN\"] = recommend_content_knn\n",
    "regular_results.append(evaluate_model(recommend_content_knn, \"ContentKNN\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN binary neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_binary, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_binary_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN BM25 neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_bm25, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_bm25_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for alpha in P3_ALPHA_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting P3alpha alpha={alpha}\")\n",
    "    rec = make_p3_recommender(X_binary, alpha=alpha, beta=0.0)\n",
    "    name = f\"P3alpha_a{str(alpha).replace('.', '_')}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for alpha, beta in RP3_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting RP3beta alpha={alpha} beta={beta}\")\n",
    "    rec = make_p3_recommender(X_binary, alpha=alpha, beta=beta)\n",
    "    name = f\"RP3beta_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for lam in EASE_LAMBDA_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting EASE binary lambda={lam}\")\n",
    "    B_bin = fit_ease(X_binary, lam=lam)\n",
    "    rec_bin = make_ease_recommender(X_binary_dense, B_bin)\n",
    "    name_bin = f\"EASE_binary_l{int(lam)}\"\n",
    "    regular_models[name_bin] = rec_bin\n",
    "    regular_results.append(\n",
    "        evaluate_ease_batched(\n",
    "            B_matrix=B_bin,\n",
    "            X_train_matrix=X_binary,\n",
    "            user_seen_dict=user_seen,\n",
    "            test_item_by_user=test_item_by_user,\n",
    "            model_name=name_bin,\n",
    "            user_indices=eval_users,\n",
    "            ks=TOP_KS,\n",
    "            batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "            user_seen_arrays_dict=user_seen_arrays,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    print(f\"Fitting EASE count lambda={lam}\")\n",
    "    B_cnt = fit_ease(X_counts, lam=lam)\n",
    "    rec_cnt = make_ease_recommender(X_counts_dense, B_cnt)\n",
    "    name_cnt = f\"EASE_count_l{int(lam)}\"\n",
    "    regular_models[name_cnt] = rec_cnt\n",
    "    regular_results.append(\n",
    "        evaluate_ease_batched(\n",
    "            B_matrix=B_cnt,\n",
    "            X_train_matrix=X_counts,\n",
    "            user_seen_dict=user_seen,\n",
    "            test_item_by_user=test_item_by_user,\n",
    "            model_name=name_cnt,\n",
    "            user_indices=eval_users,\n",
    "            ks=TOP_KS,\n",
    "            batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "            user_seen_arrays_dict=user_seen_arrays,\n",
    "        )\n",
    "    )\n",
    "\n",
    "regular_results_df = pd.DataFrame(regular_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "display(regular_results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "0bdd000a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training SoftmaxMLP_base\n",
      "SoftmaxMLP_base epoch 1/20 - loss: 5.433627\n",
      "SoftmaxMLP_base epoch 2/20 - loss: 4.594458\n",
      "SoftmaxMLP_base epoch 3/20 - loss: 4.559392\n",
      "SoftmaxMLP_base epoch 4/20 - loss: 4.541153\n",
      "SoftmaxMLP_base epoch 5/20 - loss: 4.529110\n",
      "SoftmaxMLP_base epoch 6/20 - loss: 4.521301\n",
      "SoftmaxMLP_base epoch 7/20 - loss: 4.514780\n",
      "SoftmaxMLP_base epoch 8/20 - loss: 4.509405\n",
      "SoftmaxMLP_base epoch 9/20 - loss: 4.505360\n",
      "SoftmaxMLP_base epoch 10/20 - loss: 4.502343\n",
      "SoftmaxMLP_base epoch 11/20 - loss: 4.498367\n",
      "SoftmaxMLP_base epoch 12/20 - loss: 4.496181\n",
      "SoftmaxMLP_base epoch 13/20 - loss: 4.492616\n",
      "SoftmaxMLP_base epoch 14/20 - loss: 4.489436\n",
      "SoftmaxMLP_base epoch 15/20 - loss: 4.487635\n",
      "SoftmaxMLP_base epoch 16/20 - loss: 4.484647\n",
      "SoftmaxMLP_base epoch 17/20 - loss: 4.481383\n",
      "SoftmaxMLP_base epoch 18/20 - loss: 4.478235\n",
      "SoftmaxMLP_base epoch 19/20 - loss: 4.475141\n",
      "SoftmaxMLP_base epoch 20/20 - loss: 4.472415\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "SoftmaxMLP_base: 100%|██████████| 43199/43199 [00:16<00:00, 2583.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training SoftmaxMLP_large\n",
      "SoftmaxMLP_large epoch 1/20 - loss: 5.647399\n",
      "SoftmaxMLP_large epoch 2/20 - loss: 4.615100\n",
      "SoftmaxMLP_large epoch 3/20 - loss: 4.566708\n",
      "SoftmaxMLP_large epoch 4/20 - loss: 4.544700\n",
      "SoftmaxMLP_large epoch 5/20 - loss: 4.530629\n",
      "SoftmaxMLP_large epoch 6/20 - loss: 4.521154\n",
      "SoftmaxMLP_large epoch 7/20 - loss: 4.514441\n",
      "SoftmaxMLP_large epoch 8/20 - loss: 4.510036\n",
      "SoftmaxMLP_large epoch 9/20 - loss: 4.505512\n",
      "SoftmaxMLP_large epoch 10/20 - loss: 4.502212\n",
      "SoftmaxMLP_large epoch 11/20 - loss: 4.499344\n",
      "SoftmaxMLP_large epoch 12/20 - loss: 4.497216\n",
      "SoftmaxMLP_large epoch 13/20 - loss: 4.494908\n",
      "SoftmaxMLP_large epoch 14/20 - loss: 4.492697\n",
      "SoftmaxMLP_large epoch 15/20 - loss: 4.490482\n",
      "SoftmaxMLP_large epoch 16/20 - loss: 4.488693\n",
      "SoftmaxMLP_large epoch 17/20 - loss: 4.486507\n",
      "SoftmaxMLP_large epoch 18/20 - loss: 4.484540\n",
      "SoftmaxMLP_large epoch 19/20 - loss: 4.482527\n",
      "SoftmaxMLP_large epoch 20/20 - loss: 4.480358\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "SoftmaxMLP_large: 100%|██████████| 43199/43199 [00:18<00:00, 2362.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training TwoTower_base\n",
      "TwoTower_base epoch 1/20 - loss: 5.822068\n",
      "TwoTower_base epoch 2/20 - loss: 5.637971\n",
      "TwoTower_base epoch 3/20 - loss: 5.636291\n",
      "TwoTower_base epoch 4/20 - loss: 5.635547\n",
      "TwoTower_base epoch 5/20 - loss: 5.634698\n",
      "TwoTower_base epoch 6/20 - loss: 5.634148\n",
      "TwoTower_base epoch 7/20 - loss: 5.633886\n",
      "TwoTower_base epoch 8/20 - loss: 5.633663\n",
      "TwoTower_base epoch 9/20 - loss: 5.633324\n",
      "TwoTower_base epoch 10/20 - loss: 5.633090\n",
      "TwoTower_base epoch 11/20 - loss: 5.632885\n",
      "TwoTower_base epoch 12/20 - loss: 5.632590\n",
      "TwoTower_base epoch 13/20 - loss: 5.632494\n",
      "TwoTower_base epoch 14/20 - loss: 5.632353\n",
      "TwoTower_base epoch 15/20 - loss: 5.632221\n",
      "TwoTower_base epoch 16/20 - loss: 5.632073\n",
      "TwoTower_base epoch 17/20 - loss: 5.631933\n",
      "TwoTower_base epoch 18/20 - loss: 5.631662\n",
      "TwoTower_base epoch 19/20 - loss: 5.631791\n",
      "TwoTower_base epoch 20/20 - loss: 5.631503\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TwoTower_base: 100%|██████████| 43199/43199 [00:17<00:00, 2400.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training TwoTower_large\n",
      "TwoTower_large epoch 1/20 - loss: 6.060688\n",
      "TwoTower_large epoch 2/20 - loss: 5.642607\n",
      "TwoTower_large epoch 3/20 - loss: 5.638540\n",
      "TwoTower_large epoch 4/20 - loss: 5.636511\n",
      "TwoTower_large epoch 5/20 - loss: 5.635549\n",
      "TwoTower_large epoch 6/20 - loss: 5.634881\n",
      "TwoTower_large epoch 7/20 - loss: 5.634233\n",
      "TwoTower_large epoch 8/20 - loss: 5.634051\n",
      "TwoTower_large epoch 9/20 - loss: 5.633371\n",
      "TwoTower_large epoch 10/20 - loss: 5.633048\n",
      "TwoTower_large epoch 11/20 - loss: 5.632787\n",
      "TwoTower_large epoch 12/20 - loss: 5.632671\n",
      "TwoTower_large epoch 13/20 - loss: 5.632314\n",
      "TwoTower_large epoch 14/20 - loss: 5.632094\n",
      "TwoTower_large epoch 15/20 - loss: 5.631977\n",
      "TwoTower_large epoch 16/20 - loss: 5.631923\n",
      "TwoTower_large epoch 17/20 - loss: 5.631531\n",
      "TwoTower_large epoch 18/20 - loss: 5.631616\n",
      "TwoTower_large epoch 19/20 - loss: 5.631395\n",
      "TwoTower_large epoch 20/20 - loss: 5.631218\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TwoTower_large: 100%|██████████| 43199/43199 [00:19<00:00, 2231.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultDAE\n",
      "MultDAE epoch 1/80 - loss: 139.663078\n",
      "MultDAE epoch 10/80 - loss: 81.691693\n",
      "MultDAE epoch 20/80 - loss: 75.510694\n",
      "MultDAE epoch 30/80 - loss: 70.891737\n",
      "MultDAE epoch 40/80 - loss: 69.398385\n",
      "MultDAE epoch 50/80 - loss: 69.018782\n",
      "MultDAE epoch 60/80 - loss: 68.817604\n",
      "MultDAE epoch 70/80 - loss: 68.760805\n",
      "MultDAE epoch 80/80 - loss: 68.693466\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultDAE: 100%|██████████| 43199/43199 [00:10<00:00, 4127.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultVAE_base\n",
      "MultVAE_base epoch 1/80 - loss: 145.110954 - anneal: 0.038\n",
      "MultVAE_base epoch 10/80 - loss: 92.739109 - anneal: 0.413\n",
      "MultVAE_base epoch 20/80 - loss: 96.354798 - anneal: 0.830\n",
      "MultVAE_base epoch 30/80 - loss: 97.721237 - anneal: 1.000\n",
      "MultVAE_base epoch 40/80 - loss: 97.314592 - anneal: 1.000\n",
      "MultVAE_base epoch 50/80 - loss: 97.163861 - anneal: 1.000\n",
      "MultVAE_base epoch 60/80 - loss: 96.885542 - anneal: 1.000\n",
      "MultVAE_base epoch 70/80 - loss: 96.561948 - anneal: 1.000\n",
      "MultVAE_base epoch 80/80 - loss: 95.870900 - anneal: 1.000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultVAE_base: 100%|██████████| 43199/43199 [00:12<00:00, 3582.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultVAE_large\n",
      "MultVAE_large epoch 1/80 - loss: 144.021453 - anneal: 0.038\n",
      "MultVAE_large epoch 10/80 - loss: 92.897243 - anneal: 0.413\n",
      "MultVAE_large epoch 20/80 - loss: 96.548958 - anneal: 0.830\n",
      "MultVAE_large epoch 30/80 - loss: 97.881042 - anneal: 1.000\n",
      "MultVAE_large epoch 40/80 - loss: 97.613863 - anneal: 1.000\n",
      "MultVAE_large epoch 50/80 - loss: 97.391887 - anneal: 1.000\n",
      "MultVAE_large epoch 60/80 - loss: 97.258981 - anneal: 1.000\n",
      "MultVAE_large epoch 70/80 - loss: 97.145295 - anneal: 1.000\n",
      "MultVAE_large epoch 80/80 - loss: 97.001129 - anneal: 1.000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultVAE_large: 100%|██████████| 43199/43199 [00:11<00:00, 3720.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training BPRMF\n",
      "BPRMF epoch 1/24 - loss: 0.691725\n",
      "BPRMF epoch 2/24 - loss: 0.673615\n",
      "BPRMF epoch 3/24 - loss: 0.569194\n",
      "BPRMF epoch 4/24 - loss: 0.336917\n",
      "BPRMF epoch 5/24 - loss: 0.150869\n",
      "BPRMF epoch 6/24 - loss: 0.075791\n",
      "BPRMF epoch 7/24 - loss: 0.050720\n",
      "BPRMF epoch 8/24 - loss: 0.040561\n",
      "BPRMF epoch 9/24 - loss: 0.035357\n",
      "BPRMF epoch 10/24 - loss: 0.032310\n",
      "BPRMF epoch 11/24 - loss: 0.030016\n",
      "BPRMF epoch 12/24 - loss: 0.028853\n",
      "BPRMF epoch 13/24 - loss: 0.027745\n",
      "BPRMF epoch 14/24 - loss: 0.027455\n",
      "BPRMF epoch 15/24 - loss: 0.026372\n",
      "BPRMF epoch 16/24 - loss: 0.026184\n",
      "BPRMF epoch 17/24 - loss: 0.025756\n",
      "BPRMF epoch 18/24 - loss: 0.025307\n",
      "BPRMF epoch 19/24 - loss: 0.025115\n",
      "BPRMF epoch 20/24 - loss: 0.024741\n",
      "BPRMF epoch 21/24 - loss: 0.024742\n",
      "BPRMF epoch 22/24 - loss: 0.024441\n",
      "BPRMF epoch 23/24 - loss: 0.024390\n",
      "BPRMF epoch 24/24 - loss: 0.024141\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "BPRMF: 100%|██████████| 43199/43199 [00:06<00:00, 6871.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training NeuMF\n",
      "NeuMF epoch 1/24 - loss: 0.687082\n",
      "NeuMF epoch 2/24 - loss: 0.567259\n",
      "NeuMF epoch 3/24 - loss: 0.432680\n",
      "NeuMF epoch 4/24 - loss: 0.324502\n",
      "NeuMF epoch 5/24 - loss: 0.238247\n",
      "NeuMF epoch 6/24 - loss: 0.180756\n",
      "NeuMF epoch 7/24 - loss: 0.146119\n",
      "NeuMF epoch 8/24 - loss: 0.117219\n",
      "NeuMF epoch 9/24 - loss: 0.098352\n",
      "NeuMF epoch 10/24 - loss: 0.088241\n",
      "NeuMF epoch 11/24 - loss: 0.083465\n",
      "NeuMF epoch 12/24 - loss: 0.082000\n",
      "NeuMF epoch 13/24 - loss: 0.079005\n",
      "NeuMF epoch 14/24 - loss: 0.078644\n",
      "NeuMF epoch 15/24 - loss: 0.078368\n",
      "NeuMF epoch 16/24 - loss: 0.078357\n",
      "NeuMF epoch 17/24 - loss: 0.077738\n",
      "NeuMF epoch 18/24 - loss: 0.077661\n",
      "NeuMF epoch 19/24 - loss: 0.077143\n",
      "NeuMF epoch 20/24 - loss: 0.076916\n",
      "NeuMF epoch 21/24 - loss: 0.076846\n",
      "NeuMF epoch 22/24 - loss: 0.076226\n",
      "NeuMF epoch 23/24 - loss: 0.075804\n",
      "NeuMF epoch 24/24 - loss: 0.076763\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "NeuMF: 100%|██████████| 43199/43199 [00:14<00:00, 2976.83it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TwoTower_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096067</td>\n",
       "      <td>0.170351</td>\n",
       "      <td>0.315192</td>\n",
       "      <td>0.059292</td>\n",
       "      <td>0.084829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094099</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.310077</td>\n",
       "      <td>0.057802</td>\n",
       "      <td>0.083448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NeuMF</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094354</td>\n",
       "      <td>0.169263</td>\n",
       "      <td>0.310378</td>\n",
       "      <td>0.058961</td>\n",
       "      <td>0.084273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>SoftmaxMLP_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.093197</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.313850</td>\n",
       "      <td>0.057147</td>\n",
       "      <td>0.082778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>BPRMF</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094169</td>\n",
       "      <td>0.166717</td>\n",
       "      <td>0.309961</td>\n",
       "      <td>0.057814</td>\n",
       "      <td>0.082833</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>TwoTower_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097456</td>\n",
       "      <td>0.166555</td>\n",
       "      <td>0.310331</td>\n",
       "      <td>0.060473</td>\n",
       "      <td>0.084917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>MultVAE_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.092873</td>\n",
       "      <td>0.165976</td>\n",
       "      <td>0.309359</td>\n",
       "      <td>0.056188</td>\n",
       "      <td>0.081413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>SoftmaxMLP_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.091252</td>\n",
       "      <td>0.164795</td>\n",
       "      <td>0.309683</td>\n",
       "      <td>0.054060</td>\n",
       "      <td>0.079489</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>MultDAE</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.088011</td>\n",
       "      <td>0.162018</td>\n",
       "      <td>0.307553</td>\n",
       "      <td>0.052818</td>\n",
       "      <td>0.077864</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0     TwoTower_base      43199  0.096067  0.170351  0.315192  0.059292   \n",
       "1      MultVAE_base      43199  0.094099  0.169518  0.310077  0.057802   \n",
       "2             NeuMF      43199  0.094354  0.169263  0.310378  0.058961   \n",
       "3  SoftmaxMLP_large      43199  0.093197  0.168870  0.313850  0.057147   \n",
       "4             BPRMF      43199  0.094169  0.166717  0.309961  0.057814   \n",
       "5    TwoTower_large      43199  0.097456  0.166555  0.310331  0.060473   \n",
       "6     MultVAE_large      43199  0.092873  0.165976  0.309359  0.056188   \n",
       "7   SoftmaxMLP_base      43199  0.091252  0.164795  0.309683  0.054060   \n",
       "8           MultDAE      43199  0.088011  0.162018  0.307553  0.052818   \n",
       "\n",
       "    NDCG@10  \n",
       "0  0.084829  \n",
       "1  0.083448  \n",
       "2  0.084273  \n",
       "3  0.082778  \n",
       "4  0.082833  \n",
       "5  0.084917  \n",
       "6  0.081413  \n",
       "7  0.079489  \n",
       "8  0.077864  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if not RUN_NEURAL:\n",
    "    neural_results = []\n",
    "    neural_models = {}\n",
    "    neural_results_df = pd.DataFrame(columns=['Model','UsersEval','HR@5','HR@10','HR@20','MRR@10','NDCG@10'])\n",
    "    print('Skipping neural model training because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Fit and evaluate neural models\n",
    "\n",
    "    neural_results = []\n",
    "    neural_models = {}\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training SoftmaxMLP_base\")\n",
    "    softmax_base = train_softmax_model(\n",
    "        model_name=\"SoftmaxMLP_base\",\n",
    "        emb_dim=64,\n",
    "        hidden_dims=(512, 512, 256),\n",
    "        epochs=SOFTMAX_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-5,\n",
    "    )\n",
    "    rec = make_softmax_recommender(softmax_base)\n",
    "    neural_models[\"SoftmaxMLP_base\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"SoftmaxMLP_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training SoftmaxMLP_large\")\n",
    "    softmax_large = train_softmax_model(\n",
    "        model_name=\"SoftmaxMLP_large\",\n",
    "        emb_dim=96,\n",
    "        hidden_dims=(1024, 1024, 512, 256),\n",
    "        epochs=SOFTMAX_EPOCHS,\n",
    "        lr=1.5e-3,\n",
    "        wd=1e-5,\n",
    "    )\n",
    "    rec = make_softmax_recommender(softmax_large)\n",
    "    neural_models[\"SoftmaxMLP_large\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"SoftmaxMLP_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training TwoTower_base\")\n",
    "    twotower_base = train_two_tower(\n",
    "        model_name=\"TwoTower_base\",\n",
    "        emb_dim=64,\n",
    "        hidden_dims=(512, 256),\n",
    "        out_dim=128,\n",
    "        epochs=TWOTOWER_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-5,\n",
    "        temperature=0.07,\n",
    "    )\n",
    "    rec = make_two_tower_recommender(twotower_base)\n",
    "    neural_models[\"TwoTower_base\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"TwoTower_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training TwoTower_large\")\n",
    "    twotower_large = train_two_tower(\n",
    "        model_name=\"TwoTower_large\",\n",
    "        emb_dim=96,\n",
    "        hidden_dims=(1024, 512, 256),\n",
    "        out_dim=192,\n",
    "        epochs=TWOTOWER_EPOCHS,\n",
    "        lr=1.5e-3,\n",
    "        wd=1e-5,\n",
    "        temperature=0.05,\n",
    "    )\n",
    "    rec = make_two_tower_recommender(twotower_large)\n",
    "    neural_models[\"TwoTower_large\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"TwoTower_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training MultDAE\")\n",
    "    multdae = train_multdae(\n",
    "        model_name=\"MultDAE\",\n",
    "        hidden_dim=1024,\n",
    "        latent_dim=256,\n",
    "        dropout=0.2,\n",
    "        epochs=AUTOENC_EPOCHS,\n",
    "        lr=1e-3,\n",
    "        wd=0.0,\n",
    "    )\n",
    "    rec = make_multdae_recommender(multdae)\n",
    "    neural_models[\"MultDAE\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"MultDAE\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training MultVAE_base\")\n",
    "    multvae_base = train_multvae(\n",
    "        model_name=\"MultVAE_base\",\n",
    "        hidden_dim=1024,\n",
    "        latent_dim=256,\n",
    "        dropout=0.25,\n",
    "        epochs=AUTOENC_EPOCHS,\n",
    "        lr=1e-3,\n",
    "        wd=0.0,\n",
    "    )\n",
    "    rec = make_multvae_recommender(multvae_base)\n",
    "    neural_models[\"MultVAE_base\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"MultVAE_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training MultVAE_large\")\n",
    "    multvae_large = train_multvae(\n",
    "        model_name=\"MultVAE_large\",\n",
    "        hidden_dim=1536,\n",
    "        latent_dim=384,\n",
    "        dropout=0.30,\n",
    "        epochs=AUTOENC_EPOCHS,\n",
    "        lr=8e-4,\n",
    "        wd=0.0,\n",
    "    )\n",
    "    rec = make_multvae_recommender(multvae_large)\n",
    "    neural_models[\"MultVAE_large\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"MultVAE_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training BPRMF\")\n",
    "    bprmf = train_bprmf(\n",
    "        model_name=\"BPRMF\",\n",
    "        dim=128,\n",
    "        epochs=PAIRWISE_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-6,\n",
    "    )\n",
    "    rec = make_bprmf_recommender(bprmf)\n",
    "    neural_models[\"BPRMF\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"BPRMF\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training NeuMF\")\n",
    "    neumf = train_neumf(\n",
    "        model_name=\"NeuMF\",\n",
    "        mf_dim=64,\n",
    "        mlp_dim=128,\n",
    "        hidden_dims=(256, 128),\n",
    "        epochs=PAIRWISE_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-6,\n",
    "    )\n",
    "    rec = make_neumf_recommender(neumf)\n",
    "    neural_models[\"NeuMF\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"NeuMF\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    neural_results_df = pd.DataFrame(neural_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "    display(neural_results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d1ab8431",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top regular models:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169448</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060618</td>\n",
       "      <td>0.085705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097711</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>P3alpha_a1_0</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060332</td>\n",
       "      <td>0.085333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096924</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310563</td>\n",
       "      <td>0.060030</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ItemKNN_bm25_k100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>ItemKNN_bm25_k200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>ItemKNN_bm25_k300</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>ItemKNN_binary_k100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>ItemKNN_binary_k200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ItemKNN_binary_k300</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>EASE_binary_l2000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168453</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060929</td>\n",
       "      <td>0.085723</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>P3alpha_a0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097387</td>\n",
       "      <td>0.168407</td>\n",
       "      <td>0.309405</td>\n",
       "      <td>0.060372</td>\n",
       "      <td>0.085252</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0     RP3beta_a1_0_b0_6      43199  0.097687  0.169448  0.308896  0.060618   \n",
       "1     EASE_binary_l1000      43199  0.097711  0.169356  0.309614  0.060787   \n",
       "2          P3alpha_a1_0      43199  0.098104  0.168870  0.310609  0.060332   \n",
       "3      EASE_count_l1000      43199  0.096924  0.168870  0.310563  0.060030   \n",
       "4     ItemKNN_bm25_k100      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "5     ItemKNN_bm25_k200      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "6     ItemKNN_bm25_k300      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "7   ItemKNN_binary_k100      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "8   ItemKNN_binary_k200      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "9   ItemKNN_binary_k300      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "10    EASE_binary_l2000      43199  0.098104  0.168453  0.309614  0.060929   \n",
       "11         P3alpha_a0_5      43199  0.097387  0.168407  0.309405  0.060372   \n",
       "\n",
       "     NDCG@10  \n",
       "0   0.085705  \n",
       "1   0.085790  \n",
       "2   0.085333  \n",
       "3   0.085076  \n",
       "4   0.085358  \n",
       "5   0.085358  \n",
       "6   0.085358  \n",
       "7   0.085517  \n",
       "8   0.085517  \n",
       "9   0.085517  \n",
       "10  0.085723  \n",
       "11  0.085252  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top neural models:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TwoTower_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096067</td>\n",
       "      <td>0.170351</td>\n",
       "      <td>0.315192</td>\n",
       "      <td>0.059292</td>\n",
       "      <td>0.084829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094099</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.310077</td>\n",
       "      <td>0.057802</td>\n",
       "      <td>0.083448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NeuMF</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094354</td>\n",
       "      <td>0.169263</td>\n",
       "      <td>0.310378</td>\n",
       "      <td>0.058961</td>\n",
       "      <td>0.084273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>SoftmaxMLP_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.093197</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.313850</td>\n",
       "      <td>0.057147</td>\n",
       "      <td>0.082778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>BPRMF</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094169</td>\n",
       "      <td>0.166717</td>\n",
       "      <td>0.309961</td>\n",
       "      <td>0.057814</td>\n",
       "      <td>0.082833</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>TwoTower_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097456</td>\n",
       "      <td>0.166555</td>\n",
       "      <td>0.310331</td>\n",
       "      <td>0.060473</td>\n",
       "      <td>0.084917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>MultVAE_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.092873</td>\n",
       "      <td>0.165976</td>\n",
       "      <td>0.309359</td>\n",
       "      <td>0.056188</td>\n",
       "      <td>0.081413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>SoftmaxMLP_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.091252</td>\n",
       "      <td>0.164795</td>\n",
       "      <td>0.309683</td>\n",
       "      <td>0.054060</td>\n",
       "      <td>0.079489</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>MultDAE</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.088011</td>\n",
       "      <td>0.162018</td>\n",
       "      <td>0.307553</td>\n",
       "      <td>0.052818</td>\n",
       "      <td>0.077864</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0     TwoTower_base      43199  0.096067  0.170351  0.315192  0.059292   \n",
       "1      MultVAE_base      43199  0.094099  0.169518  0.310077  0.057802   \n",
       "2             NeuMF      43199  0.094354  0.169263  0.310378  0.058961   \n",
       "3  SoftmaxMLP_large      43199  0.093197  0.168870  0.313850  0.057147   \n",
       "4             BPRMF      43199  0.094169  0.166717  0.309961  0.057814   \n",
       "5    TwoTower_large      43199  0.097456  0.166555  0.310331  0.060473   \n",
       "6     MultVAE_large      43199  0.092873  0.165976  0.309359  0.056188   \n",
       "7   SoftmaxMLP_base      43199  0.091252  0.164795  0.309683  0.054060   \n",
       "8           MultDAE      43199  0.088011  0.162018  0.307553  0.052818   \n",
       "\n",
       "    NDCG@10  \n",
       "0  0.084829  \n",
       "1  0.083448  \n",
       "2  0.084273  \n",
       "3  0.082778  \n",
       "4  0.082833  \n",
       "5  0.084917  \n",
       "6  0.081413  \n",
       "7  0.079489  \n",
       "8  0.077864  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall ranking:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TwoTower_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096067</td>\n",
       "      <td>0.170351</td>\n",
       "      <td>0.315192</td>\n",
       "      <td>0.059292</td>\n",
       "      <td>0.084829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094099</td>\n",
       "      <td>0.169518</td>\n",
       "      <td>0.310077</td>\n",
       "      <td>0.057802</td>\n",
       "      <td>0.083448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169448</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060618</td>\n",
       "      <td>0.085705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097711</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NeuMF</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094354</td>\n",
       "      <td>0.169263</td>\n",
       "      <td>0.310378</td>\n",
       "      <td>0.058961</td>\n",
       "      <td>0.084273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>P3alpha_a1_0</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060332</td>\n",
       "      <td>0.085333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096924</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310563</td>\n",
       "      <td>0.060030</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>SoftmaxMLP_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.093197</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.313850</td>\n",
       "      <td>0.057147</td>\n",
       "      <td>0.082778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>ItemKNN_bm25_k100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ItemKNN_bm25_k200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>ItemKNN_bm25_k300</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.308757</td>\n",
       "      <td>0.060349</td>\n",
       "      <td>0.085358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>ItemKNN_binary_k100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>ItemKNN_binary_k200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>ItemKNN_binary_k300</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308016</td>\n",
       "      <td>0.060613</td>\n",
       "      <td>0.085517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>EASE_binary_l2000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168453</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060929</td>\n",
       "      <td>0.085723</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>P3alpha_a0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097387</td>\n",
       "      <td>0.168407</td>\n",
       "      <td>0.309405</td>\n",
       "      <td>0.060372</td>\n",
       "      <td>0.085252</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>RP3beta_a0_8_b0_3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097016</td>\n",
       "      <td>0.168337</td>\n",
       "      <td>0.308086</td>\n",
       "      <td>0.060347</td>\n",
       "      <td>0.085230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>EASE_count_l500</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.092271</td>\n",
       "      <td>0.167759</td>\n",
       "      <td>0.309359</td>\n",
       "      <td>0.056853</td>\n",
       "      <td>0.082288</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>EASE_count_l2000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097618</td>\n",
       "      <td>0.167527</td>\n",
       "      <td>0.309058</td>\n",
       "      <td>0.060644</td>\n",
       "      <td>0.085284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>EASE_binary_l500</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095118</td>\n",
       "      <td>0.167434</td>\n",
       "      <td>0.309660</td>\n",
       "      <td>0.058010</td>\n",
       "      <td>0.083173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>BPRMF</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094169</td>\n",
       "      <td>0.166717</td>\n",
       "      <td>0.309961</td>\n",
       "      <td>0.057814</td>\n",
       "      <td>0.082833</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>TwoTower_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097456</td>\n",
       "      <td>0.166555</td>\n",
       "      <td>0.310331</td>\n",
       "      <td>0.060473</td>\n",
       "      <td>0.084917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>MultVAE_large</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.092873</td>\n",
       "      <td>0.165976</td>\n",
       "      <td>0.309359</td>\n",
       "      <td>0.056188</td>\n",
       "      <td>0.081413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>SoftmaxMLP_base</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.091252</td>\n",
       "      <td>0.164795</td>\n",
       "      <td>0.309683</td>\n",
       "      <td>0.054060</td>\n",
       "      <td>0.079489</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>EASE_count_l200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.086229</td>\n",
       "      <td>0.162342</td>\n",
       "      <td>0.306720</td>\n",
       "      <td>0.052313</td>\n",
       "      <td>0.077514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>MultDAE</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.088011</td>\n",
       "      <td>0.162018</td>\n",
       "      <td>0.307553</td>\n",
       "      <td>0.052818</td>\n",
       "      <td>0.077864</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>EASE_binary_l200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.088636</td>\n",
       "      <td>0.161717</td>\n",
       "      <td>0.306952</td>\n",
       "      <td>0.052433</td>\n",
       "      <td>0.077498</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>EASE_count_l100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.083289</td>\n",
       "      <td>0.160235</td>\n",
       "      <td>0.304289</td>\n",
       "      <td>0.050601</td>\n",
       "      <td>0.075694</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>EASE_binary_l100</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.084909</td>\n",
       "      <td>0.158916</td>\n",
       "      <td>0.304799</td>\n",
       "      <td>0.050151</td>\n",
       "      <td>0.075070</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>ContentKNN</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.072942</td>\n",
       "      <td>0.145235</td>\n",
       "      <td>0.291812</td>\n",
       "      <td>0.043139</td>\n",
       "      <td>0.066449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.003218</td>\n",
       "      <td>0.006019</td>\n",
       "      <td>0.011621</td>\n",
       "      <td>0.001746</td>\n",
       "      <td>0.002724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>CountryPopularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.002639</td>\n",
       "      <td>0.005185</td>\n",
       "      <td>0.010579</td>\n",
       "      <td>0.001451</td>\n",
       "      <td>0.002306</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0         TwoTower_base      43199  0.096067  0.170351  0.315192  0.059292   \n",
       "1          MultVAE_base      43199  0.094099  0.169518  0.310077  0.057802   \n",
       "2     RP3beta_a1_0_b0_6      43199  0.097687  0.169448  0.308896  0.060618   \n",
       "3     EASE_binary_l1000      43199  0.097711  0.169356  0.309614  0.060787   \n",
       "4                 NeuMF      43199  0.094354  0.169263  0.310378  0.058961   \n",
       "5          P3alpha_a1_0      43199  0.098104  0.168870  0.310609  0.060332   \n",
       "6      EASE_count_l1000      43199  0.096924  0.168870  0.310563  0.060030   \n",
       "7      SoftmaxMLP_large      43199  0.093197  0.168870  0.313850  0.057147   \n",
       "8     ItemKNN_bm25_k100      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "9     ItemKNN_bm25_k200      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "10    ItemKNN_bm25_k300      43199  0.098104  0.168731  0.308757  0.060349   \n",
       "11  ItemKNN_binary_k100      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "12  ItemKNN_binary_k200      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "13  ItemKNN_binary_k300      43199  0.098104  0.168638  0.308016  0.060613   \n",
       "14    EASE_binary_l2000      43199  0.098104  0.168453  0.309614  0.060929   \n",
       "15         P3alpha_a0_5      43199  0.097387  0.168407  0.309405  0.060372   \n",
       "16    RP3beta_a0_8_b0_3      43199  0.097016  0.168337  0.308086  0.060347   \n",
       "17      EASE_count_l500      43199  0.092271  0.167759  0.309359  0.056853   \n",
       "18     EASE_count_l2000      43199  0.097618  0.167527  0.309058  0.060644   \n",
       "19     EASE_binary_l500      43199  0.095118  0.167434  0.309660  0.058010   \n",
       "20                BPRMF      43199  0.094169  0.166717  0.309961  0.057814   \n",
       "21       TwoTower_large      43199  0.097456  0.166555  0.310331  0.060473   \n",
       "22        MultVAE_large      43199  0.092873  0.165976  0.309359  0.056188   \n",
       "23      SoftmaxMLP_base      43199  0.091252  0.164795  0.309683  0.054060   \n",
       "24      EASE_count_l200      43199  0.086229  0.162342  0.306720  0.052313   \n",
       "25              MultDAE      43199  0.088011  0.162018  0.307553  0.052818   \n",
       "26     EASE_binary_l200      43199  0.088636  0.161717  0.306952  0.052433   \n",
       "27      EASE_count_l100      43199  0.083289  0.160235  0.304289  0.050601   \n",
       "28     EASE_binary_l100      43199  0.084909  0.158916  0.304799  0.050151   \n",
       "29           ContentKNN      43199  0.072942  0.145235  0.291812  0.043139   \n",
       "30           Popularity      43199  0.003218  0.006019  0.011621  0.001746   \n",
       "31    CountryPopularity      43199  0.002639  0.005185  0.010579  0.001451   \n",
       "\n",
       "     NDCG@10  \n",
       "0   0.084829  \n",
       "1   0.083448  \n",
       "2   0.085705  \n",
       "3   0.085790  \n",
       "4   0.084273  \n",
       "5   0.085333  \n",
       "6   0.085076  \n",
       "7   0.082778  \n",
       "8   0.085358  \n",
       "9   0.085358  \n",
       "10  0.085358  \n",
       "11  0.085517  \n",
       "12  0.085517  \n",
       "13  0.085517  \n",
       "14  0.085723  \n",
       "15  0.085252  \n",
       "16  0.085230  \n",
       "17  0.082288  \n",
       "18  0.085284  \n",
       "19  0.083173  \n",
       "20  0.082833  \n",
       "21  0.084917  \n",
       "22  0.081413  \n",
       "23  0.079489  \n",
       "24  0.077514  \n",
       "25  0.077864  \n",
       "26  0.077498  \n",
       "27  0.075694  \n",
       "28  0.075070  \n",
       "29  0.066449  \n",
       "30  0.002724  \n",
       "31  0.002306  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best overall model: TwoTower_base\n"
     ]
    }
   ],
   "source": [
    "# Final combined comparison\n",
    "\n",
    "frames = [regular_results_df]\n",
    "if isinstance(neural_results_df, pd.DataFrame) and not neural_results_df.empty:\n",
    "    frames.append(neural_results_df)\n",
    "\n",
    "all_results_df = pd.concat(frames, ignore_index=True)\n",
    "all_results_df = all_results_df.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "print(\"Top regular models:\")\n",
    "display(regular_results_df.head(12))\n",
    "\n",
    "if isinstance(neural_results_df, pd.DataFrame) and not neural_results_df.empty:\n",
    "    print(\"Top neural models:\")\n",
    "    display(neural_results_df.head(12))\n",
    "else:\n",
    "    print(\"Neural models were skipped.\")\n",
    "\n",
    "print(\"Overall ranking:\")\n",
    "display(all_results_df)\n",
    "\n",
    "best_model_name = all_results_df.iloc[0][\"Model\"]\n",
    "print(\"Best overall model:\", best_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "bed7dcdd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example pseudo-user index: 0\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>AgeBin</th>\n",
       "      <th>user_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             user_id    Country  \\\n",
       "0  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "\n",
       "                         PriceBin                    MileageBin  \\\n",
       "0  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "\n",
       "             Condition               AgeBin  user_idx  \n",
       "0  Certified Pre-Owned  age_(1.999 to  4.0]         0  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rank</th>\n",
       "      <th>recommended_item</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>Chrysler :: Pacifica :: age_(1.999 to  4.0] ::...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>Chrysler :: Voyager :: age_(1.999 to  4.0] :: ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Dodge :: Challenger :: age_(1.999 to  4.0] :: ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>BMW :: Z4 :: age_(1.999 to  4.0] :: Certified ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>Dodge :: Journey :: age_(1.999 to  4.0] :: Cer...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>Land Rover :: Range Rover :: age_(1.999 to  4....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>Dodge :: Durango :: age_(1.999 to  4.0] :: Cer...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>Volkswagen :: Jetta :: age_(1.999 to  4.0] :: ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>Lexus :: RX :: age_(1.999 to  4.0] :: Certifie...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>Dodge :: Charger :: age_(1.999 to  4.0] :: Cer...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   rank                                   recommended_item\n",
       "0     1  Chrysler :: Pacifica :: age_(1.999 to  4.0] ::...\n",
       "1     2  Chrysler :: Voyager :: age_(1.999 to  4.0] :: ...\n",
       "2     3  Dodge :: Challenger :: age_(1.999 to  4.0] :: ...\n",
       "3     4  BMW :: Z4 :: age_(1.999 to  4.0] :: Certified ...\n",
       "4     5  Dodge :: Journey :: age_(1.999 to  4.0] :: Cer...\n",
       "5     6  Land Rover :: Range Rover :: age_(1.999 to  4....\n",
       "6     7  Dodge :: Durango :: age_(1.999 to  4.0] :: Cer...\n",
       "7     8  Volkswagen :: Jetta :: age_(1.999 to  4.0] :: ...\n",
       "8     9  Lexus :: RX :: age_(1.999 to  4.0] :: Certifie...\n",
       "9    10  Dodge :: Charger :: age_(1.999 to  4.0] :: Cer..."
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Show example recommendations from the best overall model\n",
    "\n",
    "all_model_funcs = {}\n",
    "all_model_funcs.update(regular_models)\n",
    "if 'neural_models' in globals():\n",
    "    all_model_funcs.update(neural_models)\n",
    "\n",
    "best_recommender = all_model_funcs[best_model_name]\n",
    "\n",
    "example_user = int(eval_users[0])\n",
    "example_item_ids = best_recommender(example_user, n=10)\n",
    "example_item_names = [item_ids[i] for i in example_item_ids]\n",
    "\n",
    "print(\"Example pseudo-user index:\", example_user)\n",
    "display(user_feature_df[user_feature_df[\"user_idx\"] == example_user])\n",
    "\n",
    "example_df = pd.DataFrame({\n",
    "    \"rank\": np.arange(1, len(example_item_names) + 1),\n",
    "    \"recommended_item\": example_item_names,\n",
    "})\n",
    "display(example_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26869d31",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "If you still want even harder settings after this run, the next levers are:\n",
    "- raise item granularity again\n",
    "- require lower popularity ceiling for formulation selection\n",
    "- increase negative-sampling pressure for pairwise neural models\n",
    "- tune the strongest 2 to 3 models only instead of running the full zoo"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
