{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "72a8c1e8",
   "metadata": {},
   "source": [
    "# Car recommendation notebook (segment-based)\n",
    "\n",
    "This version reformulates the car task as a **regular recommender system**.\n",
    "\n",
    "Because the dataset has only one purchase per person and no real user history, we create **pseudo-users as market segments** built from:\n",
    "- `Country`\n",
    "- `Condition`\n",
    "- `Price` bin\n",
    "- `Mileage` bin\n",
    "- `Year` bin\n",
    "\n",
    "Each segment has a history of purchased car models.  \n",
    "The task becomes:\n",
    "\n",
    "**Given a segment's past purchases, recommend the next car model that is likely to be relevant for that segment.**\n",
    "\n",
    "Models in this notebook:\n",
    "- Regular: Popularity, ItemKNN, EASE\n",
    "- Neural: NeuMF, MultVAE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c6399ce0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch device: cuda\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from pathlib import Path\n",
    "import math\n",
    "import random\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(\"Torch device:\", DEVICE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c980574b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using dataset: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Paths and core settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "# Segment construction\n",
    "N_PRICE_BINS = 8\n",
    "N_MILEAGE_BINS = 8\n",
    "N_YEAR_BINS = 6\n",
    "\n",
    "# Filtering\n",
    "MIN_ITEMS_PER_SEGMENT = 8\n",
    "MIN_SEGMENTS_PER_ITEM = 8\n",
    "\n",
    "# Ranking\n",
    "TOP_KS = [5, 10, 20]\n",
    "\n",
    "# ItemKNN\n",
    "ITEM_KNN_NEIGHBORS = 30\n",
    "\n",
    "# EASE\n",
    "EASE_LAMBDA = 200.0\n",
    "\n",
    "# NeuMF\n",
    "NEUMF_NEG_PER_POS = 4\n",
    "NEUMF_EPOCHS = 10\n",
    "NEUMF_BATCH_SIZE = 4096\n",
    "NEUMF_LR = 1e-3\n",
    "NEUMF_MF_DIM = 32\n",
    "NEUMF_MLP_DIM = 64\n",
    "\n",
    "# MultVAE\n",
    "MULTVAE_EPOCHS = 60\n",
    "MULTVAE_BATCH_SIZE = 256\n",
    "MULTVAE_LR = 1e-3\n",
    "MULTVAE_HIDDEN_DIM = 256\n",
    "MULTVAE_LATENT_DIM = 64\n",
    "MULTVAE_KL_ANNEAL_EPOCHS = 20\n",
    "\n",
    "print(\"Using dataset:\", CSV_PATH)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ddf27c4b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape: (1000000, 11)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  \n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil  \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy  \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK  \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico  \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Brand          object\n",
       "Model          object\n",
       "Year            int64\n",
       "Price         float64\n",
       "Mileage         int64\n",
       "Color          object\n",
       "Condition      object\n",
       "First Name     object\n",
       "Last Name      object\n",
       "Address        object\n",
       "Country        object\n",
       "dtype: object"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Load data\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "print(\"Shape:\", df.shape)\n",
    "display(df.head())\n",
    "display(df.dtypes)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "98e34d8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Raw aggregated interactions: (632818, 3)\n",
      "Unique segments: 11520\n",
      "Unique items   : 88\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>segment_id</th>\n",
       "      <th>item_id</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Australia | Certified Pre-Owned | p0 | m0 | y0</td>\n",
       "      <td>Audi :: A3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Australia | Certified Pre-Owned | p0 | m0 | y0</td>\n",
       "      <td>Audi :: A4</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Australia | Certified Pre-Owned | p0 | m0 | y0</td>\n",
       "      <td>Audi :: A6</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Australia | Certified Pre-Owned | p0 | m0 | y0</td>\n",
       "      <td>Audi :: Q5</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Australia | Certified Pre-Owned | p0 | m0 | y0</td>\n",
       "      <td>Audi :: Q7</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       segment_id     item_id  count\n",
       "0  Australia | Certified Pre-Owned | p0 | m0 | y0  Audi :: A3      3\n",
       "1  Australia | Certified Pre-Owned | p0 | m0 | y0  Audi :: A4      4\n",
       "2  Australia | Certified Pre-Owned | p0 | m0 | y0  Audi :: A6      3\n",
       "3  Australia | Certified Pre-Owned | p0 | m0 | y0  Audi :: Q5      1\n",
       "4  Australia | Certified Pre-Owned | p0 | m0 | y0  Audi :: Q7      2"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Create item ids and market segments\n",
    "\n",
    "work = df.copy()\n",
    "\n",
    "work[\"Brand\"] = work[\"Brand\"].astype(str).str.strip()\n",
    "work[\"Model\"] = work[\"Model\"].astype(str).str.strip()\n",
    "work[\"Color\"] = work[\"Color\"].astype(str).str.strip()\n",
    "work[\"Condition\"] = work[\"Condition\"].astype(str).str.strip()\n",
    "work[\"Country\"] = work[\"Country\"].astype(str).str.strip()\n",
    "\n",
    "work[\"item_id\"] = work[\"Brand\"] + \" :: \" + work[\"Model\"]\n",
    "\n",
    "# Quantile bins keep segment sizes more balanced\n",
    "work[\"PriceBin\"] = pd.qcut(\n",
    "    work[\"Price\"],\n",
    "    q=N_PRICE_BINS,\n",
    "    labels=[f\"p{i}\" for i in range(N_PRICE_BINS)],\n",
    "    duplicates=\"drop\"\n",
    ").astype(str)\n",
    "\n",
    "work[\"MileageBin\"] = pd.qcut(\n",
    "    work[\"Mileage\"],\n",
    "    q=N_MILEAGE_BINS,\n",
    "    labels=[f\"m{i}\" for i in range(N_MILEAGE_BINS)],\n",
    "    duplicates=\"drop\"\n",
    ").astype(str)\n",
    "\n",
    "work[\"YearBin\"] = pd.qcut(\n",
    "    work[\"Year\"],\n",
    "    q=N_YEAR_BINS,\n",
    "    labels=[f\"y{i}\" for i in range(N_YEAR_BINS)],\n",
    "    duplicates=\"drop\"\n",
    ").astype(str)\n",
    "\n",
    "segment_cols = [\"Country\", \"Condition\", \"PriceBin\", \"MileageBin\", \"YearBin\"]\n",
    "work[\"segment_id\"] = work[segment_cols].astype(str).agg(\" | \".join, axis=1)\n",
    "\n",
    "# Aggregate to implicit interactions: segment -> item purchase count\n",
    "interactions = (\n",
    "    work.groupby([\"segment_id\", \"item_id\"], as_index=False)\n",
    "    .size()\n",
    "    .rename(columns={\"size\": \"count\"})\n",
    ")\n",
    "\n",
    "print(\"Raw aggregated interactions:\", interactions.shape)\n",
    "print(\"Unique segments:\", interactions[\"segment_id\"].nunique())\n",
    "print(\"Unique items   :\", interactions[\"item_id\"].nunique())\n",
    "display(interactions.head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f7a2f8a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After filtering:\n",
      "Interactions  : (632818, 3)\n",
      "Segments      : 11520\n",
      "Items         : 88\n",
      "Min items/seg : 38\n",
      "Min segs/item : 6988\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "item_id\n",
       "Chrysler :: Voyager       9182\n",
       "Chrysler :: 300           9137\n",
       "Chrysler :: Pacifica      9109\n",
       "Chevrolet :: Silverado    7271\n",
       "Hyundai :: Santa Fe       7218\n",
       "Jeep :: Cherokee          7217\n",
       "Toyota :: RAV4            7210\n",
       "Toyota :: Tacoma          7209\n",
       "Dodge :: Durango          7197\n",
       "Kia :: Forte              7194\n",
       "Volkswagen :: Jetta       7193\n",
       "Lexus :: ES               7189\n",
       "Hyundai :: Sonata         7187\n",
       "Chevrolet :: Trax         7185\n",
       "Ford :: Mustang           7185\n",
       "dtype: int64"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Iterative filtering so every segment and every item has enough support\n",
    "\n",
    "for _ in range(5):\n",
    "    segment_sizes = interactions.groupby(\"segment_id\").size()\n",
    "    keep_segments = segment_sizes[segment_sizes >= MIN_ITEMS_PER_SEGMENT].index\n",
    "    interactions = interactions[interactions[\"segment_id\"].isin(keep_segments)].copy()\n",
    "\n",
    "    item_sizes = interactions.groupby(\"item_id\").size()\n",
    "    keep_items = item_sizes[item_sizes >= MIN_SEGMENTS_PER_ITEM].index\n",
    "    interactions = interactions[interactions[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "segment_sizes = interactions.groupby(\"segment_id\").size()\n",
    "item_sizes = interactions.groupby(\"item_id\").size()\n",
    "\n",
    "print(\"After filtering:\")\n",
    "print(\"Interactions  :\", interactions.shape)\n",
    "print(\"Segments      :\", interactions['segment_id'].nunique())\n",
    "print(\"Items         :\", interactions['item_id'].nunique())\n",
    "print(\"Min items/seg :\", int(segment_sizes.min()))\n",
    "print(\"Min segs/item :\", int(item_sizes.min()))\n",
    "display(item_sizes.sort_values(ascending=False).head(15))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "82319390",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train interactions: (621298, 5)\n",
      "Test interactions : (11520, 5)\n",
      "Users: 11520\n",
      "Items: 88\n",
      "Matrix shape: (11520, 88)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Encode users/items and create leave-one-out split\n",
    "\n",
    "segment_encoder = LabelEncoder()\n",
    "item_encoder = LabelEncoder()\n",
    "\n",
    "interactions[\"user_idx\"] = segment_encoder.fit_transform(interactions[\"segment_id\"])\n",
    "interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "num_users = interactions[\"user_idx\"].nunique()\n",
    "num_items = interactions[\"item_idx\"].nunique()\n",
    "\n",
    "rng = np.random.default_rng(RANDOM_STATE)\n",
    "\n",
    "test_idx = []\n",
    "for user_idx, g in interactions.groupby(\"user_idx\"):\n",
    "    picked = g.sample(n=1, random_state=int(rng.integers(1_000_000))).index[0]\n",
    "    test_idx.append(picked)\n",
    "\n",
    "test_interactions = interactions.loc[test_idx].copy()\n",
    "train_interactions = interactions.drop(index=test_idx).copy()\n",
    "\n",
    "print(\"Train interactions:\", train_interactions.shape)\n",
    "print(\"Test interactions :\", test_interactions.shape)\n",
    "print(\"Users:\", num_users)\n",
    "print(\"Items:\", num_items)\n",
    "\n",
    "# Sparse matrices\n",
    "rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "X_binary = csr_matrix((np.ones_like(vals), (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "\n",
    "# Dense versions are safe here because the item catalog is small\n",
    "X_counts_dense = X_counts.toarray().astype(np.float32)\n",
    "X_binary_dense = X_binary.toarray().astype(np.float32)\n",
    "\n",
    "user_seen = {\n",
    "    int(uid): set(g[\"item_idx\"].tolist())\n",
    "    for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "}\n",
    "\n",
    "test_item_by_user = {\n",
    "    int(uid): int(item)\n",
    "    for uid, item in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "}\n",
    "\n",
    "item_ids = item_encoder.classes_\n",
    "global_popularity = np.asarray(X_counts.sum(axis=0)).ravel()\n",
    "global_pop_rank = np.argsort(global_popularity)[::-1]\n",
    "\n",
    "print(\"Matrix shape:\", X_binary.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3bd6dd5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Metrics\n",
    "\n",
    "def hit_rate_at_k(recs, true_item, k):\n",
    "    return 1.0 if true_item in recs[:k] else 0.0\n",
    "\n",
    "def mrr_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / rank\n",
    "    return 0.0\n",
    "\n",
    "def ndcg_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / math.log2(rank + 1.0)\n",
    "    return 0.0\n",
    "\n",
    "def evaluate_model(model_func, model_name, user_ids, ks=(5, 10, 20)):\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr = 0.0\n",
    "    ndcg = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    max_k = max(ks)\n",
    "\n",
    "    print(\"=\" * 80)\n",
    "    print(\"Evaluating:\", model_name)\n",
    "\n",
    "    for uid in tqdm(user_ids):\n",
    "        true_item = test_item_by_user.get(int(uid))\n",
    "        if true_item is None:\n",
    "            continue\n",
    "\n",
    "        recs = model_func(int(uid), n=max_k)\n",
    "        if not isinstance(recs, list):\n",
    "            recs = list(recs)\n",
    "\n",
    "        valid += 1\n",
    "        for k in ks:\n",
    "            hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "        mrr += mrr_at_k(recs, true_item, 10)\n",
    "        ndcg += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    row = {\n",
    "        \"Model\": model_name,\n",
    "        \"HR@5\": hits[5] / valid if valid else 0.0,\n",
    "        \"HR@10\": hits[10] / valid if valid else 0.0,\n",
    "        \"HR@20\": hits[20] / valid if valid else 0.0,\n",
    "        \"MRR@10\": mrr / valid if valid else 0.0,\n",
    "        \"NDCG@10\": ndcg / valid if valid else 0.0,\n",
    "    }\n",
    "    return row\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d69f57f9",
   "metadata": {},
   "source": [
    "## Regular model 1: Popularity baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cb2afb07",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def recommend_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(user_idx, set())\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33bc78c5",
   "metadata": {},
   "source": [
    "## Regular model 2: ItemKNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d1c2e1c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Item-item KNN on binary segment-item interactions\n",
    "\n",
    "item_user = X_binary.T.tocsr()\n",
    "\n",
    "item_knn = NearestNeighbors(\n",
    "    metric=\"cosine\",\n",
    "    algorithm=\"brute\",\n",
    "    n_neighbors=min(ITEM_KNN_NEIGHBORS, num_items),\n",
    "    n_jobs=-1,\n",
    ")\n",
    "item_knn.fit(item_user)\n",
    "\n",
    "knn_distances, knn_indices = item_knn.kneighbors(\n",
    "    item_user,\n",
    "    n_neighbors=min(ITEM_KNN_NEIGHBORS, num_items)\n",
    ")\n",
    "knn_similarities = (1.0 - knn_distances).astype(np.float32)\n",
    "\n",
    "user_item_strength = {\n",
    "    int(uid): {\n",
    "        int(item): float(cnt)\n",
    "        for item, cnt in zip(g[\"item_idx\"], g[\"count\"])\n",
    "    }\n",
    "    for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "}\n",
    "\n",
    "def recommend_itemknn(user_idx, n=10):\n",
    "    seen = user_seen.get(user_idx, set())\n",
    "    strengths = user_item_strength.get(user_idx, {})\n",
    "\n",
    "    scores = np.zeros(num_items, dtype=np.float32)\n",
    "\n",
    "    for item_idx, cnt in strengths.items():\n",
    "        weight = np.log1p(cnt)\n",
    "        nbrs = knn_indices[item_idx]\n",
    "        sims = knn_similarities[item_idx]\n",
    "\n",
    "        for nbr_idx, sim in zip(nbrs, sims):\n",
    "            if nbr_idx == item_idx:\n",
    "                continue\n",
    "            scores[nbr_idx] += weight * sim\n",
    "\n",
    "    if seen:\n",
    "        scores[list(seen)] = -np.inf\n",
    "\n",
    "    n = min(n, num_items - len(seen))\n",
    "    if n <= 0:\n",
    "        return []\n",
    "\n",
    "    top = np.argpartition(scores, -n)[-n:]\n",
    "    top = top[np.argsort(scores[top])[::-1]]\n",
    "    return top.tolist()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d9b5413",
   "metadata": {},
   "source": [
    "## Regular model 3: EASE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "758ab370",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# EASE on binary interactions\n",
    "\n",
    "G = X_binary_dense.T @ X_binary_dense\n",
    "G[np.diag_indices(num_items)] += EASE_LAMBDA\n",
    "\n",
    "P = np.linalg.inv(G)\n",
    "B = -P / np.diag(P)\n",
    "B[np.diag_indices(num_items)] = 0.0\n",
    "\n",
    "def recommend_ease(user_idx, n=10):\n",
    "    scores = X_binary_dense[user_idx] @ B\n",
    "    seen = user_seen.get(user_idx, set())\n",
    "\n",
    "    if seen:\n",
    "        scores[list(seen)] = -np.inf\n",
    "\n",
    "    n = min(n, num_items - len(seen))\n",
    "    if n <= 0:\n",
    "        return []\n",
    "\n",
    "    top = np.argpartition(scores, -n)[-n:]\n",
    "    top = top[np.argsort(scores[top])[::-1]]\n",
    "    return top.tolist()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "887cfa54",
   "metadata": {},
   "source": [
    "## Neural model 1: NeuMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "452e132d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NeuMF epoch 1/10 - loss: 0.5025\n",
      "NeuMF epoch 2/10 - loss: 0.4987\n",
      "NeuMF epoch 3/10 - loss: 0.4974\n",
      "NeuMF epoch 4/10 - loss: 0.4952\n",
      "NeuMF epoch 5/10 - loss: 0.4914\n",
      "NeuMF epoch 6/10 - loss: 0.4857\n",
      "NeuMF epoch 7/10 - loss: 0.4782\n",
      "NeuMF epoch 8/10 - loss: 0.4694\n",
      "NeuMF epoch 9/10 - loss: 0.4596\n",
      "NeuMF epoch 10/10 - loss: 0.4495\n"
     ]
    }
   ],
   "source": [
    "\n",
    "class ImplicitPairDataset(Dataset):\n",
    "    def __init__(self, user_seen_dict, num_items, neg_per_pos=4, seed=42):\n",
    "        self.user_seen_dict = user_seen_dict\n",
    "        self.num_items = num_items\n",
    "        self.neg_per_pos = neg_per_pos\n",
    "        self.seed = seed\n",
    "        self.rng = np.random.default_rng(seed)\n",
    "        self.positive_pairs = [\n",
    "            (int(u), int(i))\n",
    "            for u, items in user_seen_dict.items()\n",
    "            for i in items\n",
    "        ]\n",
    "        self.samples = []\n",
    "        self.refresh()\n",
    "\n",
    "    def refresh(self):\n",
    "        samples = []\n",
    "        for u, i in self.positive_pairs:\n",
    "            samples.append((u, i, 1.0))\n",
    "\n",
    "            negs = 0\n",
    "            pos_set = self.user_seen_dict[u]\n",
    "            while negs < self.neg_per_pos:\n",
    "                j = int(self.rng.integers(self.num_items))\n",
    "                if j not in pos_set:\n",
    "                    samples.append((u, j, 0.0))\n",
    "                    negs += 1\n",
    "\n",
    "        self.samples = samples\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.samples)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        u, i, y = self.samples[idx]\n",
    "        return (\n",
    "            torch.tensor(u, dtype=torch.long),\n",
    "            torch.tensor(i, dtype=torch.long),\n",
    "            torch.tensor(y, dtype=torch.float32),\n",
    "        )\n",
    "\n",
    "class NeuMF(nn.Module):\n",
    "    def __init__(self, num_users, num_items, mf_dim=32, mlp_dim=64):\n",
    "        super().__init__()\n",
    "        self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "        self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "\n",
    "        self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "        self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.Linear(mlp_dim * 2, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.15),\n",
    "            nn.Linear(128, 64),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.out = nn.Linear(mf_dim + 64, 1)\n",
    "\n",
    "    def forward(self, user_idx, item_idx):\n",
    "        mf_part = self.user_mf(user_idx) * self.item_mf(item_idx)\n",
    "        mlp_part = self.mlp(\n",
    "            torch.cat([self.user_mlp(user_idx), self.item_mlp(item_idx)], dim=1)\n",
    "        )\n",
    "        x = torch.cat([mf_part, mlp_part], dim=1)\n",
    "        return self.out(x).squeeze(1)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def score_all_items(self, user_idx):\n",
    "        self.eval()\n",
    "        device = next(self.parameters()).device\n",
    "        item_idx = torch.arange(self.item_mf.num_embeddings, device=device)\n",
    "        user_vec = torch.full_like(item_idx, fill_value=int(user_idx))\n",
    "        scores = self.forward(user_vec, item_idx)\n",
    "        return scores.detach().cpu().numpy()\n",
    "\n",
    "neumf_model = NeuMF(\n",
    "    num_users=num_users,\n",
    "    num_items=num_items,\n",
    "    mf_dim=NEUMF_MF_DIM,\n",
    "    mlp_dim=NEUMF_MLP_DIM,\n",
    ").to(DEVICE)\n",
    "\n",
    "neumf_dataset = ImplicitPairDataset(\n",
    "    user_seen_dict=user_seen,\n",
    "    num_items=num_items,\n",
    "    neg_per_pos=NEUMF_NEG_PER_POS,\n",
    "    seed=RANDOM_STATE,\n",
    ")\n",
    "\n",
    "neumf_loader = DataLoader(\n",
    "    neumf_dataset,\n",
    "    batch_size=NEUMF_BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    num_workers=0,\n",
    "    pin_memory=(DEVICE == \"cuda\"),\n",
    ")\n",
    "\n",
    "neumf_optimizer = torch.optim.Adam(neumf_model.parameters(), lr=NEUMF_LR)\n",
    "neumf_criterion = nn.BCEWithLogitsLoss()\n",
    "\n",
    "for epoch in range(NEUMF_EPOCHS):\n",
    "    neumf_dataset.refresh()\n",
    "    neumf_model.train()\n",
    "    total_loss = 0.0\n",
    "\n",
    "    for user_idx, item_idx, target in neumf_loader:\n",
    "        user_idx = user_idx.to(DEVICE)\n",
    "        item_idx = item_idx.to(DEVICE)\n",
    "        target = target.to(DEVICE)\n",
    "\n",
    "        neumf_optimizer.zero_grad()\n",
    "        logits = neumf_model(user_idx, item_idx)\n",
    "        loss = neumf_criterion(logits, target)\n",
    "        loss.backward()\n",
    "        neumf_optimizer.step()\n",
    "\n",
    "        total_loss += loss.item() * len(target)\n",
    "\n",
    "    epoch_loss = total_loss / len(neumf_dataset)\n",
    "    print(f\"NeuMF epoch {epoch + 1}/{NEUMF_EPOCHS} - loss: {epoch_loss:.4f}\")\n",
    "\n",
    "@torch.no_grad()\n",
    "def recommend_neumf(user_idx, n=10):\n",
    "    scores = neumf_model.score_all_items(user_idx)\n",
    "    seen = user_seen.get(user_idx, set())\n",
    "\n",
    "    if seen:\n",
    "        scores[list(seen)] = -np.inf\n",
    "\n",
    "    n = min(n, num_items - len(seen))\n",
    "    if n <= 0:\n",
    "        return []\n",
    "\n",
    "    top = np.argpartition(scores, -n)[-n:]\n",
    "    top = top[np.argsort(scores[top])[::-1]]\n",
    "    return top.tolist()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3acaa6c4",
   "metadata": {},
   "source": [
    "## Neural model 2: MultVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "af3565cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MultVAE epoch 1/60 - loss: 241.3210 - anneal: 0.050\n",
      "MultVAE epoch 10/60 - loss: 241.6675 - anneal: 0.500\n",
      "MultVAE epoch 20/60 - loss: 241.4924 - anneal: 1.000\n",
      "MultVAE epoch 30/60 - loss: 241.4797 - anneal: 1.000\n",
      "MultVAE epoch 40/60 - loss: 241.4659 - anneal: 1.000\n",
      "MultVAE epoch 50/60 - loss: 241.4500 - anneal: 1.000\n",
      "MultVAE epoch 60/60 - loss: 241.4394 - anneal: 1.000\n"
     ]
    }
   ],
   "source": [
    "\n",
    "class MultVAE(nn.Module):\n",
    "    def __init__(self, num_items, hidden_dim=256, latent_dim=64):\n",
    "        super().__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(num_items, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(latent_dim, hidden_dim),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_dim, num_items),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        h = self.encoder(x)\n",
    "        mu = self.mu(h)\n",
    "        logvar = self.logvar(h)\n",
    "\n",
    "        if self.training:\n",
    "            std = torch.exp(0.5 * logvar)\n",
    "            eps = torch.randn_like(std)\n",
    "            z = mu + eps * std\n",
    "        else:\n",
    "            z = mu\n",
    "\n",
    "        logits = self.decoder(z)\n",
    "        return logits, mu, logvar\n",
    "\n",
    "def multvae_loss_fn(logits, x, mu, logvar, anneal=1.0):\n",
    "    recon = -(torch.log_softmax(logits, dim=1) * x).sum(dim=1).mean()\n",
    "    kl = -0.5 * torch.mean(torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))\n",
    "    return recon + anneal * kl, recon.detach(), kl.detach()\n",
    "\n",
    "multvae_model = MultVAE(\n",
    "    num_items=num_items,\n",
    "    hidden_dim=MULTVAE_HIDDEN_DIM,\n",
    "    latent_dim=MULTVAE_LATENT_DIM,\n",
    ").to(DEVICE)\n",
    "\n",
    "multvae_optimizer = torch.optim.Adam(multvae_model.parameters(), lr=MULTVAE_LR)\n",
    "\n",
    "X_multvae = torch.tensor(X_binary_dense, dtype=torch.float32)\n",
    "\n",
    "for epoch in range(MULTVAE_EPOCHS):\n",
    "    multvae_model.train()\n",
    "    perm = torch.randperm(X_multvae.size(0))\n",
    "    total_loss = 0.0\n",
    "\n",
    "    anneal = min(1.0, (epoch + 1) / MULTVAE_KL_ANNEAL_EPOCHS)\n",
    "\n",
    "    for start in range(0, X_multvae.size(0), MULTVAE_BATCH_SIZE):\n",
    "        idx = perm[start:start + MULTVAE_BATCH_SIZE]\n",
    "        batch = X_multvae[idx].to(DEVICE)\n",
    "\n",
    "        multvae_optimizer.zero_grad()\n",
    "        logits, mu, logvar = multvae_model(batch)\n",
    "        loss, recon, kl = multvae_loss_fn(logits, batch, mu, logvar, anneal=anneal)\n",
    "        loss.backward()\n",
    "        multvae_optimizer.step()\n",
    "\n",
    "        total_loss += loss.item() * batch.size(0)\n",
    "\n",
    "    epoch_loss = total_loss / X_multvae.size(0)\n",
    "    if (epoch + 1) % 10 == 0 or epoch == 0:\n",
    "        print(f\"MultVAE epoch {epoch + 1}/{MULTVAE_EPOCHS} - loss: {epoch_loss:.4f} - anneal: {anneal:.3f}\")\n",
    "\n",
    "@torch.no_grad()\n",
    "def recommend_multvae(user_idx, n=10):\n",
    "    multvae_model.eval()\n",
    "\n",
    "    x = torch.tensor(X_binary_dense[user_idx:user_idx + 1], dtype=torch.float32, device=DEVICE)\n",
    "    logits, _, _ = multvae_model(x)\n",
    "    scores = logits[0].detach().cpu().numpy()\n",
    "\n",
    "    seen = user_seen.get(user_idx, set())\n",
    "    if seen:\n",
    "        scores[list(seen)] = -np.inf\n",
    "\n",
    "    n = min(n, num_items - len(seen))\n",
    "    if n <= 0:\n",
    "        return []\n",
    "\n",
    "    top = np.argpartition(scores, -n)[-n:]\n",
    "    top = top[np.argsort(scores[top])[::-1]]\n",
    "    return top.tolist()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1e88dc7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example segment:\n",
      "France | Certified Pre-Owned | p2 | m4 | y0\n",
      "\n",
      "Seen items (first 15):\n",
      "['Audi :: A6', 'Audi :: Q7', 'BMW :: 5 Series', 'BMW :: X5', 'BMW :: Z4', 'Chevrolet :: Equinox', 'Chevrolet :: Malibu', 'Chevrolet :: Silverado', 'Chevrolet :: Trax', 'Chrysler :: 300', 'Chrysler :: Pacifica', 'Chrysler :: Voyager', 'Dodge :: Challenger', 'Dodge :: Journey', 'Ford :: Escape']\n",
      "\n",
      "Popularity:\n",
      "['Kia :: Optima', 'Kia :: Forte', 'Hyundai :: Tucson', 'Toyota :: Tacoma', 'Volkswagen :: Tiguan', 'Dodge :: Charger', 'Volkswagen :: Jetta', 'Honda :: CR-V', 'Toyota :: RAV4', 'Audi :: A4']\n",
      "\n",
      "ItemKNN:\n",
      "['Kia :: Forte', 'Toyota :: RAV4', 'Toyota :: Tacoma', 'Dodge :: Durango', 'Volkswagen :: Jetta', 'Hyundai :: Tucson', 'Jeep :: Renegade', 'Kia :: Optima', 'Land Rover :: Velar', 'Audi :: A3']\n",
      "\n",
      "EASE:\n",
      "['Dodge :: Ram 1500', 'Honda :: CR-V', 'Volkswagen :: Passat', 'Audi :: A4', 'Mazda :: Mazda6', 'Kia :: Optima', 'Volkswagen :: Atlas', 'Toyota :: RAV4', 'Nissan :: Titan', 'Hyundai :: Kona']\n",
      "\n",
      "NeuMF:\n",
      "['Kia :: Forte', 'Audi :: Q5', 'Nissan :: Murano', 'Dodge :: Ram 1500', 'Mazda :: CX-5', 'Mazda :: Mazda6', 'BMW :: 3 Series', 'Volkswagen :: Passat', 'Ford :: Explorer', 'Toyota :: Camry']\n",
      "\n",
      "MultVAE:\n",
      "['Volkswagen :: Jetta', 'Dodge :: Durango', 'Toyota :: RAV4', 'Kia :: Optima', 'Hyundai :: Tucson', 'Nissan :: Murano', 'Kia :: Forte', 'Honda :: CR-V', 'Ford :: Explorer', 'Toyota :: Tacoma']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Example recommendations for one segment\n",
    "\n",
    "example_user = int(train_interactions[\"user_idx\"].sample(1, random_state=RANDOM_STATE).iloc[0])\n",
    "example_segment = segment_encoder.inverse_transform([example_user])[0]\n",
    "seen_items = [item_ids[i] for i in sorted(user_seen[example_user])[:15]]\n",
    "\n",
    "print(\"Example segment:\")\n",
    "print(example_segment)\n",
    "print()\n",
    "print(\"Seen items (first 15):\")\n",
    "print(seen_items)\n",
    "print()\n",
    "print(\"Popularity:\")\n",
    "print([item_ids[i] for i in recommend_popularity(example_user, n=10)])\n",
    "print()\n",
    "print(\"ItemKNN:\")\n",
    "print([item_ids[i] for i in recommend_itemknn(example_user, n=10)])\n",
    "print()\n",
    "print(\"EASE:\")\n",
    "print([item_ids[i] for i in recommend_ease(example_user, n=10)])\n",
    "print()\n",
    "print(\"NeuMF:\")\n",
    "print([item_ids[i] for i in recommend_neumf(example_user, n=10)])\n",
    "print()\n",
    "print(\"MultVAE:\")\n",
    "print([item_ids[i] for i in recommend_multvae(example_user, n=10)])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0c6507f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: Popularity\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11520/11520 [00:00<00:00, 107535.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: ItemKNN\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11520/11520 [00:08<00:00, 1348.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: EASE\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11520/11520 [00:00<00:00, 74883.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: NeuMF\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11520/11520 [00:03<00:00, 2949.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: MultVAE\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11520/11520 [00:02<00:00, 3972.26it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ItemKNN</td>\n",
       "      <td>0.173611</td>\n",
       "      <td>0.323785</td>\n",
       "      <td>0.611806</td>\n",
       "      <td>0.108583</td>\n",
       "      <td>0.157806</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MultVAE</td>\n",
       "      <td>0.176302</td>\n",
       "      <td>0.321007</td>\n",
       "      <td>0.610243</td>\n",
       "      <td>0.110243</td>\n",
       "      <td>0.158543</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Popularity</td>\n",
       "      <td>0.170660</td>\n",
       "      <td>0.321007</td>\n",
       "      <td>0.611806</td>\n",
       "      <td>0.109012</td>\n",
       "      <td>0.157498</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EASE</td>\n",
       "      <td>0.169271</td>\n",
       "      <td>0.317535</td>\n",
       "      <td>0.605816</td>\n",
       "      <td>0.107657</td>\n",
       "      <td>0.155620</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NeuMF</td>\n",
       "      <td>0.159809</td>\n",
       "      <td>0.307031</td>\n",
       "      <td>0.603038</td>\n",
       "      <td>0.091376</td>\n",
       "      <td>0.140766</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Model      HR@5     HR@10     HR@20    MRR@10   NDCG@10\n",
       "0     ItemKNN  0.173611  0.323785  0.611806  0.108583  0.157806\n",
       "1     MultVAE  0.176302  0.321007  0.610243  0.110243  0.158543\n",
       "2  Popularity  0.170660  0.321007  0.611806  0.109012  0.157498\n",
       "3        EASE  0.169271  0.317535  0.605816  0.107657  0.155620\n",
       "4       NeuMF  0.159809  0.307031  0.603038  0.091376  0.140766"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# Evaluate all models\n",
    "\n",
    "eval_users = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "results = []\n",
    "results.append(evaluate_model(recommend_popularity, \"Popularity\", eval_users, ks=TOP_KS))\n",
    "results.append(evaluate_model(recommend_itemknn, \"ItemKNN\", eval_users, ks=TOP_KS))\n",
    "results.append(evaluate_model(recommend_ease, \"EASE\", eval_users, ks=TOP_KS))\n",
    "results.append(evaluate_model(recommend_neumf, \"NeuMF\", eval_users, ks=TOP_KS))\n",
    "results.append(evaluate_model(recommend_multvae, \"MultVAE\", eval_users, ks=TOP_KS))\n",
    "\n",
    "results_df = pd.DataFrame(results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "display(results_df)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdc8372a",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "If the results are still weak, the next iteration should focus on:\n",
    "1. better segment design\n",
    "2. stronger weighting of interaction counts\n",
    "3. tuning EASE regularization\n",
    "4. tuning NeuMF negative sampling and epochs\n",
    "5. tuning MultVAE latent size and KL annealing\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
