{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2e584ae2",
   "metadata": {},
   "source": [
    "# Car recommender notebook — hard formulation, broader model zoo, GPU-heavy neural search\n",
    "\n",
    "This notebook is the corrected follow-up to the earlier car recommender runs.\n",
    "\n",
    "Main changes:\n",
    "- screens out trivial task formulations automatically\n",
    "- uses only harder pseudo-user / item constructions\n",
    "- evaluates a larger set of regular and neural recommenders\n",
    "- keeps the neural training path GPU-friendly with manual on-device batching\n",
    "- produces one combined ranking table for comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3dd295bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU thread target: 16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import gc\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "# Thread settings for NumPy / SciPy BLAS backends.\n",
    "# These must be set before importing numpy/scipy to have the best chance of taking effect.\n",
    "CPU_THREADS = min(16, os.cpu_count() or 1)\n",
    "os.environ[\"OMP_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"MKL_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = str(CPU_THREADS)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn.preprocessing import LabelEncoder, normalize\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from tqdm.auto import tqdm\n",
    "try:\n",
    "    from threadpoolctl import threadpool_limits\n",
    "except Exception:\n",
    "    threadpool_limits = None\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "\n",
    "if threadpool_limits is not None:\n",
    "    try:\n",
    "        threadpool_limits(limits=CPU_THREADS)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "print(\"CPU thread target:\", CPU_THREADS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77a4eb83",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n",
      "Fixed formulation: H1_country_price_mileage_cond_age__brand_model_age_cond\n"
     ]
    }
   ],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "CURRENT_YEAR = 2026\n",
    "\n",
    "# Binning\n",
    "N_PRICE_BINS = 12\n",
    "N_MILEAGE_BINS = 12\n",
    "N_AGE_BINS = 10\n",
    "\n",
    "# Interaction filtering for the fixed hard formulation\n",
    "MIN_ITEMS_PER_USER = 5\n",
    "MIN_USERS_PER_ITEM = 10\n",
    "MAX_FILTER_ITERS = 10\n",
    "\n",
    "# If the formulation is too sparse under the default thresholds, progressively relax them.\n",
    "FILTER_SCHEDULE = [\n",
    "    (5, 10),\n",
    "    (4, 8),\n",
    "    (3, 5),\n",
    "    (2, 3),\n",
    "]\n",
    "\n",
    "# Ranking metrics\n",
    "TOP_KS = [5, 10, 20]\n",
    "\n",
    "# We now skip the slow automatic formulation screening entirely.\n",
    "# H1 was already the best non-trivial formulation in the previous believable run.\n",
    "FIXED_FORMULATION_NAME = \"H1_country_price_mileage_cond_age__brand_model_age_cond\"\n",
    "FIXED_USER_COLS = [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"AgeBin\"]\n",
    "FIXED_ITEM_COLS = [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"]\n",
    "\n",
    "# Focus only on the strongest families from the previous run.\n",
    "P3_ALPHA_GRID = [0.85, 0.95, 1.00, 1.05, 1.15]\n",
    "RP3_GRID = [\n",
    "    (0.90, 0.40),\n",
    "    (0.95, 0.50),\n",
    "    (1.00, 0.50),\n",
    "    (1.00, 0.60),\n",
    "    (1.05, 0.60),\n",
    "    (1.10, 0.70),\n",
    "]\n",
    "EASE_BINARY_LAMBDAS = [600.0, 800.0, 1000.0, 1200.0, 1600.0, 2200.0, 3000.0]\n",
    "EASE_COUNT_LAMBDAS = [600.0, 800.0, 1000.0, 1200.0, 1600.0, 2200.0, 3000.0]\n",
    "\n",
    "# Neural tuning grids for the strongest neural families\n",
    "TWOTOWER_CONFIGS = [\n",
    "    {\"name\": \"TwoTower_t1\", \"emb_dim\": 64,  \"hidden_dims\": (512, 256),       \"out_dim\": 128, \"epochs\": 24, \"lr\": 2e-3,   \"wd\": 1e-5, \"temperature\": 0.07},\n",
    "    {\"name\": \"TwoTower_t2\", \"emb_dim\": 96,  \"hidden_dims\": (768, 384, 192),  \"out_dim\": 160, \"epochs\": 28, \"lr\": 1.5e-3, \"wd\": 1e-5, \"temperature\": 0.05},\n",
    "    {\"name\": \"TwoTower_t3\", \"emb_dim\": 128, \"hidden_dims\": (1024, 512, 256), \"out_dim\": 192, \"epochs\": 30, \"lr\": 1.2e-3, \"wd\": 1e-5, \"temperature\": 0.04},\n",
    "]\n",
    "\n",
    "MULTVAE_CONFIGS = [\n",
    "    {\"name\": \"MultVAE_v1\", \"hidden_dim\": 1024, \"latent_dim\": 256, \"dropout\": 0.20, \"epochs\": 90,  \"lr\": 1e-3,  \"wd\": 0.0, \"anneal_cap\": 0.70},\n",
    "    {\"name\": \"MultVAE_v2\", \"hidden_dim\": 1536, \"latent_dim\": 384, \"dropout\": 0.25, \"epochs\": 100, \"lr\": 8e-4,  \"wd\": 0.0, \"anneal_cap\": 0.80},\n",
    "    {\"name\": \"MultVAE_v3\", \"hidden_dim\": 2048, \"latent_dim\": 512, \"dropout\": 0.30, \"epochs\": 110, \"lr\": 6e-4,  \"wd\": 0.0, \"anneal_cap\": 1.00},\n",
    "]\n",
    "\n",
    "NEUMF_CONFIGS = [\n",
    "    {\"name\": \"NeuMF_n1\", \"mf_dim\": 64,  \"mlp_dim\": 128, \"hidden_dims\": (256, 128),      \"epochs\": 28, \"lr\": 2e-3,   \"wd\": 1e-6},\n",
    "    {\"name\": \"NeuMF_n2\", \"mf_dim\": 96,  \"mlp_dim\": 192, \"hidden_dims\": (384, 192, 96),  \"epochs\": 32, \"lr\": 1.5e-3, \"wd\": 1e-6},\n",
    "]\n",
    "\n",
    "# Neural training defaults / batch sizes\n",
    "SOFTMAX_EPOCHS = 0\n",
    "PAIRWISE_EPOCHS = 0\n",
    "TWOTOWER_BATCH_SIZE = 16384\n",
    "AUTOENC_BATCH_SIZE = 8192\n",
    "PAIRWISE_BATCH_SIZE = 32768\n",
    "\n",
    "USE_AMP = False  # initialized safely before torch setup; updated later if CUDA is active\n",
    "RUN_FUSION = True\n",
    "FUSION_FETCH_N = 100\n",
    "FUSION_RRF_K = 60\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Fixed formulation:\", FIXED_FORMULATION_NAME)\n",
    "\n",
    "EASE_EVAL_BATCH_SIZE = 2048\n",
    "EASE_USE_GPU = True\n",
    "EASE_GPU_BATCH_SIZE = 8192\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "44983c26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch device: cuda\n",
      "USE_AMP: True\n"
     ]
    }
   ],
   "source": [
    "# Runtime toggles\n",
    "\n",
    "# Neural models section\n",
    "RUN_NEURAL = True\n",
    "PREFER_CUDA = True\n",
    "\n",
    "# EASE can also use torch on GPU for batched scoring even before the neural section.\n",
    "NEED_TORCH = RUN_NEURAL or EASE_USE_GPU\n",
    "\n",
    "HAS_TORCH = False\n",
    "torch = None\n",
    "nn = None\n",
    "F = None\n",
    "device = None\n",
    "\n",
    "if NEED_TORCH:\n",
    "    import torch\n",
    "    import torch.nn as nn\n",
    "    import torch.nn.functional as F\n",
    "\n",
    "    HAS_TORCH = True\n",
    "    torch.manual_seed(RANDOM_STATE)\n",
    "\n",
    "    device = torch.device(\"cuda\" if PREFER_CUDA and torch.cuda.is_available() else \"cpu\")\n",
    "    USE_AMP = (device.type == \"cuda\")\n",
    "    print(\"Torch device:\", device)\n",
    "    print(\"USE_AMP:\", USE_AMP)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "        torch.backends.cuda.matmul.allow_tf32 = True\n",
    "        torch.backends.cudnn.allow_tf32 = True\n",
    "        torch.backends.cudnn.benchmark = True\n",
    "        try:\n",
    "            torch.set_float32_matmul_precision(\"high\")\n",
    "        except Exception:\n",
    "            pass\n",
    "else:\n",
    "    USE_AMP = False\n",
    "    print(\"PyTorch disabled for this run.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4c1a8ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows: 1000000\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>price_(23744.848 to  29982.3]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(24.0 to  26.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>price_(48755.822 to  55009.26]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(16665.0 to  33316.0]</td>\n",
       "      <td>age_(22.0 to  24.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>price_(5000.059 to  11233.652]</td>\n",
       "      <td>mileage_(66608.0 to  83237.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  Age  \\\n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil    3   \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy   26   \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK   12   \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico   23   \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA   14   \n",
       "\n",
       "                         PriceBin                     MileageBin  \\\n",
       "0   price_(23744.848 to  29982.3]  mileage_(49995.0 to  66608.0]   \n",
       "1  price_(11233.652 to  17482.96]  mileage_(49995.0 to  66608.0]   \n",
       "2  price_(48755.822 to  55009.26]   mileage_(-0.001 to  16665.0]   \n",
       "3  price_(11233.652 to  17482.96]  mileage_(16665.0 to  33316.0]   \n",
       "4  price_(5000.059 to  11233.652]  mileage_(66608.0 to  83237.0]   \n",
       "\n",
       "                AgeBin  \n",
       "0  age_(1.999 to  4.0]  \n",
       "1  age_(24.0 to  26.0]  \n",
       "2  age_(11.0 to  14.0]  \n",
       "3  age_(22.0 to  24.0]  \n",
       "4  age_(11.0 to  14.0]  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>Country</th>\n",
       "      <th>Age</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>AgeBin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>3</td>\n",
       "      <td>price_(23744.848 to  29982.3]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Italy</td>\n",
       "      <td>26</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(49995.0 to  66608.0]</td>\n",
       "      <td>age_(24.0 to  26.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>UK</td>\n",
       "      <td>12</td>\n",
       "      <td>price_(48755.822 to  55009.26]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>23</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(16665.0 to  33316.0]</td>\n",
       "      <td>age_(22.0 to  24.0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>USA</td>\n",
       "      <td>14</td>\n",
       "      <td>price_(5000.059 to  11233.652]</td>\n",
       "      <td>mileage_(66608.0 to  83237.0]</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition Country  Age                        PriceBin  \\\n",
       "0  Certified Pre-Owned  Brazil    3   price_(23744.848 to  29982.3]   \n",
       "1  Certified Pre-Owned   Italy   26  price_(11233.652 to  17482.96]   \n",
       "2  Certified Pre-Owned      UK   12  price_(48755.822 to  55009.26]   \n",
       "3                 Used  Mexico   23  price_(11233.652 to  17482.96]   \n",
       "4                 Used     USA   14  price_(5000.059 to  11233.652]   \n",
       "\n",
       "                      MileageBin               AgeBin  \n",
       "0  mileage_(49995.0 to  66608.0]  age_(1.999 to  4.0]  \n",
       "1  mileage_(49995.0 to  66608.0]  age_(24.0 to  26.0]  \n",
       "2   mileage_(-0.001 to  16665.0]  age_(11.0 to  14.0]  \n",
       "3  mileage_(16665.0 to  33316.0]  age_(22.0 to  24.0]  \n",
       "4  mileage_(66608.0 to  83237.0]  age_(11.0 to  14.0]  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and clean data\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "expected_cols = [\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\"]\n",
    "missing = [c for c in expected_cols if c not in df.columns]\n",
    "if missing:\n",
    "    raise ValueError(f\"Missing expected columns: {missing}\")\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\"]:\n",
    "    df[col] = df[col].astype(str).str.strip().replace({\"\": \"Unknown\", \"nan\": \"Unknown\"})\n",
    "\n",
    "df[\"Year\"] = pd.to_numeric(df[\"Year\"], errors=\"coerce\")\n",
    "df[\"Price\"] = pd.to_numeric(df[\"Price\"], errors=\"coerce\")\n",
    "df[\"Mileage\"] = pd.to_numeric(df[\"Mileage\"], errors=\"coerce\")\n",
    "\n",
    "df = df.dropna(subset=[\"Year\", \"Price\", \"Mileage\"]).copy()\n",
    "\n",
    "df = df[df[\"Year\"].between(1990, CURRENT_YEAR)].copy()\n",
    "df = df[df[\"Price\"] > 0].copy()\n",
    "df = df[df[\"Mileage\"] >= 0].copy()\n",
    "\n",
    "df[\"Age\"] = (CURRENT_YEAR - df[\"Year\"]).clip(lower=0, upper=50)\n",
    "\n",
    "def make_qbin(series, n_bins, prefix):\n",
    "    cat = pd.qcut(series, q=n_bins, duplicates=\"drop\")\n",
    "    return cat.astype(str).str.replace(\",\", \" to \", regex=False).map(lambda x: f\"{prefix}_{x}\")\n",
    "\n",
    "df[\"PriceBin\"] = make_qbin(df[\"Price\"], N_PRICE_BINS, \"price\")\n",
    "df[\"MileageBin\"] = make_qbin(df[\"Mileage\"], N_MILEAGE_BINS, \"mileage\")\n",
    "df[\"AgeBin\"] = make_qbin(df[\"Age\"], N_AGE_BINS, \"age\")\n",
    "\n",
    "print(\"Rows:\", len(df))\n",
    "display(df.head())\n",
    "display(df[[\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\", \"Age\", \"PriceBin\", \"MileageBin\", \"AgeBin\"]].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "071498f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user=5, min_users_per_item=10, max_iters=10):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_n = out.shape[0]\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if out.shape[0] == old_n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    raw_interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "           .size()\n",
    "           .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = None\n",
    "    used_thresholds = None\n",
    "\n",
    "    for min_items_per_user, min_users_per_item in FILTER_SCHEDULE:\n",
    "        candidate = iterative_filter(\n",
    "            raw_interactions,\n",
    "            min_items_per_user=min_items_per_user,\n",
    "            min_users_per_item=min_users_per_item,\n",
    "            max_iters=MAX_FILTER_ITERS\n",
    "        ).reset_index(drop=True)\n",
    "\n",
    "        if not candidate.empty:\n",
    "            interactions = candidate\n",
    "            used_thresholds = (min_items_per_user, min_users_per_item)\n",
    "            break\n",
    "\n",
    "    if interactions is None or interactions.empty:\n",
    "        raise ValueError(\n",
    "            f\"{formulation_name} produced no interactions after filtering, \"\n",
    "            f\"even after trying FILTER_SCHEDULE={FILTER_SCHEDULE}\"\n",
    "        )\n",
    "\n",
    "    valid_users = set(interactions[\"user_id\"])\n",
    "    valid_items = set(interactions[\"item_id\"])\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    num_users = interactions[\"user_idx\"].nunique()\n",
    "    num_items = interactions[\"item_idx\"].nunique()\n",
    "\n",
    "    user_feature_df = (\n",
    "        tmp[tmp[\"user_id\"].isin(valid_users)][[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "    user_feature_df = user_feature_df.sort_values(\"user_idx\").reset_index(drop=True)\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[tmp[\"item_id\"].isin(valid_items)][[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "    item_feature_df = item_feature_df.sort_values(\"item_idx\").reset_index(drop=True)\n",
    "\n",
    "    # Weighted leave-one-out split\n",
    "    rng = np.random.default_rng(RANDOM_STATE)\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = rng.choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(picked)\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    user_seen = {\n",
    "        int(uid): set(g[\"item_idx\"].astype(int).tolist())\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "    user_seen_arrays = {\n",
    "        uid: np.array(sorted(list(seen)), dtype=np.int32)\n",
    "        for uid, seen in user_seen.items()\n",
    "    }\n",
    "\n",
    "    user_strength = {\n",
    "        int(uid): {int(i): float(c) for i, c in zip(g[\"item_idx\"], g[\"count\"])}\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(uid): int(i)\n",
    "        for uid, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy()\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df,\n",
    "        \"item_feature_df\": item_feature_df,\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen\": user_seen,\n",
    "        \"user_seen_arrays\": user_seen_arrays,\n",
    "        \"user_strength\": user_strength,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_encoder.classes_.tolist(),\n",
    "        \"user_ids\": user_encoder.classes_.tolist(),\n",
    "        \"used_thresholds\": used_thresholds,\n",
    "    }\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "\n",
    "    avg_train_per_user = bundle[\"train_interactions\"].shape[0] / max(bundle[\"num_users\"], 1)\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Thresholds :\", bundle.get(\"used_thresholds\"))\n",
    "    print(\"Train rows :\", bundle[\"train_interactions\"].shape[0])\n",
    "    print(\"Test rows  :\", bundle[\"test_interactions\"].shape[0])\n",
    "    print(\"Avg train interactions/user:\", round(avg_train_per_user, 3))\n",
    "    print(\"Matrix shape:\", bundle[\"X_binary\"].shape)\n",
    "\n",
    "def topn_from_scores(scores, seen, n):\n",
    "    scores = np.asarray(scores, dtype=np.float32).copy()\n",
    "    if seen:\n",
    "        seen_idx = np.fromiter(seen, dtype=np.int32)\n",
    "        scores[seen_idx] = -np.inf\n",
    "    n = min(int(n), scores.shape[0])\n",
    "    if n <= 0:\n",
    "        return []\n",
    "    idx = np.argpartition(scores, -n)[-n:]\n",
    "    idx = idx[np.argsort(scores[idx])[::-1]]\n",
    "    return idx.astype(int).tolist()\n",
    "\n",
    "def topn_from_torch_scores(scores, seen, n):\n",
    "    x = scores.clone()\n",
    "    if seen:\n",
    "        idx = torch.tensor(list(seen), dtype=torch.long, device=x.device)\n",
    "        x[idx] = -1e9\n",
    "    k = min(int(n), int(x.shape[0]))\n",
    "    return torch.topk(x, k=k).indices.detach().cpu().tolist()\n",
    "\n",
    "def hit_rate_at_k(recs, true_item, k):\n",
    "    return 1.0 if true_item in recs[:k] else 0.0\n",
    "\n",
    "def mrr_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / rank\n",
    "    return 0.0\n",
    "\n",
    "def ndcg_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / math.log2(rank + 1)\n",
    "    return 0.0\n",
    "\n",
    "def evaluate_model(recommend_fn, model_name, bundle, user_indices=None, ks=(5, 10, 20)):\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr10 = 0.0\n",
    "    ndcg10 = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    for uid in tqdm(user_indices, desc=model_name):\n",
    "        uid = int(uid)\n",
    "        true_item = test_item_by_user.get(uid, None)\n",
    "        if true_item is None:\n",
    "            continue\n",
    "\n",
    "        recs = recommend_fn(uid, n=max(ks))\n",
    "        valid += 1\n",
    "\n",
    "        for k in ks:\n",
    "            hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "        mrr10 += mrr_at_k(recs, true_item, 10)\n",
    "        ndcg10 += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    out = {\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": valid,\n",
    "    }\n",
    "    for k in ks:\n",
    "        out[f\"HR@{k}\"] = hits[k] / max(valid, 1)\n",
    "    out[\"MRR@10\"] = mrr10 / max(valid, 1)\n",
    "    out[\"NDCG@10\"] = ndcg10 / max(valid, 1)\n",
    "    return out\n",
    "\n",
    "def evaluate_ease_batched(\n",
    "    B_matrix,\n",
    "    X_train_matrix,\n",
    "    user_seen_dict,\n",
    "    test_item_by_user,\n",
    "    model_name,\n",
    "    user_indices=None,\n",
    "    ks=(5, 10, 20),\n",
    "    batch_size=2048,\n",
    "    user_seen_arrays_dict=None,\n",
    "):\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(test_item_by_user.keys()), dtype=np.int32)\n",
    "    else:\n",
    "        user_indices = np.asarray(user_indices, dtype=np.int32)\n",
    "\n",
    "    max_k = max(ks)\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr10 = 0.0\n",
    "    ndcg10 = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    use_gpu = bool(\n",
    "        HAS_TORCH and EASE_USE_GPU and (device is not None) and (str(device).startswith(\"cuda\"))\n",
    "    )\n",
    "\n",
    "    B_t = None\n",
    "    if use_gpu:\n",
    "        B_t = torch.as_tensor(B_matrix, dtype=torch.float32, device=device)\n",
    "        batch_size = max(batch_size, EASE_GPU_BATCH_SIZE)\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), batch_size), desc=model_name):\n",
    "        batch_uids = user_indices[start:start + batch_size]\n",
    "\n",
    "        X_batch = X_train_matrix[batch_uids].toarray().astype(np.float32, copy=False)\n",
    "\n",
    "        if use_gpu:\n",
    "            with torch.inference_mode():\n",
    "                X_batch_t = torch.from_numpy(X_batch).to(device, non_blocking=True)\n",
    "                scores_t = X_batch_t @ B_t\n",
    "\n",
    "                for row_idx, uid in enumerate(batch_uids):\n",
    "                    uid_int = int(uid)\n",
    "                    if user_seen_arrays_dict is not None:\n",
    "                        seen_idx_np = user_seen_arrays_dict.get(uid_int, None)\n",
    "                    else:\n",
    "                        seen = user_seen_dict.get(uid_int, None)\n",
    "                        seen_idx_np = None if not seen else np.fromiter(seen, dtype=np.int32)\n",
    "                    if seen_idx_np is not None and len(seen_idx_np) > 0:\n",
    "                        seen_idx_t = torch.as_tensor(seen_idx_np, dtype=torch.long, device=device)\n",
    "                        scores_t[row_idx, seen_idx_t] = float(\"-inf\")\n",
    "\n",
    "                recs_batch = torch.topk(scores_t, k=max_k, dim=1).indices.detach().cpu().numpy()\n",
    "\n",
    "                del X_batch_t, scores_t\n",
    "        else:\n",
    "            scores = X_batch @ B_matrix\n",
    "            scores = np.asarray(scores, dtype=np.float32)\n",
    "\n",
    "            for row_idx, uid in enumerate(batch_uids):\n",
    "                uid_int = int(uid)\n",
    "                if user_seen_arrays_dict is not None:\n",
    "                    seen_idx = user_seen_arrays_dict.get(uid_int, None)\n",
    "                else:\n",
    "                    seen = user_seen_dict.get(uid_int, None)\n",
    "                    seen_idx = None if not seen else np.fromiter(seen, dtype=np.int32)\n",
    "                if seen_idx is not None and len(seen_idx) > 0:\n",
    "                    scores[row_idx, seen_idx] = -np.inf\n",
    "\n",
    "            top_idx = np.argpartition(scores, -max_k, axis=1)[:, -max_k:]\n",
    "            top_scores = np.take_along_axis(scores, top_idx, axis=1)\n",
    "            order = np.argsort(top_scores, axis=1)[:, ::-1]\n",
    "            recs_batch = np.take_along_axis(top_idx, order, axis=1)\n",
    "\n",
    "        for row_idx, uid in enumerate(batch_uids):\n",
    "            uid = int(uid)\n",
    "            true_item = test_item_by_user.get(uid, None)\n",
    "            if true_item is None:\n",
    "                continue\n",
    "\n",
    "            recs = recs_batch[row_idx].astype(int).tolist()\n",
    "            valid += 1\n",
    "\n",
    "            for k in ks:\n",
    "                hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "            mrr10 += mrr_at_k(recs, true_item, 10)\n",
    "            ndcg10 += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    if use_gpu:\n",
    "        del B_t\n",
    "        try:\n",
    "            torch.cuda.empty_cache()\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "    out = {\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": valid,\n",
    "    }\n",
    "    for k in ks:\n",
    "        out[f\"HR@{k}\"] = hits[k] / max(valid, 1)\n",
    "    out[\"MRR@10\"] = mrr10 / max(valid, 1)\n",
    "    out[\"NDCG@10\"] = ndcg10 / max(valid, 1)\n",
    "    return out\n",
    "\n",
    "def bm25_weight(X, K1=100, B=0.8):\n",
    "    X = X.tocoo(copy=True).astype(np.float32)\n",
    "    N = float(X.shape[0])\n",
    "\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    avgdl = row_sums.mean() + 1e-8\n",
    "\n",
    "    df = np.bincount(X.col, minlength=X.shape[1]).astype(np.float32)\n",
    "    idf = np.log((N - df + 0.5) / (df + 0.5))\n",
    "    idf = np.maximum(idf, 0)\n",
    "\n",
    "    denom = X.data + K1 * (1 - B + B * row_sums[X.row] / avgdl)\n",
    "    data = X.data * (K1 + 1) / denom * idf[X.col]\n",
    "\n",
    "    return csr_matrix((data, (X.row, X.col)), shape=X.shape, dtype=np.float32)\n",
    "\n",
    "def l1_row_normalize(X):\n",
    "    X = X.tocsr().astype(np.float32)\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    inv = np.zeros_like(row_sums, dtype=np.float32)\n",
    "    mask = row_sums > 0\n",
    "    inv[mask] = 1.0 / row_sums[mask]\n",
    "    return sp.diags(inv) @ X\n",
    "\n",
    "def fit_p3alpha(X_binary, alpha=1.0, beta=0.0):\n",
    "    Pui = l1_row_normalize(X_binary).power(alpha).tocsr()\n",
    "    Piu = l1_row_normalize(X_binary.T).power(alpha).tocsr()\n",
    "    S = (Piu @ Pui).astype(np.float32).tocsr()\n",
    "    S.setdiag(0.0)\n",
    "    S.eliminate_zeros()\n",
    "\n",
    "    if beta > 0:\n",
    "        item_degree = np.asarray(X_binary.sum(axis=0)).ravel().astype(np.float32)\n",
    "        penalty = np.ones_like(item_degree, dtype=np.float32)\n",
    "        mask = item_degree > 0\n",
    "        penalty[mask] = np.power(item_degree[mask], -beta)\n",
    "        S = S @ sp.diags(penalty.astype(np.float32))\n",
    "\n",
    "    return S.tocsr()\n",
    "\n",
    "def make_sparse_similarity_recommender(X_matrix, S_matrix, user_seen_dict):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_matrix.getrow(uid) @ S_matrix\n",
    "        scores = np.asarray(scores.todense()).ravel()\n",
    "        seen = user_seen_dict.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend\n",
    "\n",
    "def make_rrf_recommender(named_recommenders, fetch_n=100, rrf_k=60):\n",
    "    # named_recommenders: list of (name, recommend_fn)\n",
    "    def recommend(user_idx, n=10):\n",
    "        rank_scores = {}\n",
    "        for model_name, rec_fn in named_recommenders:\n",
    "            recs = rec_fn(int(user_idx), n=fetch_n)\n",
    "            for rank, item_idx in enumerate(recs, start=1):\n",
    "                rank_scores[int(item_idx)] = rank_scores.get(int(item_idx), 0.0) + 1.0 / (rrf_k + rank)\n",
    "\n",
    "        if not rank_scores:\n",
    "            return recommend_popularity(int(user_idx), n=n) if 'recommend_popularity' in globals() else []\n",
    "\n",
    "        ranked = sorted(rank_scores.items(), key=lambda x: x[1], reverse=True)[:n]\n",
    "        return [int(i) for i, _ in ranked]\n",
    "\n",
    "    return recommend\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5b918a2",
   "metadata": {},
   "source": [
    "## Quick formulation screening\n",
    "\n",
    "The notebook will build several harder pseudo-user / item formulations, evaluate cheap baselines on each one, and automatically reject trivial setups.\n",
    "\n",
    "The main selection rule is:\n",
    "- enough items\n",
    "- popularity baseline not too strong\n",
    "- average interactions per pseudo-user not too high\n",
    "- strong EASE screen score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0d081b84",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Formulation</th>\n",
       "      <th>Users</th>\n",
       "      <th>Items</th>\n",
       "      <th>TrainRows</th>\n",
       "      <th>AvgTrainPerUser</th>\n",
       "      <th>Thresholds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1_country_price_mileage_cond_age__brand_model...</td>\n",
       "      <td>43199</td>\n",
       "      <td>2640</td>\n",
       "      <td>831152</td>\n",
       "      <td>19.240075</td>\n",
       "      <td>(5, 10)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         Formulation  Users  Items  TrainRows  \\\n",
       "0  H1_country_price_mileage_cond_age__brand_model...  43199   2640     831152   \n",
       "\n",
       "   AvgTrainPerUser Thresholds  \n",
       "0        19.240075    (5, 10)  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Formulation: H1_country_price_mileage_cond_age__brand_model_age_cond\n",
      "User cols  : ['Country', 'PriceBin', 'MileageBin', 'Condition', 'AgeBin']\n",
      "Item cols  : ['Brand', 'Model', 'AgeBin', 'Condition']\n",
      "Users      : 43199\n",
      "Items      : 2640\n",
      "Thresholds : (5, 10)\n",
      "Train rows : 831152\n",
      "Test rows  : 43199\n",
      "Avg train interactions/user: 19.24\n",
      "Matrix shape: (43199, 2640)\n"
     ]
    }
   ],
   "source": [
    "# Build the fixed hard formulation directly\n",
    "\n",
    "data = build_formulation(\n",
    "    df,\n",
    "    user_cols=FIXED_USER_COLS,\n",
    "    item_cols=FIXED_ITEM_COLS,\n",
    "    formulation_name=FIXED_FORMULATION_NAME,\n",
    ")\n",
    "\n",
    "summary_df = pd.DataFrame([{\n",
    "    \"Formulation\": data[\"name\"],\n",
    "    \"Users\": data[\"num_users\"],\n",
    "    \"Items\": data[\"num_items\"],\n",
    "    \"TrainRows\": data[\"train_interactions\"].shape[0],\n",
    "    \"AvgTrainPerUser\": data[\"train_interactions\"].shape[0] / data[\"num_users\"],\n",
    "    \"Thresholds\": data.get(\"used_thresholds\"),\n",
    "}])\n",
    "\n",
    "display(summary_df)\n",
    "print_bundle_summary(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "da9f1404",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using fixed formulation without screening: H1_country_price_mileage_cond_age__brand_model_age_cond\n"
     ]
    }
   ],
   "source": [
    "# No automatic formulation screening in this version.\n",
    "# H1 is used directly because it was already the strongest non-trivial formulation\n",
    "# in the previous meaningful run, and the screening stage was the main time/RAM sink.\n",
    "\n",
    "print(\"Using fixed formulation without screening:\", data[\"name\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ab89fdb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final working matrix: (43199, 2640)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>AgeBin</th>\n",
       "      <th>user_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(14.0 to  16.0]</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(16.0 to  19.0]</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(19.0 to  22.0]</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             user_id    Country  \\\n",
       "0  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "1  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "2  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "3  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "4  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "\n",
       "                         PriceBin                    MileageBin  \\\n",
       "0  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "1  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "2  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "3  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "4  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "\n",
       "             Condition               AgeBin  user_idx  \n",
       "0  Certified Pre-Owned  age_(1.999 to  4.0]         0  \n",
       "1  Certified Pre-Owned  age_(11.0 to  14.0]         1  \n",
       "2  Certified Pre-Owned  age_(14.0 to  16.0]         2  \n",
       "3  Certified Pre-Owned  age_(16.0 to  19.0]         3  \n",
       "4  Certified Pre-Owned  age_(19.0 to  22.0]         4  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>AgeBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>item_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Audi :: A3 :: age_(1.999 to  4.0] :: Certified...</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Audi :: A3 :: age_(1.999 to  4.0] :: New</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>New</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Audi :: A3 :: age_(1.999 to  4.0] :: Used</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>Used</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Audi :: A3 :: age_(11.0 to  14.0] :: Certified...</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Audi :: A3 :: age_(11.0 to  14.0] :: New</td>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>age_(11.0 to  14.0]</td>\n",
       "      <td>New</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             item_id Brand Model  \\\n",
       "0  Audi :: A3 :: age_(1.999 to  4.0] :: Certified...  Audi    A3   \n",
       "1           Audi :: A3 :: age_(1.999 to  4.0] :: New  Audi    A3   \n",
       "2          Audi :: A3 :: age_(1.999 to  4.0] :: Used  Audi    A3   \n",
       "3  Audi :: A3 :: age_(11.0 to  14.0] :: Certified...  Audi    A3   \n",
       "4           Audi :: A3 :: age_(11.0 to  14.0] :: New  Audi    A3   \n",
       "\n",
       "                AgeBin            Condition  item_idx  \n",
       "0  age_(1.999 to  4.0]  Certified Pre-Owned         0  \n",
       "1  age_(1.999 to  4.0]                  New         1  \n",
       "2  age_(1.999 to  4.0]                 Used         2  \n",
       "3  age_(11.0 to  14.0]  Certified Pre-Owned         3  \n",
       "4  age_(11.0 to  14.0]                  New         4  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Prepare reusable arrays and feature tables for the selected formulation\n",
    "\n",
    "user_feature_df = data[\"user_feature_df\"].copy()\n",
    "item_feature_df = data[\"item_feature_df\"].copy()\n",
    "\n",
    "train_interactions = data[\"train_interactions\"].copy()\n",
    "test_interactions = data[\"test_interactions\"].copy()\n",
    "\n",
    "X_counts = data[\"X_counts\"].tocsr()\n",
    "X_binary = data[\"X_binary\"].tocsr()\n",
    "\n",
    "num_users = data[\"num_users\"]\n",
    "num_items = data[\"num_items\"]\n",
    "\n",
    "user_seen = data[\"user_seen\"]\n",
    "user_seen_arrays = data[\"user_seen_arrays\"]\n",
    "user_strength = data[\"user_strength\"]\n",
    "test_item_by_user = data[\"test_item_by_user\"]\n",
    "global_pop_rank = data[\"global_pop_rank\"]\n",
    "item_ids = data[\"item_ids\"]\n",
    "\n",
    "X_binary_dense = X_binary.toarray().astype(np.float32)\n",
    "X_counts_dense = X_counts.toarray().astype(np.float32)\n",
    "\n",
    "eval_users = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "print(\"Final working matrix:\", X_binary.shape)\n",
    "display(user_feature_df.head())\n",
    "display(item_feature_df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa1814d8",
   "metadata": {},
   "source": [
    "## Regular model zoo\n",
    "\n",
    "This section includes:\n",
    "- popularity baselines\n",
    "- item-based collaborative filtering\n",
    "- content-based KNN\n",
    "- graph-style recommenders\n",
    "- EASE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a1afb92e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Popularity baselines\n",
    "\n",
    "def recommend_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(int(user_idx), set())\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]\n",
    "\n",
    "country_col = \"Country\" if \"Country\" in user_feature_df.columns else None\n",
    "user_country = {}\n",
    "country_pop_rank = {}\n",
    "\n",
    "if country_col is not None:\n",
    "    user_country = dict(zip(user_feature_df[\"user_idx\"], user_feature_df[country_col]))\n",
    "    train_with_country = train_interactions.merge(\n",
    "        user_feature_df[[\"user_idx\", country_col]],\n",
    "        on=\"user_idx\",\n",
    "        how=\"left\"\n",
    "    )\n",
    "\n",
    "    for country, g in train_with_country.groupby(country_col):\n",
    "        country_pop_rank[country] = (\n",
    "            g.groupby(\"item_idx\")[\"count\"]\n",
    "             .sum()\n",
    "             .sort_values(ascending=False)\n",
    "             .index.to_numpy()\n",
    "        )\n",
    "\n",
    "def recommend_country_popularity(user_idx, n=10):\n",
    "    uid = int(user_idx)\n",
    "    seen = user_seen.get(uid, set())\n",
    "    country = user_country.get(uid, None)\n",
    "\n",
    "    if country in country_pop_rank:\n",
    "        recs = [int(i) for i in country_pop_rank[country] if int(i) not in seen][:n]\n",
    "        if len(recs) >= n:\n",
    "            return recs\n",
    "\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "89251180",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ItemKNN variants\n",
    "\n",
    "def fit_itemknn_recommender(X_matrix, neighbors=100):\n",
    "    item_user = X_matrix.T.tocsr()\n",
    "\n",
    "    knn = NearestNeighbors(\n",
    "        metric=\"cosine\",\n",
    "        algorithm=\"brute\",\n",
    "        n_neighbors=min(neighbors, item_user.shape[0]),\n",
    "        n_jobs=-1\n",
    "    )\n",
    "    knn.fit(item_user)\n",
    "\n",
    "    distances, indices = knn.kneighbors(item_user, n_neighbors=min(neighbors, item_user.shape[0]))\n",
    "    similarities = (1.0 - distances).astype(np.float32)\n",
    "    return indices, similarities\n",
    "\n",
    "def make_itemknn_recommender(X_matrix, neighbors=100, use_strength=True):\n",
    "    neighbor_idx, neighbor_sim = fit_itemknn_recommender(X_matrix, neighbors=neighbors)\n",
    "\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        seen = user_seen.get(uid, set())\n",
    "        strength_map = user_strength.get(uid, {})\n",
    "        scores = {}\n",
    "\n",
    "        for item_idx in seen:\n",
    "            weight = float(strength_map.get(item_idx, 1.0)) if use_strength else 1.0\n",
    "            nbrs = neighbor_idx[item_idx]\n",
    "            sims = neighbor_sim[item_idx]\n",
    "\n",
    "            for j, sim in zip(nbrs, sims):\n",
    "                j = int(j)\n",
    "                if j == item_idx or j in seen:\n",
    "                    continue\n",
    "                scores[j] = scores.get(j, 0.0) + float(sim) * weight\n",
    "\n",
    "        if not scores:\n",
    "            return recommend_popularity(uid, n=n)\n",
    "\n",
    "        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:n]\n",
    "        return [int(i) for i, _ in ranked]\n",
    "\n",
    "    return recommend\n",
    "\n",
    "X_bm25 = bm25_weight(X_counts, K1=100, B=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "24fed1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Content-based KNN on item attributes\n",
    "\n",
    "item_feature_cols = data[\"item_cols\"]\n",
    "item_feature_onehot = pd.get_dummies(item_feature_df[item_feature_cols].astype(str), sparse=True)\n",
    "item_feature_sparse = csr_matrix(item_feature_onehot.sparse.to_coo()).astype(np.float32)\n",
    "\n",
    "item_feature_norm = normalize(item_feature_sparse, norm=\"l2\", axis=1)\n",
    "content_sim = (item_feature_norm @ item_feature_norm.T).astype(np.float32).toarray()\n",
    "np.fill_diagonal(content_sim, 0.0)\n",
    "\n",
    "def recommend_content_knn(user_idx, n=10):\n",
    "    uid = int(user_idx)\n",
    "    seen = user_seen.get(uid, set())\n",
    "    user_vec = X_counts.getrow(uid).toarray().ravel().astype(np.float32)\n",
    "    scores = user_vec @ content_sim\n",
    "    return topn_from_scores(scores, seen, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bee1b5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Graph recommenders: P3alpha and RP3beta\n",
    "\n",
    "def make_p3_recommender(X_input, alpha=1.0, beta=0.0):\n",
    "    S = fit_p3alpha(X_input, alpha=alpha, beta=beta)\n",
    "    return make_sparse_similarity_recommender(X_input, S, user_seen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "500d74ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EASE\n",
    "\n",
    "def fit_ease(X_matrix, lam):\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += lam\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    return B.astype(np.float32)\n",
    "\n",
    "def make_ease_recommender(X_train_dense, B_matrix):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_train_dense[uid] @ B_matrix\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af5374c9",
   "metadata": {},
   "source": [
    "## Neural model zoo\n",
    "\n",
    "This section favors models that actually keep the GPU busy:\n",
    "- full-softmax MLPs\n",
    "- two-tower retrieval\n",
    "- dense autoencoders\n",
    "- pairwise embedding models\n",
    "- NeuMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "366f6ed2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HAS_TORCH: True\n",
      "device: cuda\n",
      "USE_AMP: True\n"
     ]
    }
   ],
   "source": [
    "# PyTorch setup was already performed in the earlier runtime cell.\n",
    "\n",
    "print(\"HAS_TORCH:\", HAS_TORCH)\n",
    "print(\"device:\", device)\n",
    "print(\"USE_AMP:\", USE_AMP)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "89dad4e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive interaction rows: 831152\n",
      "Dense matrix shape: (43199, 2640)\n"
     ]
    }
   ],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Feature encoders for neural models\n",
    "\n",
    "    user_cols = data[\"user_cols\"]\n",
    "    item_cols = data[\"item_cols\"]\n",
    "\n",
    "    feature_encoders = {}\n",
    "    for col in sorted(set(user_cols + item_cols)):\n",
    "        enc = LabelEncoder()\n",
    "        values = pd.concat([\n",
    "            user_feature_df[col].astype(str) if col in user_feature_df.columns else pd.Series(dtype=str),\n",
    "            item_feature_df[col].astype(str) if col in item_feature_df.columns else pd.Series(dtype=str),\n",
    "        ], ignore_index=True)\n",
    "        enc.fit(values)\n",
    "        feature_encoders[col] = enc\n",
    "\n",
    "    user_feature_arrays = {\n",
    "        col: feature_encoders[col].transform(user_feature_df[col].astype(str))\n",
    "        for col in user_cols\n",
    "    }\n",
    "    item_feature_arrays = {\n",
    "        col: feature_encoders[col].transform(item_feature_df[col].astype(str))\n",
    "        for col in item_cols\n",
    "    }\n",
    "\n",
    "    user_feature_tensors = {\n",
    "        col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "        for col, arr in user_feature_arrays.items()\n",
    "    }\n",
    "    item_feature_tensors = {\n",
    "        col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "        for col, arr in item_feature_arrays.items()\n",
    "    }\n",
    "\n",
    "    positive_user_idx = torch.tensor(train_interactions[\"user_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "    positive_item_idx = torch.tensor(train_interactions[\"item_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "    positive_weight = torch.tensor(train_interactions[\"count\"].astype(np.float32).to_numpy(), dtype=torch.float32, device=device)\n",
    "\n",
    "    X_binary_dense_tensor = torch.tensor(X_binary_dense, dtype=torch.float32, device=device)\n",
    "    X_counts_dense_tensor = torch.tensor(X_counts_dense, dtype=torch.float32, device=device)\n",
    "\n",
    "    n_positive_rows = int(positive_user_idx.shape[0])\n",
    "    print(\"Positive interaction rows:\", n_positive_rows)\n",
    "    print(\"Dense matrix shape:\", tuple(X_binary_dense_tensor.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8fe48d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Manual batch iterators\n",
    "\n",
    "    def iterate_positive_batches(batch_size, shuffle=True):\n",
    "        n = n_positive_rows\n",
    "        order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "        for start in range(0, n, batch_size):\n",
    "            idx = order[start:start + batch_size]\n",
    "            yield positive_user_idx[idx], positive_item_idx[idx], positive_weight[idx]\n",
    "\n",
    "    def iterate_dense_user_batches(batch_size, shuffle=True):\n",
    "        n = int(X_binary_dense_tensor.shape[0])\n",
    "        order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "        for start in range(0, n, batch_size):\n",
    "            idx = order[start:start + batch_size]\n",
    "            yield X_binary_dense_tensor[idx], idx\n",
    "\n",
    "    def make_user_feature_batch(user_idx_batch):\n",
    "        return {col: user_feature_tensors[col][user_idx_batch] for col in user_cols}\n",
    "\n",
    "    def make_item_feature_batch(item_idx_batch):\n",
    "        return {col: item_feature_tensors[col][item_idx_batch] for col in item_cols}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f30d69b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Shared building blocks\n",
    "\n",
    "    class FeatureEmbeddingBlock(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, emb_dim):\n",
    "            super().__init__()\n",
    "            self.feature_names = list(feature_cardinalities.keys())\n",
    "            self.embs = nn.ModuleDict({\n",
    "                name: nn.Embedding(card, emb_dim)\n",
    "                for name, card in feature_cardinalities.items()\n",
    "            })\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            parts = [self.embs[name](feature_batch[name]) for name in self.feature_names]\n",
    "            return torch.cat(parts, dim=-1)\n",
    "\n",
    "    class MLPBlock(nn.Module):\n",
    "        def __init__(self, input_dim, hidden_dims, dropout=0.1, final_dim=None, final_activation=False):\n",
    "            super().__init__()\n",
    "            layers = []\n",
    "            last = input_dim\n",
    "            for h in hidden_dims:\n",
    "                layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]\n",
    "                last = h\n",
    "            if final_dim is not None:\n",
    "                layers.append(nn.Linear(last, final_dim))\n",
    "                last = final_dim\n",
    "                if final_activation:\n",
    "                    layers.append(nn.ReLU())\n",
    "            self.net = nn.Sequential(*layers)\n",
    "            self.output_dim = last\n",
    "\n",
    "        def forward(self, x):\n",
    "            return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0c79d1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 1: full-softmax feature MLP\n",
    "\n",
    "    class SoftmaxSegmentMLP(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, num_items, emb_dim=64, hidden_dims=(512, 512, 256), dropout=0.15):\n",
    "            super().__init__()\n",
    "            self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "            input_dim = len(feature_cardinalities) * emb_dim\n",
    "            self.mlp = MLPBlock(input_dim, hidden_dims, dropout=dropout)\n",
    "            self.out = nn.Linear(self.mlp.output_dim, num_items)\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            x = self.encoder(feature_batch)\n",
    "            x = self.mlp(x)\n",
    "            return self.out(x)\n",
    "\n",
    "    def train_softmax_model(model_name, emb_dim=64, hidden_dims=(512, 512, 256), epochs=20, lr=2e-3, wd=1e-5):\n",
    "        feature_cardinalities = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "        model = SoftmaxSegmentMLP(\n",
    "            feature_cardinalities=feature_cardinalities,\n",
    "            num_items=num_items,\n",
    "            emb_dim=emb_dim,\n",
    "            hidden_dims=hidden_dims,\n",
    "            dropout=0.15,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "        criterion = nn.CrossEntropyLoss(reduction=\"none\")\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "            for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(SOFTMAX_BATCH_SIZE, shuffle=True):\n",
    "                feat_batch = make_user_feature_batch(user_idx_batch)\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model(feat_batch)\n",
    "                    loss_vec = criterion(logits, item_idx_batch)\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_softmax_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "            feat_batch = make_user_feature_batch(u)\n",
    "            logits = model(feat_batch).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(logits, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c2fc4897",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 2: feature two-tower with in-batch negatives\n",
    "\n",
    "    class TowerEncoder(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "            super().__init__()\n",
    "            self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "            input_dim = len(feature_cardinalities) * emb_dim\n",
    "            self.net = MLPBlock(input_dim, hidden_dims, dropout=dropout, final_dim=out_dim)\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            x = self.encoder(feature_batch)\n",
    "            x = self.net(x)\n",
    "            return F.normalize(x, dim=-1)\n",
    "\n",
    "    class TwoTowerModel(nn.Module):\n",
    "        def __init__(self, user_feature_cards, item_feature_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "            super().__init__()\n",
    "            self.user_tower = TowerEncoder(user_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "            self.item_tower = TowerEncoder(item_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "    def train_two_tower(model_name, emb_dim=64, hidden_dims=(512, 256), out_dim=128, epochs=20, lr=2e-3, wd=1e-5, temperature=0.07):\n",
    "        user_feature_cards = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "        item_feature_cards = {col: len(feature_encoders[col].classes_) for col in item_cols}\n",
    "\n",
    "        model = TwoTowerModel(\n",
    "            user_feature_cards=user_feature_cards,\n",
    "            item_feature_cards=item_feature_cards,\n",
    "            emb_dim=emb_dim,\n",
    "            hidden_dims=hidden_dims,\n",
    "            out_dim=out_dim,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "            for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(TWOTOWER_BATCH_SIZE, shuffle=True):\n",
    "                user_batch = make_user_feature_batch(user_idx_batch)\n",
    "                item_batch = make_item_feature_batch(item_idx_batch)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    user_vec = model.user_tower(user_batch)\n",
    "                    item_vec = model.item_tower(item_batch)\n",
    "                    logits = (user_vec @ item_vec.T) / temperature\n",
    "                    targets = torch.arange(logits.shape[0], device=device)\n",
    "                    loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_two_tower_recommender(model):\n",
    "        model.eval()\n",
    "        all_item_idx = torch.arange(num_items, dtype=torch.long, device=device)\n",
    "        item_matrix_parts = []\n",
    "        with torch.no_grad():\n",
    "            for start in range(0, num_items, 4096):\n",
    "                idx = all_item_idx[start:start + 4096]\n",
    "                batch = make_item_feature_batch(idx)\n",
    "                item_matrix_parts.append(model.item_tower(batch))\n",
    "        item_matrix = torch.cat(item_matrix_parts, dim=0)\n",
    "\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "            user_batch = make_user_feature_batch(u)\n",
    "            user_vec = model.user_tower(user_batch)\n",
    "            scores = (user_vec @ item_matrix.T).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "918da172",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 3: dense autoencoders\n",
    "\n",
    "    class MultDAE(nn.Module):\n",
    "        def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.2):\n",
    "            super().__init__()\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "            self.encoder = nn.Sequential(\n",
    "                nn.Linear(num_items, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, latent_dim),\n",
    "                nn.Tanh(),\n",
    "            )\n",
    "            self.decoder = nn.Sequential(\n",
    "                nn.Linear(latent_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, num_items),\n",
    "            )\n",
    "\n",
    "        def forward(self, x):\n",
    "            z = self.encoder(self.dropout(x))\n",
    "            logits = self.decoder(z)\n",
    "            return logits\n",
    "\n",
    "    class MultVAE(nn.Module):\n",
    "        def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.3):\n",
    "            super().__init__()\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "            self.encoder = nn.Sequential(\n",
    "                nn.Linear(num_items, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "            )\n",
    "            self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.decoder = nn.Sequential(\n",
    "                nn.Linear(latent_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, num_items),\n",
    "            )\n",
    "\n",
    "        def encode(self, x):\n",
    "            h = self.encoder(self.dropout(x))\n",
    "            return self.mu(h), self.logvar(h)\n",
    "\n",
    "        def reparameterize(self, mu, logvar):\n",
    "            if self.training:\n",
    "                std = torch.exp(0.5 * logvar)\n",
    "                eps = torch.randn_like(std)\n",
    "                return mu + eps * std\n",
    "            return mu\n",
    "\n",
    "        def forward(self, x):\n",
    "            mu, logvar = self.encode(x)\n",
    "            z = self.reparameterize(mu, logvar)\n",
    "            logits = self.decoder(z)\n",
    "            return logits, mu, logvar\n",
    "\n",
    "    def train_multdae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.2, epochs=80, lr=1e-3, wd=0.0):\n",
    "        model = MultDAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_rows = 0\n",
    "            for batch_x, _ in iterate_dense_user_batches(AUTOENC_BATCH_SIZE, shuffle=True):\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model(batch_x)\n",
    "                    loss = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1).mean()\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "                total_rows += batch_x.shape[0]\n",
    "\n",
    "            if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "                print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def train_multvae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.3, epochs=80, lr=1e-3, wd=0.0, anneal_cap=1.0):\n",
    "        model = MultVAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        steps_per_epoch = max(1, math.ceil(num_users / AUTOENC_BATCH_SIZE))\n",
    "        total_steps = max(1, epochs * steps_per_epoch)\n",
    "        step = 0\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_rows = 0\n",
    "\n",
    "            for batch_x, _ in iterate_dense_user_batches(AUTOENC_BATCH_SIZE, shuffle=True):\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                anneal = min(anneal_cap, step / max(total_steps * 0.3, 1))\n",
    "\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits, mu, logvar = model(batch_x)\n",
    "                    recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                    loss = (recon + anneal * kl).mean()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "                total_rows += batch_x.shape[0]\n",
    "                step += 1\n",
    "\n",
    "            if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "                print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f} - anneal: {anneal:.3f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_multdae_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            x = X_binary_dense_tensor[uid:uid + 1]\n",
    "            logits = model(x).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(logits, seen, n)\n",
    "        return recommend\n",
    "\n",
    "    def make_multvae_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            x = X_binary_dense_tensor[uid:uid + 1]\n",
    "            logits, _, _ = model(x)\n",
    "            scores = logits.squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "2d7bc556",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 4: BPR matrix factorization\n",
    "\n",
    "    class BPRMF(nn.Module):\n",
    "        def __init__(self, num_users, num_items, dim=128):\n",
    "            super().__init__()\n",
    "            self.user_emb = nn.Embedding(num_users, dim)\n",
    "            self.item_emb = nn.Embedding(num_items, dim)\n",
    "            self.item_bias = nn.Embedding(num_items, 1)\n",
    "\n",
    "            nn.init.normal_(self.user_emb.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_emb.weight, std=0.02)\n",
    "            nn.init.zeros_(self.item_bias.weight)\n",
    "\n",
    "        def forward(self, user_idx, pos_idx, neg_idx):\n",
    "            u = self.user_emb(user_idx)\n",
    "            p = self.item_emb(pos_idx)\n",
    "            n = self.item_emb(neg_idx)\n",
    "            pb = self.item_bias(pos_idx).squeeze(-1)\n",
    "            nb = self.item_bias(neg_idx).squeeze(-1)\n",
    "\n",
    "            pos_scores = (u * p).sum(dim=1) + pb\n",
    "            neg_scores = (u * n).sum(dim=1) + nb\n",
    "            return pos_scores, neg_scores\n",
    "\n",
    "    def train_bprmf(model_name, dim=128, epochs=24, lr=2e-3, wd=1e-6):\n",
    "        model = BPRMF(num_users=num_users, num_items=num_items, dim=dim).to(device)\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "\n",
    "            for user_idx_batch, pos_item_batch, weight_batch in iterate_positive_batches(PAIRWISE_BATCH_SIZE, shuffle=True):\n",
    "                neg_item_batch = torch.randint(0, num_items, size=pos_item_batch.shape, device=device)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    pos_scores, neg_scores = model(user_idx_batch, pos_item_batch, neg_item_batch)\n",
    "                    loss_vec = -F.logsigmoid(pos_scores - neg_scores)\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_bprmf_recommender(model):\n",
    "        item_matrix = model.item_emb.weight\n",
    "        item_bias = model.item_bias.weight.squeeze(-1)\n",
    "\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = model.user_emb.weight[uid]\n",
    "            scores = item_matrix @ u + item_bias\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "48fb928d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 5: NeuMF\n",
    "\n",
    "    class NeuMF(nn.Module):\n",
    "        def __init__(self, num_users, num_items, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128)):\n",
    "            super().__init__()\n",
    "            self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "            self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "            self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "            self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "            layers = []\n",
    "            last = mlp_dim * 2\n",
    "            for h in hidden_dims:\n",
    "                layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(0.1)]\n",
    "                last = h\n",
    "            self.mlp = nn.Sequential(*layers)\n",
    "            self.out = nn.Linear(last + mf_dim, 1)\n",
    "\n",
    "            nn.init.normal_(self.user_mf.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_mf.weight, std=0.02)\n",
    "            nn.init.normal_(self.user_mlp.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_mlp.weight, std=0.02)\n",
    "\n",
    "        def score(self, user_idx, item_idx):\n",
    "            mf_u = self.user_mf(user_idx)\n",
    "            mf_i = self.item_mf(item_idx)\n",
    "            mf = mf_u * mf_i\n",
    "\n",
    "            mlp_u = self.user_mlp(user_idx)\n",
    "            mlp_i = self.item_mlp(item_idx)\n",
    "            mlp = self.mlp(torch.cat([mlp_u, mlp_i], dim=-1))\n",
    "\n",
    "            x = torch.cat([mf, mlp], dim=-1)\n",
    "            return self.out(x).squeeze(-1)\n",
    "\n",
    "    def train_neumf(model_name, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128), epochs=24, lr=2e-3, wd=1e-6):\n",
    "        model = NeuMF(num_users=num_users, num_items=num_items, mf_dim=mf_dim, mlp_dim=mlp_dim, hidden_dims=hidden_dims).to(device)\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "\n",
    "            for user_idx_batch, pos_item_batch, weight_batch in iterate_positive_batches(PAIRWISE_BATCH_SIZE, shuffle=True):\n",
    "                neg_item_batch = torch.randint(0, num_items, size=pos_item_batch.shape, device=device)\n",
    "\n",
    "                user_cat = torch.cat([user_idx_batch, user_idx_batch], dim=0)\n",
    "                item_cat = torch.cat([pos_item_batch, neg_item_batch], dim=0)\n",
    "                target = torch.cat([\n",
    "                    torch.ones_like(pos_item_batch, dtype=torch.float32),\n",
    "                    torch.zeros_like(neg_item_batch, dtype=torch.float32),\n",
    "                ], dim=0)\n",
    "                sample_weight = torch.cat([weight_batch, weight_batch], dim=0)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model.score(user_cat, item_cat)\n",
    "                    loss_vec = F.binary_cross_entropy_with_logits(logits, target, reduction=\"none\")\n",
    "                    loss = (loss_vec * sample_weight).sum() / sample_weight.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(sample_weight.sum().item())\n",
    "                total_w += float(sample_weight.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_neumf_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            users = torch.full((num_items,), uid, dtype=torch.long, device=device)\n",
    "            items = torch.arange(num_items, dtype=torch.long, device=device)\n",
    "            scores = model.score(users, items)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c9e63b69",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating regular baseline: Popularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Popularity: 100%|██████████| 43199/43199 [00:10<00:00, 4171.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting P3alpha alpha=0.85\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "P3alpha_a0_85: 100%|██████████| 43199/43199 [00:05<00:00, 8206.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting P3alpha alpha=0.95\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "P3alpha_a0_95: 100%|██████████| 43199/43199 [00:05<00:00, 8276.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting P3alpha alpha=1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "P3alpha_a1_0: 100%|██████████| 43199/43199 [00:05<00:00, 8294.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting P3alpha alpha=1.05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "P3alpha_a1_05: 100%|██████████| 43199/43199 [00:05<00:00, 8192.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting P3alpha alpha=1.15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "P3alpha_a1_15: 100%|██████████| 43199/43199 [00:05<00:00, 8042.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=0.9 beta=0.4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a0_9_b0_4: 100%|██████████| 43199/43199 [00:05<00:00, 8035.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=0.95 beta=0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a0_95_b0_5: 100%|██████████| 43199/43199 [00:05<00:00, 8116.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=1.0 beta=0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a1_0_b0_5: 100%|██████████| 43199/43199 [00:05<00:00, 8149.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=1.0 beta=0.6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a1_0_b0_6: 100%|██████████| 43199/43199 [00:05<00:00, 8116.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=1.05 beta=0.6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a1_05_b0_6: 100%|██████████| 43199/43199 [00:05<00:00, 8056.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting RP3beta alpha=1.1 beta=0.7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RP3beta_a1_1_b0_7: 100%|██████████| 43199/43199 [00:05<00:00, 8167.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=600.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l600: 100%|██████████| 6/6 [00:02<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=800.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l800: 100%|██████████| 6/6 [00:02<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l1000: 100%|██████████| 6/6 [00:02<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=1200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l1200: 100%|██████████| 6/6 [00:02<00:00,  2.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=1600.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l1600: 100%|██████████| 6/6 [00:02<00:00,  2.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=2200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l2200: 100%|██████████| 6/6 [00:02<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE binary lambda=3000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_binary_l3000: 100%|██████████| 6/6 [00:02<00:00,  2.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE count lambda=600.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l600: 100%|██████████| 6/6 [00:02<00:00,  2.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE count lambda=800.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l800: 100%|██████████| 6/6 [00:02<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE count lambda=1000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l1000: 100%|██████████| 6/6 [00:02<00:00,  2.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE count lambda=1200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l1200: 100%|██████████| 6/6 [00:02<00:00,  2.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE count lambda=1600.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l1600: 100%|██████████| 6/6 [00:02<00:00,  2.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE count lambda=2200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l2200: 100%|██████████| 6/6 [00:01<00:00,  3.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Fitting EASE count lambda=3000.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "EASE_count_l3000: 100%|██████████| 6/6 [00:02<00:00,  2.83it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097734</td>\n",
       "      <td>0.169981</td>\n",
       "      <td>0.308433</td>\n",
       "      <td>0.060909</td>\n",
       "      <td>0.086051</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>RP3beta_a1_05_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169680</td>\n",
       "      <td>0.308850</td>\n",
       "      <td>0.060639</td>\n",
       "      <td>0.085774</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>RP3beta_a1_0_b0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097780</td>\n",
       "      <td>0.169657</td>\n",
       "      <td>0.310887</td>\n",
       "      <td>0.060827</td>\n",
       "      <td>0.085906</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169448</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060618</td>\n",
       "      <td>0.085705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097711</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>P3alpha_a1_05</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098197</td>\n",
       "      <td>0.169147</td>\n",
       "      <td>0.310540</td>\n",
       "      <td>0.060244</td>\n",
       "      <td>0.085325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098544</td>\n",
       "      <td>0.169055</td>\n",
       "      <td>0.309799</td>\n",
       "      <td>0.061045</td>\n",
       "      <td>0.085938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>EASE_count_l800</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095396</td>\n",
       "      <td>0.169009</td>\n",
       "      <td>0.309961</td>\n",
       "      <td>0.059349</td>\n",
       "      <td>0.084554</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>RP3beta_a0_95_b0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097595</td>\n",
       "      <td>0.168985</td>\n",
       "      <td>0.310007</td>\n",
       "      <td>0.060733</td>\n",
       "      <td>0.085685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>EASE_binary_l800</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097618</td>\n",
       "      <td>0.168985</td>\n",
       "      <td>0.310007</td>\n",
       "      <td>0.060290</td>\n",
       "      <td>0.085320</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>P3alpha_a1_0</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060332</td>\n",
       "      <td>0.085333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096924</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310563</td>\n",
       "      <td>0.060030</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>P3alpha_a0_95</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097826</td>\n",
       "      <td>0.168800</td>\n",
       "      <td>0.310493</td>\n",
       "      <td>0.060338</td>\n",
       "      <td>0.085320</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098567</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.309452</td>\n",
       "      <td>0.060942</td>\n",
       "      <td>0.085798</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>RP3beta_a0_9_b0_4</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097224</td>\n",
       "      <td>0.168615</td>\n",
       "      <td>0.310053</td>\n",
       "      <td>0.060512</td>\n",
       "      <td>0.085428</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>P3alpha_a1_15</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098058</td>\n",
       "      <td>0.168592</td>\n",
       "      <td>0.310239</td>\n",
       "      <td>0.060189</td>\n",
       "      <td>0.085164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>EASE_binary_l3000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097780</td>\n",
       "      <td>0.168569</td>\n",
       "      <td>0.308734</td>\n",
       "      <td>0.060858</td>\n",
       "      <td>0.085689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>EASE_count_l600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.093752</td>\n",
       "      <td>0.168453</td>\n",
       "      <td>0.309915</td>\n",
       "      <td>0.057772</td>\n",
       "      <td>0.083179</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>EASE_count_l3000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098174</td>\n",
       "      <td>0.168268</td>\n",
       "      <td>0.308688</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085551</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>EASE_count_l2200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097757</td>\n",
       "      <td>0.168013</td>\n",
       "      <td>0.309313</td>\n",
       "      <td>0.060807</td>\n",
       "      <td>0.085508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>P3alpha_a0_85</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097664</td>\n",
       "      <td>0.168013</td>\n",
       "      <td>0.310678</td>\n",
       "      <td>0.060293</td>\n",
       "      <td>0.085120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>EASE_binary_l2200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098012</td>\n",
       "      <td>0.167990</td>\n",
       "      <td>0.309104</td>\n",
       "      <td>0.060916</td>\n",
       "      <td>0.085615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>EASE_binary_l600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095905</td>\n",
       "      <td>0.167990</td>\n",
       "      <td>0.309660</td>\n",
       "      <td>0.059069</td>\n",
       "      <td>0.084132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097410</td>\n",
       "      <td>0.167805</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060303</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097317</td>\n",
       "      <td>0.167342</td>\n",
       "      <td>0.310030</td>\n",
       "      <td>0.060572</td>\n",
       "      <td>0.085189</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.003218</td>\n",
       "      <td>0.006019</td>\n",
       "      <td>0.011621</td>\n",
       "      <td>0.001746</td>\n",
       "      <td>0.002724</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0    RP3beta_a1_1_b0_7      43199  0.097734  0.169981  0.308433  0.060909   \n",
       "1   RP3beta_a1_05_b0_6      43199  0.097687  0.169680  0.308850  0.060639   \n",
       "2    RP3beta_a1_0_b0_5      43199  0.097780  0.169657  0.310887  0.060827   \n",
       "3    RP3beta_a1_0_b0_6      43199  0.097687  0.169448  0.308896  0.060618   \n",
       "4    EASE_binary_l1000      43199  0.097711  0.169356  0.309614  0.060787   \n",
       "5        P3alpha_a1_05      43199  0.098197  0.169147  0.310540  0.060244   \n",
       "6    EASE_binary_l1200      43199  0.098544  0.169055  0.309799  0.061045   \n",
       "7      EASE_count_l800      43199  0.095396  0.169009  0.309961  0.059349   \n",
       "8   RP3beta_a0_95_b0_5      43199  0.097595  0.168985  0.310007  0.060733   \n",
       "9     EASE_binary_l800      43199  0.097618  0.168985  0.310007  0.060290   \n",
       "10        P3alpha_a1_0      43199  0.098104  0.168870  0.310609  0.060332   \n",
       "11    EASE_count_l1000      43199  0.096924  0.168870  0.310563  0.060030   \n",
       "12       P3alpha_a0_95      43199  0.097826  0.168800  0.310493  0.060338   \n",
       "13   EASE_binary_l1600      43199  0.098567  0.168731  0.309452  0.060942   \n",
       "14   RP3beta_a0_9_b0_4      43199  0.097224  0.168615  0.310053  0.060512   \n",
       "15       P3alpha_a1_15      43199  0.098058  0.168592  0.310239  0.060189   \n",
       "16   EASE_binary_l3000      43199  0.097780  0.168569  0.308734  0.060858   \n",
       "17     EASE_count_l600      43199  0.093752  0.168453  0.309915  0.057772   \n",
       "18    EASE_count_l3000      43199  0.098174  0.168268  0.308688  0.060787   \n",
       "19    EASE_count_l2200      43199  0.097757  0.168013  0.309313  0.060807   \n",
       "20       P3alpha_a0_85      43199  0.097664  0.168013  0.310678  0.060293   \n",
       "21   EASE_binary_l2200      43199  0.098012  0.167990  0.309104  0.060916   \n",
       "22    EASE_binary_l600      43199  0.095905  0.167990  0.309660  0.059069   \n",
       "23    EASE_count_l1200      43199  0.097410  0.167805  0.310609  0.060303   \n",
       "24    EASE_count_l1600      43199  0.097317  0.167342  0.310030  0.060572   \n",
       "25          Popularity      43199  0.003218  0.006019  0.011621  0.001746   \n",
       "\n",
       "     NDCG@10  \n",
       "0   0.086051  \n",
       "1   0.085774  \n",
       "2   0.085906  \n",
       "3   0.085705  \n",
       "4   0.085790  \n",
       "5   0.085325  \n",
       "6   0.085938  \n",
       "7   0.084554  \n",
       "8   0.085685  \n",
       "9   0.085320  \n",
       "10  0.085333  \n",
       "11  0.085076  \n",
       "12  0.085320  \n",
       "13  0.085798  \n",
       "14  0.085428  \n",
       "15  0.085164  \n",
       "16  0.085689  \n",
       "17  0.083179  \n",
       "18  0.085551  \n",
       "19  0.085508  \n",
       "20  0.085120  \n",
       "21  0.085615  \n",
       "22  0.084132  \n",
       "23  0.085076  \n",
       "24  0.085189  \n",
       "25  0.002724  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit and evaluate focused regular models\n",
    "\n",
    "regular_results = []\n",
    "regular_models = {}\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: Popularity\")\n",
    "regular_models[\"Popularity\"] = recommend_popularity\n",
    "regular_results.append(\n",
    "    evaluate_model(recommend_popularity, \"Popularity\", data, user_indices=eval_users, ks=TOP_KS)\n",
    ")\n",
    "\n",
    "for alpha in P3_ALPHA_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting P3alpha alpha={alpha}\")\n",
    "    rec = make_p3_recommender(X_binary, alpha=alpha, beta=0.0)\n",
    "    name = f\"P3alpha_a{str(alpha).replace('.', '_')}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(\n",
    "        evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS)\n",
    "    )\n",
    "\n",
    "for alpha, beta in RP3_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting RP3beta alpha={alpha} beta={beta}\")\n",
    "    rec = make_p3_recommender(X_binary, alpha=alpha, beta=beta)\n",
    "    name = f\"RP3beta_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(\n",
    "        evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS)\n",
    "    )\n",
    "\n",
    "for lam in EASE_BINARY_LAMBDAS:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting EASE binary lambda={lam}\")\n",
    "    B_bin = fit_ease(X_binary, lam=lam)\n",
    "    rec_bin = make_ease_recommender(X_binary_dense, B_bin)\n",
    "    name_bin = f\"EASE_binary_l{int(lam)}\"\n",
    "    regular_models[name_bin] = rec_bin\n",
    "    regular_results.append(\n",
    "        evaluate_ease_batched(\n",
    "            B_matrix=B_bin,\n",
    "            X_train_matrix=X_binary,\n",
    "            user_seen_dict=user_seen,\n",
    "            test_item_by_user=test_item_by_user,\n",
    "            model_name=name_bin,\n",
    "            user_indices=eval_users,\n",
    "            ks=TOP_KS,\n",
    "            batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "            user_seen_arrays_dict=user_seen_arrays,\n",
    "        )\n",
    "    )\n",
    "\n",
    "for lam in EASE_COUNT_LAMBDAS:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting EASE count lambda={lam}\")\n",
    "    B_cnt = fit_ease(X_counts, lam=lam)\n",
    "    rec_cnt = make_ease_recommender(X_counts_dense, B_cnt)\n",
    "    name_cnt = f\"EASE_count_l{int(lam)}\"\n",
    "    regular_models[name_cnt] = rec_cnt\n",
    "    regular_results.append(\n",
    "        evaluate_ease_batched(\n",
    "            B_matrix=B_cnt,\n",
    "            X_train_matrix=X_counts,\n",
    "            user_seen_dict=user_seen,\n",
    "            test_item_by_user=test_item_by_user,\n",
    "            model_name=name_cnt,\n",
    "            user_indices=eval_users,\n",
    "            ks=TOP_KS,\n",
    "            batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "            user_seen_arrays_dict=user_seen_arrays,\n",
    "        )\n",
    "    )\n",
    "\n",
    "regular_results_df = (\n",
    "    pd.DataFrame(regular_results)\n",
    "    .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "display(regular_results_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "0bdd000a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training TwoTower_t1\n",
      "TwoTower_t1 epoch 1/24 - loss: 6.660124\n",
      "TwoTower_t1 epoch 2/24 - loss: 6.332343\n",
      "TwoTower_t1 epoch 3/24 - loss: 6.329521\n",
      "TwoTower_t1 epoch 4/24 - loss: 6.328155\n",
      "TwoTower_t1 epoch 5/24 - loss: 6.327216\n",
      "TwoTower_t1 epoch 6/24 - loss: 6.326530\n",
      "TwoTower_t1 epoch 7/24 - loss: 6.326110\n",
      "TwoTower_t1 epoch 8/24 - loss: 6.325513\n",
      "TwoTower_t1 epoch 9/24 - loss: 6.325276\n",
      "TwoTower_t1 epoch 10/24 - loss: 6.324973\n",
      "TwoTower_t1 epoch 11/24 - loss: 6.324833\n",
      "TwoTower_t1 epoch 12/24 - loss: 6.324714\n",
      "TwoTower_t1 epoch 13/24 - loss: 6.324426\n",
      "TwoTower_t1 epoch 14/24 - loss: 6.324100\n",
      "TwoTower_t1 epoch 15/24 - loss: 6.324039\n",
      "TwoTower_t1 epoch 16/24 - loss: 6.323955\n",
      "TwoTower_t1 epoch 17/24 - loss: 6.323771\n",
      "TwoTower_t1 epoch 18/24 - loss: 6.323748\n",
      "TwoTower_t1 epoch 19/24 - loss: 6.323497\n",
      "TwoTower_t1 epoch 20/24 - loss: 6.323462\n",
      "TwoTower_t1 epoch 21/24 - loss: 6.323321\n",
      "TwoTower_t1 epoch 22/24 - loss: 6.323294\n",
      "TwoTower_t1 epoch 23/24 - loss: 6.323117\n",
      "TwoTower_t1 epoch 24/24 - loss: 6.323093\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TwoTower_t1: 100%|██████████| 43199/43199 [00:18<00:00, 2386.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training TwoTower_t2\n",
      "TwoTower_t2 epoch 1/28 - loss: 6.943108\n",
      "TwoTower_t2 epoch 2/28 - loss: 6.343322\n",
      "TwoTower_t2 epoch 3/28 - loss: 6.335398\n",
      "TwoTower_t2 epoch 4/28 - loss: 6.332158\n",
      "TwoTower_t2 epoch 5/28 - loss: 6.330052\n",
      "TwoTower_t2 epoch 6/28 - loss: 6.328640\n",
      "TwoTower_t2 epoch 7/28 - loss: 6.327775\n",
      "TwoTower_t2 epoch 8/28 - loss: 6.327046\n",
      "TwoTower_t2 epoch 9/28 - loss: 6.326525\n",
      "TwoTower_t2 epoch 10/28 - loss: 6.326367\n",
      "TwoTower_t2 epoch 11/28 - loss: 6.325923\n",
      "TwoTower_t2 epoch 12/28 - loss: 6.325519\n",
      "TwoTower_t2 epoch 13/28 - loss: 6.325029\n",
      "TwoTower_t2 epoch 14/28 - loss: 6.325085\n",
      "TwoTower_t2 epoch 15/28 - loss: 6.324746\n",
      "TwoTower_t2 epoch 16/28 - loss: 6.324592\n",
      "TwoTower_t2 epoch 17/28 - loss: 6.324357\n",
      "TwoTower_t2 epoch 18/28 - loss: 6.324215\n",
      "TwoTower_t2 epoch 19/28 - loss: 6.324162\n",
      "TwoTower_t2 epoch 20/28 - loss: 6.323955\n",
      "TwoTower_t2 epoch 21/28 - loss: 6.323822\n",
      "TwoTower_t2 epoch 22/28 - loss: 6.323712\n",
      "TwoTower_t2 epoch 23/28 - loss: 6.323581\n",
      "TwoTower_t2 epoch 24/28 - loss: 6.323494\n",
      "TwoTower_t2 epoch 25/28 - loss: 6.323167\n",
      "TwoTower_t2 epoch 26/28 - loss: 6.323308\n",
      "TwoTower_t2 epoch 27/28 - loss: 6.323192\n",
      "TwoTower_t2 epoch 28/28 - loss: 6.323008\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TwoTower_t2: 100%|██████████| 43199/43199 [00:19<00:00, 2204.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training TwoTower_t3\n",
      "TwoTower_t3 epoch 1/30 - loss: 6.848020\n",
      "TwoTower_t3 epoch 2/30 - loss: 6.340398\n",
      "TwoTower_t3 epoch 3/30 - loss: 6.334111\n",
      "TwoTower_t3 epoch 4/30 - loss: 6.331278\n",
      "TwoTower_t3 epoch 5/30 - loss: 6.329516\n",
      "TwoTower_t3 epoch 6/30 - loss: 6.328225\n",
      "TwoTower_t3 epoch 7/30 - loss: 6.327511\n",
      "TwoTower_t3 epoch 8/30 - loss: 6.326832\n",
      "TwoTower_t3 epoch 9/30 - loss: 6.326316\n",
      "TwoTower_t3 epoch 10/30 - loss: 6.325874\n",
      "TwoTower_t3 epoch 11/30 - loss: 6.325562\n",
      "TwoTower_t3 epoch 12/30 - loss: 6.325153\n",
      "TwoTower_t3 epoch 13/30 - loss: 6.324978\n",
      "TwoTower_t3 epoch 14/30 - loss: 6.324683\n",
      "TwoTower_t3 epoch 15/30 - loss: 6.324469\n",
      "TwoTower_t3 epoch 16/30 - loss: 6.324277\n",
      "TwoTower_t3 epoch 17/30 - loss: 6.324086\n",
      "TwoTower_t3 epoch 18/30 - loss: 6.323896\n",
      "TwoTower_t3 epoch 19/30 - loss: 6.323805\n",
      "TwoTower_t3 epoch 20/30 - loss: 6.323606\n",
      "TwoTower_t3 epoch 21/30 - loss: 6.323480\n",
      "TwoTower_t3 epoch 22/30 - loss: 6.323381\n",
      "TwoTower_t3 epoch 23/30 - loss: 6.323274\n",
      "TwoTower_t3 epoch 24/30 - loss: 6.323240\n",
      "TwoTower_t3 epoch 25/30 - loss: 6.323259\n",
      "TwoTower_t3 epoch 26/30 - loss: 6.323004\n",
      "TwoTower_t3 epoch 27/30 - loss: 6.322918\n",
      "TwoTower_t3 epoch 28/30 - loss: 6.322918\n",
      "TwoTower_t3 epoch 29/30 - loss: 6.322792\n",
      "TwoTower_t3 epoch 30/30 - loss: 6.322725\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TwoTower_t3: 100%|██████████| 43199/43199 [00:19<00:00, 2180.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultVAE_v1\n",
      "MultVAE_v1 epoch 1/90 - loss: 151.361780 - anneal: 0.031\n",
      "MultVAE_v1 epoch 10/90 - loss: 93.088502 - anneal: 0.364\n",
      "MultVAE_v1 epoch 20/90 - loss: 95.418273 - anneal: 0.700\n",
      "MultVAE_v1 epoch 30/90 - loss: 94.968472 - anneal: 0.700\n",
      "MultVAE_v1 epoch 40/90 - loss: 94.770533 - anneal: 0.700\n",
      "MultVAE_v1 epoch 50/90 - loss: 94.677405 - anneal: 0.700\n",
      "MultVAE_v1 epoch 60/90 - loss: 94.608560 - anneal: 0.700\n",
      "MultVAE_v1 epoch 70/90 - loss: 94.427493 - anneal: 0.700\n",
      "MultVAE_v1 epoch 80/90 - loss: 94.394787 - anneal: 0.700\n",
      "MultVAE_v1 epoch 90/90 - loss: 94.332030 - anneal: 0.700\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultVAE_v1: 100%|██████████| 43199/43199 [00:11<00:00, 3602.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultVAE_v2\n",
      "MultVAE_v2 epoch 1/100 - loss: 151.233939 - anneal: 0.028\n",
      "MultVAE_v2 epoch 10/100 - loss: 92.913051 - anneal: 0.328\n",
      "MultVAE_v2 epoch 20/100 - loss: 95.177418 - anneal: 0.661\n",
      "MultVAE_v2 epoch 30/100 - loss: 96.160018 - anneal: 0.800\n",
      "MultVAE_v2 epoch 40/100 - loss: 95.932998 - anneal: 0.800\n",
      "MultVAE_v2 epoch 50/100 - loss: 95.694546 - anneal: 0.800\n",
      "MultVAE_v2 epoch 60/100 - loss: 95.574383 - anneal: 0.800\n",
      "MultVAE_v2 epoch 70/100 - loss: 95.520040 - anneal: 0.800\n",
      "MultVAE_v2 epoch 80/100 - loss: 95.503385 - anneal: 0.800\n",
      "MultVAE_v2 epoch 90/100 - loss: 95.446835 - anneal: 0.800\n",
      "MultVAE_v2 epoch 100/100 - loss: 95.416922 - anneal: 0.800\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultVAE_v2: 100%|██████████| 43199/43199 [00:11<00:00, 3649.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training MultVAE_v3\n",
      "MultVAE_v3 epoch 1/110 - loss: 151.384818 - anneal: 0.025\n",
      "MultVAE_v3 epoch 10/110 - loss: 93.016749 - anneal: 0.298\n",
      "MultVAE_v3 epoch 20/110 - loss: 94.627547 - anneal: 0.601\n",
      "MultVAE_v3 epoch 30/110 - loss: 97.095880 - anneal: 0.904\n",
      "MultVAE_v3 epoch 40/110 - loss: 97.801608 - anneal: 1.000\n",
      "MultVAE_v3 epoch 50/110 - loss: 97.587398 - anneal: 1.000\n",
      "MultVAE_v3 epoch 60/110 - loss: 97.473454 - anneal: 1.000\n",
      "MultVAE_v3 epoch 70/110 - loss: 97.438779 - anneal: 1.000\n",
      "MultVAE_v3 epoch 80/110 - loss: 97.292370 - anneal: 1.000\n",
      "MultVAE_v3 epoch 90/110 - loss: 97.326313 - anneal: 1.000\n",
      "MultVAE_v3 epoch 100/110 - loss: 97.231278 - anneal: 1.000\n",
      "MultVAE_v3 epoch 110/110 - loss: 97.179016 - anneal: 1.000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MultVAE_v3: 100%|██████████| 43199/43199 [00:12<00:00, 3521.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training NeuMF_n1\n",
      "NeuMF_n1 epoch 1/28 - loss: 0.687114\n",
      "NeuMF_n1 epoch 2/28 - loss: 0.568033\n",
      "NeuMF_n1 epoch 3/28 - loss: 0.427860\n",
      "NeuMF_n1 epoch 4/28 - loss: 0.305585\n",
      "NeuMF_n1 epoch 5/28 - loss: 0.220109\n",
      "NeuMF_n1 epoch 6/28 - loss: 0.175079\n",
      "NeuMF_n1 epoch 7/28 - loss: 0.141281\n",
      "NeuMF_n1 epoch 8/28 - loss: 0.115696\n",
      "NeuMF_n1 epoch 9/28 - loss: 0.097381\n",
      "NeuMF_n1 epoch 10/28 - loss: 0.087948\n",
      "NeuMF_n1 epoch 11/28 - loss: 0.083066\n",
      "NeuMF_n1 epoch 12/28 - loss: 0.081146\n",
      "NeuMF_n1 epoch 13/28 - loss: 0.080191\n",
      "NeuMF_n1 epoch 14/28 - loss: 0.078523\n",
      "NeuMF_n1 epoch 15/28 - loss: 0.078250\n",
      "NeuMF_n1 epoch 16/28 - loss: 0.077645\n",
      "NeuMF_n1 epoch 17/28 - loss: 0.077888\n",
      "NeuMF_n1 epoch 18/28 - loss: 0.077745\n",
      "NeuMF_n1 epoch 19/28 - loss: 0.076872\n",
      "NeuMF_n1 epoch 20/28 - loss: 0.076329\n",
      "NeuMF_n1 epoch 21/28 - loss: 0.075989\n",
      "NeuMF_n1 epoch 22/28 - loss: 0.076492\n",
      "NeuMF_n1 epoch 23/28 - loss: 0.076606\n",
      "NeuMF_n1 epoch 24/28 - loss: 0.076247\n",
      "NeuMF_n1 epoch 25/28 - loss: 0.075680\n",
      "NeuMF_n1 epoch 26/28 - loss: 0.075585\n",
      "NeuMF_n1 epoch 27/28 - loss: 0.075649\n",
      "NeuMF_n1 epoch 28/28 - loss: 0.075301\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "NeuMF_n1: 100%|██████████| 43199/43199 [00:14<00:00, 2995.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Training NeuMF_n2\n",
      "NeuMF_n2 epoch 1/32 - loss: 0.689794\n",
      "NeuMF_n2 epoch 2/32 - loss: 0.589130\n",
      "NeuMF_n2 epoch 3/32 - loss: 0.511645\n",
      "NeuMF_n2 epoch 4/32 - loss: 0.379542\n",
      "NeuMF_n2 epoch 5/32 - loss: 0.270673\n",
      "NeuMF_n2 epoch 6/32 - loss: 0.204035\n",
      "NeuMF_n2 epoch 7/32 - loss: 0.167692\n",
      "NeuMF_n2 epoch 8/32 - loss: 0.137075\n",
      "NeuMF_n2 epoch 9/32 - loss: 0.113451\n",
      "NeuMF_n2 epoch 10/32 - loss: 0.098932\n",
      "NeuMF_n2 epoch 11/32 - loss: 0.089369\n",
      "NeuMF_n2 epoch 12/32 - loss: 0.083239\n",
      "NeuMF_n2 epoch 13/32 - loss: 0.081376\n",
      "NeuMF_n2 epoch 14/32 - loss: 0.079027\n",
      "NeuMF_n2 epoch 15/32 - loss: 0.078021\n",
      "NeuMF_n2 epoch 16/32 - loss: 0.078233\n",
      "NeuMF_n2 epoch 17/32 - loss: 0.077714\n",
      "NeuMF_n2 epoch 18/32 - loss: 0.076749\n",
      "NeuMF_n2 epoch 19/32 - loss: 0.077281\n",
      "NeuMF_n2 epoch 20/32 - loss: 0.076875\n",
      "NeuMF_n2 epoch 21/32 - loss: 0.076234\n",
      "NeuMF_n2 epoch 22/32 - loss: 0.076514\n",
      "NeuMF_n2 epoch 23/32 - loss: 0.076776\n",
      "NeuMF_n2 epoch 24/32 - loss: 0.076424\n",
      "NeuMF_n2 epoch 25/32 - loss: 0.076194\n",
      "NeuMF_n2 epoch 26/32 - loss: 0.076319\n",
      "NeuMF_n2 epoch 27/32 - loss: 0.075420\n",
      "NeuMF_n2 epoch 28/32 - loss: 0.075476\n",
      "NeuMF_n2 epoch 29/32 - loss: 0.076015\n",
      "NeuMF_n2 epoch 30/32 - loss: 0.074863\n",
      "NeuMF_n2 epoch 31/32 - loss: 0.075362\n",
      "NeuMF_n2 epoch 32/32 - loss: 0.075675\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "NeuMF_n2: 100%|██████████| 43199/43199 [00:16<00:00, 2697.13it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TwoTower_t3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096484</td>\n",
       "      <td>0.169749</td>\n",
       "      <td>0.310933</td>\n",
       "      <td>0.060031</td>\n",
       "      <td>0.085268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098359</td>\n",
       "      <td>0.169610</td>\n",
       "      <td>0.312368</td>\n",
       "      <td>0.060681</td>\n",
       "      <td>0.085785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098312</td>\n",
       "      <td>0.169286</td>\n",
       "      <td>0.312901</td>\n",
       "      <td>0.060905</td>\n",
       "      <td>0.085869</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NeuMF_n1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095419</td>\n",
       "      <td>0.168013</td>\n",
       "      <td>0.311489</td>\n",
       "      <td>0.057998</td>\n",
       "      <td>0.083309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NeuMF_n2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095674</td>\n",
       "      <td>0.167180</td>\n",
       "      <td>0.312137</td>\n",
       "      <td>0.059018</td>\n",
       "      <td>0.083918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>MultVAE_v2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097919</td>\n",
       "      <td>0.165883</td>\n",
       "      <td>0.307808</td>\n",
       "      <td>0.060146</td>\n",
       "      <td>0.084537</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>MultVAE_v1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096715</td>\n",
       "      <td>0.165860</td>\n",
       "      <td>0.309799</td>\n",
       "      <td>0.059773</td>\n",
       "      <td>0.084203</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>TwoTower_t1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094169</td>\n",
       "      <td>0.165860</td>\n",
       "      <td>0.307160</td>\n",
       "      <td>0.058893</td>\n",
       "      <td>0.083500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         Model  UsersEval      HR@5     HR@10     HR@20    MRR@10   NDCG@10\n",
       "0  TwoTower_t3      43199  0.096484  0.169749  0.310933  0.060031  0.085268\n",
       "1   MultVAE_v3      43199  0.098359  0.169610  0.312368  0.060681  0.085785\n",
       "2  TwoTower_t2      43199  0.098312  0.169286  0.312901  0.060905  0.085869\n",
       "3     NeuMF_n1      43199  0.095419  0.168013  0.311489  0.057998  0.083309\n",
       "4     NeuMF_n2      43199  0.095674  0.167180  0.312137  0.059018  0.083918\n",
       "5   MultVAE_v2      43199  0.097919  0.165883  0.307808  0.060146  0.084537\n",
       "6   MultVAE_v1      43199  0.096715  0.165860  0.309799  0.059773  0.084203\n",
       "7  TwoTower_t1      43199  0.094169  0.165860  0.307160  0.058893  0.083500"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if not RUN_NEURAL:\n",
    "    neural_results = []\n",
    "    neural_models = {}\n",
    "    neural_results_df = pd.DataFrame(columns=['Model','UsersEval','HR@5','HR@10','HR@20','MRR@10','NDCG@10'])\n",
    "    print('Skipping neural model training because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Fit and evaluate focused neural models\n",
    "\n",
    "    neural_results = []\n",
    "    neural_models = {}\n",
    "\n",
    "    for cfg in TWOTOWER_CONFIGS:\n",
    "        print(\"=\" * 100)\n",
    "        print(f\"Training {cfg['name']}\")\n",
    "        model = train_two_tower(\n",
    "            model_name=cfg[\"name\"],\n",
    "            emb_dim=cfg[\"emb_dim\"],\n",
    "            hidden_dims=cfg[\"hidden_dims\"],\n",
    "            out_dim=cfg[\"out_dim\"],\n",
    "            epochs=cfg[\"epochs\"],\n",
    "            lr=cfg[\"lr\"],\n",
    "            wd=cfg[\"wd\"],\n",
    "            temperature=cfg[\"temperature\"],\n",
    "        )\n",
    "        rec = make_two_tower_recommender(model)\n",
    "        neural_models[cfg[\"name\"]] = rec\n",
    "        neural_results.append(\n",
    "            evaluate_model(rec, cfg[\"name\"], data, user_indices=eval_users, ks=TOP_KS)\n",
    "        )\n",
    "\n",
    "    for cfg in MULTVAE_CONFIGS:\n",
    "        print(\"=\" * 100)\n",
    "        print(f\"Training {cfg['name']}\")\n",
    "        model = train_multvae(\n",
    "            model_name=cfg[\"name\"],\n",
    "            hidden_dim=cfg[\"hidden_dim\"],\n",
    "            latent_dim=cfg[\"latent_dim\"],\n",
    "            dropout=cfg[\"dropout\"],\n",
    "            epochs=cfg[\"epochs\"],\n",
    "            lr=cfg[\"lr\"],\n",
    "            wd=cfg[\"wd\"],\n",
    "            anneal_cap=cfg[\"anneal_cap\"],\n",
    "        )\n",
    "        rec = make_multvae_recommender(model)\n",
    "        neural_models[cfg[\"name\"]] = rec\n",
    "        neural_results.append(\n",
    "            evaluate_model(rec, cfg[\"name\"], data, user_indices=eval_users, ks=TOP_KS)\n",
    "        )\n",
    "\n",
    "    for cfg in NEUMF_CONFIGS:\n",
    "        print(\"=\" * 100)\n",
    "        print(f\"Training {cfg['name']}\")\n",
    "        model = train_neumf(\n",
    "            model_name=cfg[\"name\"],\n",
    "            mf_dim=cfg[\"mf_dim\"],\n",
    "            mlp_dim=cfg[\"mlp_dim\"],\n",
    "            hidden_dims=cfg[\"hidden_dims\"],\n",
    "            epochs=cfg[\"epochs\"],\n",
    "            lr=cfg[\"lr\"],\n",
    "            wd=cfg[\"wd\"],\n",
    "        )\n",
    "        rec = make_neumf_recommender(model)\n",
    "        neural_models[cfg[\"name\"]] = rec\n",
    "        neural_results.append(\n",
    "            evaluate_model(rec, cfg[\"name\"], data, user_indices=eval_users, ks=TOP_KS)\n",
    "        )\n",
    "\n",
    "    neural_results_df = (\n",
    "        pd.DataFrame(neural_results)\n",
    "        .sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "    display(neural_results_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d1ab8431",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating fusion: RRF_regular_3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RRF_regular_3: 100%|██████████| 43199/43199 [01:34<00:00, 458.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating fusion: RRF_hybrid_4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RRF_hybrid_4: 100%|██████████| 43199/43199 [02:39<00:00, 271.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================================================================\n",
      "Evaluating fusion: RRF_hybrid_3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RRF_hybrid_3: 100%|██████████| 43199/43199 [02:21<00:00, 304.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top regular models:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097734</td>\n",
       "      <td>0.169981</td>\n",
       "      <td>0.308433</td>\n",
       "      <td>0.060909</td>\n",
       "      <td>0.086051</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>RP3beta_a1_05_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169680</td>\n",
       "      <td>0.308850</td>\n",
       "      <td>0.060639</td>\n",
       "      <td>0.085774</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>RP3beta_a1_0_b0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097780</td>\n",
       "      <td>0.169657</td>\n",
       "      <td>0.310887</td>\n",
       "      <td>0.060827</td>\n",
       "      <td>0.085906</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169448</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060618</td>\n",
       "      <td>0.085705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097711</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>P3alpha_a1_05</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098197</td>\n",
       "      <td>0.169147</td>\n",
       "      <td>0.310540</td>\n",
       "      <td>0.060244</td>\n",
       "      <td>0.085325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098544</td>\n",
       "      <td>0.169055</td>\n",
       "      <td>0.309799</td>\n",
       "      <td>0.061045</td>\n",
       "      <td>0.085938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>EASE_count_l800</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095396</td>\n",
       "      <td>0.169009</td>\n",
       "      <td>0.309961</td>\n",
       "      <td>0.059349</td>\n",
       "      <td>0.084554</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>RP3beta_a0_95_b0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097595</td>\n",
       "      <td>0.168985</td>\n",
       "      <td>0.310007</td>\n",
       "      <td>0.060733</td>\n",
       "      <td>0.085685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>EASE_binary_l800</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097618</td>\n",
       "      <td>0.168985</td>\n",
       "      <td>0.310007</td>\n",
       "      <td>0.060290</td>\n",
       "      <td>0.085320</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>P3alpha_a1_0</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060332</td>\n",
       "      <td>0.085333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096924</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310563</td>\n",
       "      <td>0.060030</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0    RP3beta_a1_1_b0_7      43199  0.097734  0.169981  0.308433  0.060909   \n",
       "1   RP3beta_a1_05_b0_6      43199  0.097687  0.169680  0.308850  0.060639   \n",
       "2    RP3beta_a1_0_b0_5      43199  0.097780  0.169657  0.310887  0.060827   \n",
       "3    RP3beta_a1_0_b0_6      43199  0.097687  0.169448  0.308896  0.060618   \n",
       "4    EASE_binary_l1000      43199  0.097711  0.169356  0.309614  0.060787   \n",
       "5        P3alpha_a1_05      43199  0.098197  0.169147  0.310540  0.060244   \n",
       "6    EASE_binary_l1200      43199  0.098544  0.169055  0.309799  0.061045   \n",
       "7      EASE_count_l800      43199  0.095396  0.169009  0.309961  0.059349   \n",
       "8   RP3beta_a0_95_b0_5      43199  0.097595  0.168985  0.310007  0.060733   \n",
       "9     EASE_binary_l800      43199  0.097618  0.168985  0.310007  0.060290   \n",
       "10        P3alpha_a1_0      43199  0.098104  0.168870  0.310609  0.060332   \n",
       "11    EASE_count_l1000      43199  0.096924  0.168870  0.310563  0.060030   \n",
       "\n",
       "     NDCG@10  \n",
       "0   0.086051  \n",
       "1   0.085774  \n",
       "2   0.085906  \n",
       "3   0.085705  \n",
       "4   0.085790  \n",
       "5   0.085325  \n",
       "6   0.085938  \n",
       "7   0.084554  \n",
       "8   0.085685  \n",
       "9   0.085320  \n",
       "10  0.085333  \n",
       "11  0.085076  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top neural models:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>TwoTower_t3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096484</td>\n",
       "      <td>0.169749</td>\n",
       "      <td>0.310933</td>\n",
       "      <td>0.060031</td>\n",
       "      <td>0.085268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098359</td>\n",
       "      <td>0.169610</td>\n",
       "      <td>0.312368</td>\n",
       "      <td>0.060681</td>\n",
       "      <td>0.085785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098312</td>\n",
       "      <td>0.169286</td>\n",
       "      <td>0.312901</td>\n",
       "      <td>0.060905</td>\n",
       "      <td>0.085869</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NeuMF_n1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095419</td>\n",
       "      <td>0.168013</td>\n",
       "      <td>0.311489</td>\n",
       "      <td>0.057998</td>\n",
       "      <td>0.083309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NeuMF_n2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095674</td>\n",
       "      <td>0.167180</td>\n",
       "      <td>0.312137</td>\n",
       "      <td>0.059018</td>\n",
       "      <td>0.083918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>MultVAE_v2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097919</td>\n",
       "      <td>0.165883</td>\n",
       "      <td>0.307808</td>\n",
       "      <td>0.060146</td>\n",
       "      <td>0.084537</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>MultVAE_v1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096715</td>\n",
       "      <td>0.165860</td>\n",
       "      <td>0.309799</td>\n",
       "      <td>0.059773</td>\n",
       "      <td>0.084203</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>TwoTower_t1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094169</td>\n",
       "      <td>0.165860</td>\n",
       "      <td>0.307160</td>\n",
       "      <td>0.058893</td>\n",
       "      <td>0.083500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         Model  UsersEval      HR@5     HR@10     HR@20    MRR@10   NDCG@10\n",
       "0  TwoTower_t3      43199  0.096484  0.169749  0.310933  0.060031  0.085268\n",
       "1   MultVAE_v3      43199  0.098359  0.169610  0.312368  0.060681  0.085785\n",
       "2  TwoTower_t2      43199  0.098312  0.169286  0.312901  0.060905  0.085869\n",
       "3     NeuMF_n1      43199  0.095419  0.168013  0.311489  0.057998  0.083309\n",
       "4     NeuMF_n2      43199  0.095674  0.167180  0.312137  0.059018  0.083918\n",
       "5   MultVAE_v2      43199  0.097919  0.165883  0.307808  0.060146  0.084537\n",
       "6   MultVAE_v1      43199  0.096715  0.165860  0.309799  0.059773  0.084203\n",
       "7  TwoTower_t1      43199  0.094169  0.165860  0.307160  0.058893  0.083500"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fusion models:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RRF_regular_3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308063</td>\n",
       "      <td>0.060710</td>\n",
       "      <td>0.085596</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>RRF_hybrid_3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097340</td>\n",
       "      <td>0.167805</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060563</td>\n",
       "      <td>0.085263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>RRF_hybrid_4</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097109</td>\n",
       "      <td>0.167110</td>\n",
       "      <td>0.308364</td>\n",
       "      <td>0.060333</td>\n",
       "      <td>0.084952</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           Model  UsersEval      HR@5     HR@10     HR@20    MRR@10   NDCG@10\n",
       "0  RRF_regular_3      43199  0.097687  0.168638  0.308063  0.060710  0.085596\n",
       "2   RRF_hybrid_3      43199  0.097340  0.167805  0.308896  0.060563  0.085263\n",
       "1   RRF_hybrid_4      43199  0.097109  0.167110  0.308364  0.060333  0.084952"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall ranking:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>UsersEval</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RP3beta_a1_1_b0_7</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097734</td>\n",
       "      <td>0.169981</td>\n",
       "      <td>0.308433</td>\n",
       "      <td>0.060909</td>\n",
       "      <td>0.086051</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>TwoTower_t3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096484</td>\n",
       "      <td>0.169749</td>\n",
       "      <td>0.310933</td>\n",
       "      <td>0.060031</td>\n",
       "      <td>0.085268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>RP3beta_a1_05_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169680</td>\n",
       "      <td>0.308850</td>\n",
       "      <td>0.060639</td>\n",
       "      <td>0.085774</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>RP3beta_a1_0_b0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097780</td>\n",
       "      <td>0.169657</td>\n",
       "      <td>0.310887</td>\n",
       "      <td>0.060827</td>\n",
       "      <td>0.085906</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>MultVAE_v3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098359</td>\n",
       "      <td>0.169610</td>\n",
       "      <td>0.312368</td>\n",
       "      <td>0.060681</td>\n",
       "      <td>0.085785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>RP3beta_a1_0_b0_6</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.169448</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060618</td>\n",
       "      <td>0.085705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>EASE_binary_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097711</td>\n",
       "      <td>0.169356</td>\n",
       "      <td>0.309614</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>TwoTower_t2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098312</td>\n",
       "      <td>0.169286</td>\n",
       "      <td>0.312901</td>\n",
       "      <td>0.060905</td>\n",
       "      <td>0.085869</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>P3alpha_a1_05</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098197</td>\n",
       "      <td>0.169147</td>\n",
       "      <td>0.310540</td>\n",
       "      <td>0.060244</td>\n",
       "      <td>0.085325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>EASE_binary_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098544</td>\n",
       "      <td>0.169055</td>\n",
       "      <td>0.309799</td>\n",
       "      <td>0.061045</td>\n",
       "      <td>0.085938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>EASE_count_l800</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095396</td>\n",
       "      <td>0.169009</td>\n",
       "      <td>0.309961</td>\n",
       "      <td>0.059349</td>\n",
       "      <td>0.084554</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>RP3beta_a0_95_b0_5</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097595</td>\n",
       "      <td>0.168985</td>\n",
       "      <td>0.310007</td>\n",
       "      <td>0.060733</td>\n",
       "      <td>0.085685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>EASE_binary_l800</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097618</td>\n",
       "      <td>0.168985</td>\n",
       "      <td>0.310007</td>\n",
       "      <td>0.060290</td>\n",
       "      <td>0.085320</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>P3alpha_a1_0</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098104</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060332</td>\n",
       "      <td>0.085333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>EASE_count_l1000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096924</td>\n",
       "      <td>0.168870</td>\n",
       "      <td>0.310563</td>\n",
       "      <td>0.060030</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>P3alpha_a0_95</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097826</td>\n",
       "      <td>0.168800</td>\n",
       "      <td>0.310493</td>\n",
       "      <td>0.060338</td>\n",
       "      <td>0.085320</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>EASE_binary_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098567</td>\n",
       "      <td>0.168731</td>\n",
       "      <td>0.309452</td>\n",
       "      <td>0.060942</td>\n",
       "      <td>0.085798</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>RRF_regular_3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097687</td>\n",
       "      <td>0.168638</td>\n",
       "      <td>0.308063</td>\n",
       "      <td>0.060710</td>\n",
       "      <td>0.085596</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>RP3beta_a0_9_b0_4</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097224</td>\n",
       "      <td>0.168615</td>\n",
       "      <td>0.310053</td>\n",
       "      <td>0.060512</td>\n",
       "      <td>0.085428</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>P3alpha_a1_15</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098058</td>\n",
       "      <td>0.168592</td>\n",
       "      <td>0.310239</td>\n",
       "      <td>0.060189</td>\n",
       "      <td>0.085164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>EASE_binary_l3000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097780</td>\n",
       "      <td>0.168569</td>\n",
       "      <td>0.308734</td>\n",
       "      <td>0.060858</td>\n",
       "      <td>0.085689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>EASE_count_l600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.093752</td>\n",
       "      <td>0.168453</td>\n",
       "      <td>0.309915</td>\n",
       "      <td>0.057772</td>\n",
       "      <td>0.083179</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>EASE_count_l3000</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098174</td>\n",
       "      <td>0.168268</td>\n",
       "      <td>0.308688</td>\n",
       "      <td>0.060787</td>\n",
       "      <td>0.085551</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>EASE_count_l2200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097757</td>\n",
       "      <td>0.168013</td>\n",
       "      <td>0.309313</td>\n",
       "      <td>0.060807</td>\n",
       "      <td>0.085508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>P3alpha_a0_85</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097664</td>\n",
       "      <td>0.168013</td>\n",
       "      <td>0.310678</td>\n",
       "      <td>0.060293</td>\n",
       "      <td>0.085120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>NeuMF_n1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095419</td>\n",
       "      <td>0.168013</td>\n",
       "      <td>0.311489</td>\n",
       "      <td>0.057998</td>\n",
       "      <td>0.083309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>EASE_binary_l2200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.098012</td>\n",
       "      <td>0.167990</td>\n",
       "      <td>0.309104</td>\n",
       "      <td>0.060916</td>\n",
       "      <td>0.085615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>EASE_binary_l600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095905</td>\n",
       "      <td>0.167990</td>\n",
       "      <td>0.309660</td>\n",
       "      <td>0.059069</td>\n",
       "      <td>0.084132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>RRF_hybrid_3</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097340</td>\n",
       "      <td>0.167805</td>\n",
       "      <td>0.308896</td>\n",
       "      <td>0.060563</td>\n",
       "      <td>0.085263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>EASE_count_l1200</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097410</td>\n",
       "      <td>0.167805</td>\n",
       "      <td>0.310609</td>\n",
       "      <td>0.060303</td>\n",
       "      <td>0.085076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>EASE_count_l1600</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097317</td>\n",
       "      <td>0.167342</td>\n",
       "      <td>0.310030</td>\n",
       "      <td>0.060572</td>\n",
       "      <td>0.085189</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>NeuMF_n2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.095674</td>\n",
       "      <td>0.167180</td>\n",
       "      <td>0.312137</td>\n",
       "      <td>0.059018</td>\n",
       "      <td>0.083918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>RRF_hybrid_4</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097109</td>\n",
       "      <td>0.167110</td>\n",
       "      <td>0.308364</td>\n",
       "      <td>0.060333</td>\n",
       "      <td>0.084952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>MultVAE_v2</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.097919</td>\n",
       "      <td>0.165883</td>\n",
       "      <td>0.307808</td>\n",
       "      <td>0.060146</td>\n",
       "      <td>0.084537</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>MultVAE_v1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.096715</td>\n",
       "      <td>0.165860</td>\n",
       "      <td>0.309799</td>\n",
       "      <td>0.059773</td>\n",
       "      <td>0.084203</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>TwoTower_t1</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.094169</td>\n",
       "      <td>0.165860</td>\n",
       "      <td>0.307160</td>\n",
       "      <td>0.058893</td>\n",
       "      <td>0.083500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>43199</td>\n",
       "      <td>0.003218</td>\n",
       "      <td>0.006019</td>\n",
       "      <td>0.011621</td>\n",
       "      <td>0.001746</td>\n",
       "      <td>0.002724</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Model  UsersEval      HR@5     HR@10     HR@20    MRR@10  \\\n",
       "0    RP3beta_a1_1_b0_7      43199  0.097734  0.169981  0.308433  0.060909   \n",
       "1          TwoTower_t3      43199  0.096484  0.169749  0.310933  0.060031   \n",
       "2   RP3beta_a1_05_b0_6      43199  0.097687  0.169680  0.308850  0.060639   \n",
       "3    RP3beta_a1_0_b0_5      43199  0.097780  0.169657  0.310887  0.060827   \n",
       "4           MultVAE_v3      43199  0.098359  0.169610  0.312368  0.060681   \n",
       "5    RP3beta_a1_0_b0_6      43199  0.097687  0.169448  0.308896  0.060618   \n",
       "6    EASE_binary_l1000      43199  0.097711  0.169356  0.309614  0.060787   \n",
       "7          TwoTower_t2      43199  0.098312  0.169286  0.312901  0.060905   \n",
       "8        P3alpha_a1_05      43199  0.098197  0.169147  0.310540  0.060244   \n",
       "9    EASE_binary_l1200      43199  0.098544  0.169055  0.309799  0.061045   \n",
       "10     EASE_count_l800      43199  0.095396  0.169009  0.309961  0.059349   \n",
       "11  RP3beta_a0_95_b0_5      43199  0.097595  0.168985  0.310007  0.060733   \n",
       "12    EASE_binary_l800      43199  0.097618  0.168985  0.310007  0.060290   \n",
       "13        P3alpha_a1_0      43199  0.098104  0.168870  0.310609  0.060332   \n",
       "14    EASE_count_l1000      43199  0.096924  0.168870  0.310563  0.060030   \n",
       "15       P3alpha_a0_95      43199  0.097826  0.168800  0.310493  0.060338   \n",
       "16   EASE_binary_l1600      43199  0.098567  0.168731  0.309452  0.060942   \n",
       "17       RRF_regular_3      43199  0.097687  0.168638  0.308063  0.060710   \n",
       "18   RP3beta_a0_9_b0_4      43199  0.097224  0.168615  0.310053  0.060512   \n",
       "19       P3alpha_a1_15      43199  0.098058  0.168592  0.310239  0.060189   \n",
       "20   EASE_binary_l3000      43199  0.097780  0.168569  0.308734  0.060858   \n",
       "21     EASE_count_l600      43199  0.093752  0.168453  0.309915  0.057772   \n",
       "22    EASE_count_l3000      43199  0.098174  0.168268  0.308688  0.060787   \n",
       "23    EASE_count_l2200      43199  0.097757  0.168013  0.309313  0.060807   \n",
       "24       P3alpha_a0_85      43199  0.097664  0.168013  0.310678  0.060293   \n",
       "25            NeuMF_n1      43199  0.095419  0.168013  0.311489  0.057998   \n",
       "26   EASE_binary_l2200      43199  0.098012  0.167990  0.309104  0.060916   \n",
       "27    EASE_binary_l600      43199  0.095905  0.167990  0.309660  0.059069   \n",
       "28        RRF_hybrid_3      43199  0.097340  0.167805  0.308896  0.060563   \n",
       "29    EASE_count_l1200      43199  0.097410  0.167805  0.310609  0.060303   \n",
       "30    EASE_count_l1600      43199  0.097317  0.167342  0.310030  0.060572   \n",
       "31            NeuMF_n2      43199  0.095674  0.167180  0.312137  0.059018   \n",
       "32        RRF_hybrid_4      43199  0.097109  0.167110  0.308364  0.060333   \n",
       "33          MultVAE_v2      43199  0.097919  0.165883  0.307808  0.060146   \n",
       "34          MultVAE_v1      43199  0.096715  0.165860  0.309799  0.059773   \n",
       "35         TwoTower_t1      43199  0.094169  0.165860  0.307160  0.058893   \n",
       "36          Popularity      43199  0.003218  0.006019  0.011621  0.001746   \n",
       "\n",
       "     NDCG@10  \n",
       "0   0.086051  \n",
       "1   0.085268  \n",
       "2   0.085774  \n",
       "3   0.085906  \n",
       "4   0.085785  \n",
       "5   0.085705  \n",
       "6   0.085790  \n",
       "7   0.085869  \n",
       "8   0.085325  \n",
       "9   0.085938  \n",
       "10  0.084554  \n",
       "11  0.085685  \n",
       "12  0.085320  \n",
       "13  0.085333  \n",
       "14  0.085076  \n",
       "15  0.085320  \n",
       "16  0.085798  \n",
       "17  0.085596  \n",
       "18  0.085428  \n",
       "19  0.085164  \n",
       "20  0.085689  \n",
       "21  0.083179  \n",
       "22  0.085551  \n",
       "23  0.085508  \n",
       "24  0.085120  \n",
       "25  0.083309  \n",
       "26  0.085615  \n",
       "27  0.084132  \n",
       "28  0.085263  \n",
       "29  0.085076  \n",
       "30  0.085189  \n",
       "31  0.083918  \n",
       "32  0.084952  \n",
       "33  0.084537  \n",
       "34  0.084203  \n",
       "35  0.083500  \n",
       "36  0.002724  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best overall model: RP3beta_a1_1_b0_7\n"
     ]
    }
   ],
   "source": [
    "# Final combined comparison and optional fusion models\n",
    "\n",
    "fusion_results = []\n",
    "fusion_models = {}\n",
    "\n",
    "def top_model_name(df, prefix):\n",
    "    subset = df[df[\"Model\"].str.startswith(prefix)].copy()\n",
    "    if subset.empty:\n",
    "        return None\n",
    "    subset = subset.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False)\n",
    "    return subset.iloc[0][\"Model\"]\n",
    "\n",
    "if RUN_FUSION:\n",
    "    ease_best = top_model_name(regular_results_df, \"EASE_binary_\")\n",
    "    if ease_best is None:\n",
    "        ease_best = top_model_name(regular_results_df, \"EASE_count_\")\n",
    "    p3_best = top_model_name(regular_results_df, \"P3alpha_\")\n",
    "    rp3_best = top_model_name(regular_results_df, \"RP3beta_\")\n",
    "\n",
    "    if ease_best and p3_best and rp3_best:\n",
    "        print(\"=\" * 100)\n",
    "        print(\"Evaluating fusion: RRF_regular_3\")\n",
    "        rec = make_rrf_recommender(\n",
    "            [\n",
    "                (ease_best, regular_models[ease_best]),\n",
    "                (p3_best, regular_models[p3_best]),\n",
    "                (rp3_best, regular_models[rp3_best]),\n",
    "            ],\n",
    "            fetch_n=FUSION_FETCH_N,\n",
    "            rrf_k=FUSION_RRF_K,\n",
    "        )\n",
    "        fusion_models[\"RRF_regular_3\"] = rec\n",
    "        fusion_results.append(\n",
    "            evaluate_model(rec, \"RRF_regular_3\", data, user_indices=eval_users, ks=TOP_KS)\n",
    "        )\n",
    "\n",
    "    if isinstance(neural_results_df, pd.DataFrame) and not neural_results_df.empty:\n",
    "        tw_best = top_model_name(neural_results_df, \"TwoTower_\")\n",
    "        mv_best = top_model_name(neural_results_df, \"MultVAE_\")\n",
    "        nm_best = top_model_name(neural_results_df, \"NeuMF_\")\n",
    "\n",
    "        hybrid_parts = []\n",
    "        for name in [ease_best, rp3_best, tw_best, mv_best]:\n",
    "            if name is None:\n",
    "                continue\n",
    "            if name in regular_models:\n",
    "                hybrid_parts.append((name, regular_models[name]))\n",
    "            elif name in neural_models:\n",
    "                hybrid_parts.append((name, neural_models[name]))\n",
    "\n",
    "        if len(hybrid_parts) >= 3:\n",
    "            print(\"=\" * 100)\n",
    "            print(\"Evaluating fusion: RRF_hybrid_4\")\n",
    "            rec = make_rrf_recommender(\n",
    "                hybrid_parts,\n",
    "                fetch_n=FUSION_FETCH_N,\n",
    "                rrf_k=FUSION_RRF_K,\n",
    "            )\n",
    "            fusion_models[\"RRF_hybrid_4\"] = rec\n",
    "            fusion_results.append(\n",
    "                evaluate_model(rec, \"RRF_hybrid_4\", data, user_indices=eval_users, ks=TOP_KS)\n",
    "            )\n",
    "\n",
    "        hybrid_parts_3 = []\n",
    "        for name in [ease_best, tw_best, mv_best]:\n",
    "            if name is None:\n",
    "                continue\n",
    "            if name in regular_models:\n",
    "                hybrid_parts_3.append((name, regular_models[name]))\n",
    "            elif name in neural_models:\n",
    "                hybrid_parts_3.append((name, neural_models[name]))\n",
    "\n",
    "        if len(hybrid_parts_3) >= 3:\n",
    "            print(\"=\" * 100)\n",
    "            print(\"Evaluating fusion: RRF_hybrid_3\")\n",
    "            rec = make_rrf_recommender(\n",
    "                hybrid_parts_3,\n",
    "                fetch_n=FUSION_FETCH_N,\n",
    "                rrf_k=FUSION_RRF_K,\n",
    "            )\n",
    "            fusion_models[\"RRF_hybrid_3\"] = rec\n",
    "            fusion_results.append(\n",
    "                evaluate_model(rec, \"RRF_hybrid_3\", data, user_indices=eval_users, ks=TOP_KS)\n",
    "            )\n",
    "\n",
    "fusion_results_df = pd.DataFrame(fusion_results)\n",
    "frames = [regular_results_df]\n",
    "if isinstance(neural_results_df, pd.DataFrame) and not neural_results_df.empty:\n",
    "    frames.append(neural_results_df)\n",
    "if not fusion_results_df.empty:\n",
    "    frames.append(fusion_results_df)\n",
    "\n",
    "all_results_df = pd.concat(frames, ignore_index=True)\n",
    "all_results_df = all_results_df.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "print(\"Top regular models:\")\n",
    "display(regular_results_df.head(12))\n",
    "\n",
    "if isinstance(neural_results_df, pd.DataFrame) and not neural_results_df.empty:\n",
    "    print(\"Top neural models:\")\n",
    "    display(neural_results_df.head(12))\n",
    "else:\n",
    "    print(\"Neural models were skipped.\")\n",
    "\n",
    "if not fusion_results_df.empty:\n",
    "    print(\"Fusion models:\")\n",
    "    display(fusion_results_df.sort_values([\"HR@10\", \"NDCG@10\"], ascending=False))\n",
    "else:\n",
    "    print(\"Fusion models were skipped or could not be built.\")\n",
    "\n",
    "print(\"Overall ranking:\")\n",
    "display(all_results_df)\n",
    "\n",
    "best_model_name = all_results_df.iloc[0][\"Model\"]\n",
    "print(\"Best overall model:\", best_model_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "bed7dcdd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example pseudo-user index: 0\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>Country</th>\n",
       "      <th>PriceBin</th>\n",
       "      <th>MileageBin</th>\n",
       "      <th>Condition</th>\n",
       "      <th>AgeBin</th>\n",
       "      <th>user_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Australia | price_(11233.652 to  17482.96] | m...</td>\n",
       "      <td>Australia</td>\n",
       "      <td>price_(11233.652 to  17482.96]</td>\n",
       "      <td>mileage_(-0.001 to  16665.0]</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>age_(1.999 to  4.0]</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             user_id    Country  \\\n",
       "0  Australia | price_(11233.652 to  17482.96] | m...  Australia   \n",
       "\n",
       "                         PriceBin                    MileageBin  \\\n",
       "0  price_(11233.652 to  17482.96]  mileage_(-0.001 to  16665.0]   \n",
       "\n",
       "             Condition               AgeBin  user_idx  \n",
       "0  Certified Pre-Owned  age_(1.999 to  4.0]         0  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rank</th>\n",
       "      <th>recommended_item</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>Chrysler :: Voyager :: age_(1.999 to  4.0] :: ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>Chrysler :: Pacifica :: age_(1.999 to  4.0] ::...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Lexus :: ES :: age_(1.999 to  4.0] :: Certifie...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>Land Rover :: Velar :: age_(1.999 to  4.0] :: ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>Jeep :: Cherokee :: age_(1.999 to  4.0] :: Cer...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>Jeep :: Wrangler :: age_(1.999 to  4.0] :: Cer...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>Nissan :: Titan :: age_(1.999 to  4.0] :: Cert...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>Mazda :: Mazda3 :: age_(1.999 to  4.0] :: Cert...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>BMW :: Z4 :: age_(1.999 to  4.0] :: Certified ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>Nissan :: Altima :: age_(1.999 to  4.0] :: Cer...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   rank                                   recommended_item\n",
       "0     1  Chrysler :: Voyager :: age_(1.999 to  4.0] :: ...\n",
       "1     2  Chrysler :: Pacifica :: age_(1.999 to  4.0] ::...\n",
       "2     3  Lexus :: ES :: age_(1.999 to  4.0] :: Certifie...\n",
       "3     4  Land Rover :: Velar :: age_(1.999 to  4.0] :: ...\n",
       "4     5  Jeep :: Cherokee :: age_(1.999 to  4.0] :: Cer...\n",
       "5     6  Jeep :: Wrangler :: age_(1.999 to  4.0] :: Cer...\n",
       "6     7  Nissan :: Titan :: age_(1.999 to  4.0] :: Cert...\n",
       "7     8  Mazda :: Mazda3 :: age_(1.999 to  4.0] :: Cert...\n",
       "8     9  BMW :: Z4 :: age_(1.999 to  4.0] :: Certified ...\n",
       "9    10  Nissan :: Altima :: age_(1.999 to  4.0] :: Cer..."
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Show example recommendations from the best overall model\n",
    "\n",
    "all_model_funcs = {}\n",
    "all_model_funcs.update(regular_models)\n",
    "if 'neural_models' in globals():\n",
    "    all_model_funcs.update(neural_models)\n",
    "if 'fusion_models' in globals():\n",
    "    all_model_funcs.update(fusion_models)\n",
    "\n",
    "best_recommender = all_model_funcs[best_model_name]\n",
    "\n",
    "example_user = int(eval_users[0])\n",
    "example_item_ids = best_recommender(example_user, n=10)\n",
    "example_item_names = [item_ids[i] for i in example_item_ids]\n",
    "\n",
    "print(\"Example pseudo-user index:\", example_user)\n",
    "display(user_feature_df[user_feature_df[\"user_idx\"] == example_user])\n",
    "\n",
    "example_df = pd.DataFrame({\n",
    "    \"rank\": np.arange(1, len(example_item_names) + 1),\n",
    "    \"recommended_item\": example_item_names,\n",
    "})\n",
    "display(example_df)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26869d31",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "If you still want even harder settings after this run, the next levers are:\n",
    "- raise item granularity again\n",
    "- require lower popularity ceiling for formulation selection\n",
    "- increase negative-sampling pressure for pairwise neural models\n",
    "- tune the strongest 2 to 3 models only instead of running the full zoo"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
