{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2e584ae2",
   "metadata": {},
   "source": [
    "# Car recommender notebook — hard formulation, broader model zoo, GPU-heavy neural search\n",
    "\n",
    "This notebook is the corrected follow-up to the earlier car recommender runs.\n",
    "\n",
    "Main changes:\n",
    "- screens out trivial task formulations automatically\n",
    "- uses only harder pseudo-user / item constructions\n",
    "- evaluates a larger set of regular and neural recommenders\n",
    "- keeps the neural training path GPU-friendly with manual on-device batching\n",
    "- produces one combined ranking table for comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dd295bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import gc\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "# Thread settings for NumPy / SciPy BLAS backends.\n",
    "# These must be set before importing numpy/scipy to have the best chance of taking effect.\n",
    "CPU_THREADS = min(16, os.cpu_count() or 1)\n",
    "os.environ[\"OMP_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"MKL_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = str(CPU_THREADS)\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = str(CPU_THREADS)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn.preprocessing import LabelEncoder, normalize\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "\n",
    "print(\"CPU thread target:\", CPU_THREADS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77a4eb83",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Paths and high-level settings\n",
    "\n",
    "BASE_DIR = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5\")\n",
    "CSV_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "if not CSV_PATH.exists():\n",
    "    raise FileNotFoundError(f\"Dataset not found: {CSV_PATH}\")\n",
    "\n",
    "CURRENT_YEAR = 2026\n",
    "\n",
    "# Binning\n",
    "N_PRICE_BINS = 12\n",
    "N_MILEAGE_BINS = 12\n",
    "N_AGE_BINS = 10\n",
    "\n",
    "# Interaction filtering\n",
    "MIN_ITEMS_PER_USER = 5\n",
    "MIN_USERS_PER_ITEM = 10\n",
    "MAX_FILTER_ITERS = 10\n",
    "\n",
    "# If a formulation is too sparse under the default thresholds, progressively relax them.\n",
    "FILTER_SCHEDULE = [\n",
    "    (5, 10),\n",
    "    (4, 8),\n",
    "    (3, 5),\n",
    "    (2, 3),\n",
    "]\n",
    "\n",
    "# Non-triviality constraints for automatic formulation selection\n",
    "MIN_ITEMS_FOR_VALID_FORMULATION = 500\n",
    "MAX_POP_HR10_FOR_VALID_FORMULATION = 0.15\n",
    "MAX_AVG_INTERACTIONS_PER_USER = 50\n",
    "\n",
    "# Ranking metrics\n",
    "TOP_KS = [5, 10, 20]\n",
    "\n",
    "# Regular model search grids\n",
    "ITEMKNN_NEIGHBORS_GRID = [100, 200, 300]\n",
    "EASE_LAMBDA_GRID = [100.0, 200.0, 500.0, 1000.0, 2000.0]\n",
    "P3_ALPHA_GRID = [0.5, 1.0]\n",
    "RP3_GRID = [(0.8, 0.3), (1.0, 0.6)]\n",
    "\n",
    "# Neural training defaults\n",
    "SOFTMAX_EPOCHS = 20\n",
    "TWOTOWER_EPOCHS = 20\n",
    "AUTOENC_EPOCHS = 80\n",
    "PAIRWISE_EPOCHS = 24\n",
    "\n",
    "SOFTMAX_BATCH_SIZE = 16384\n",
    "TWOTOWER_BATCH_SIZE = 8192\n",
    "AUTOENC_BATCH_SIZE = 4096\n",
    "PAIRWISE_BATCH_SIZE = 32768\n",
    "\n",
    "USE_AMP = False  # initialized safely before optional torch import; updated later if neural models are enabled\n",
    "\n",
    "# Harder candidate formulations only.\n",
    "# These all keep the richer item granularity and avoid the trivial 88-item setup.\n",
    "FORMULATIONS = {\n",
    "    \"H1_country_price_mileage_cond_age__brand_model_age_cond\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"AgeBin\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\"],\n",
    "    },\n",
    "    \"H2_country_price_mileage_cond_age_color__brand_model_age_cond_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"AgeBin\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H3_country_price_mileage_age_color__brand_model_age_cond_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"AgeBin\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "    \"H4_country_price_mileage_cond_color__brand_model_age_cond_color\": {\n",
    "        \"user_cols\": [\"Country\", \"PriceBin\", \"MileageBin\", \"Condition\", \"Color\"],\n",
    "        \"item_cols\": [\"Brand\", \"Model\", \"AgeBin\", \"Condition\", \"Color\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "print(\"CSV:\", CSV_PATH)\n",
    "print(\"Formulations:\", list(FORMULATIONS.keys()))\n",
    "\n",
    "EASE_EVAL_BATCH_SIZE = 2048"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44983c26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Runtime toggles\n",
    "\n",
    "# Set to True after you confirm that importing torch works in your environment.\n",
    "# You already tested this successfully in a terminal, so you can enable it here when you want the neural models.\n",
    "RUN_NEURAL = True\n",
    "\n",
    "# If you later need to disable the neural section quickly, set RUN_NEURAL = False.\n",
    "PREFER_CUDA = True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4c1a8ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and clean data\n",
    "\n",
    "df = pd.read_csv(CSV_PATH)\n",
    "\n",
    "expected_cols = [\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\"]\n",
    "missing = [c for c in expected_cols if c not in df.columns]\n",
    "if missing:\n",
    "    raise ValueError(f\"Missing expected columns: {missing}\")\n",
    "\n",
    "for col in [\"Brand\", \"Model\", \"Color\", \"Condition\", \"Country\"]:\n",
    "    df[col] = df[col].astype(str).str.strip().replace({\"\": \"Unknown\", \"nan\": \"Unknown\"})\n",
    "\n",
    "df[\"Year\"] = pd.to_numeric(df[\"Year\"], errors=\"coerce\")\n",
    "df[\"Price\"] = pd.to_numeric(df[\"Price\"], errors=\"coerce\")\n",
    "df[\"Mileage\"] = pd.to_numeric(df[\"Mileage\"], errors=\"coerce\")\n",
    "\n",
    "df = df.dropna(subset=[\"Year\", \"Price\", \"Mileage\"]).copy()\n",
    "\n",
    "df = df[df[\"Year\"].between(1990, CURRENT_YEAR)].copy()\n",
    "df = df[df[\"Price\"] > 0].copy()\n",
    "df = df[df[\"Mileage\"] >= 0].copy()\n",
    "\n",
    "df[\"Age\"] = (CURRENT_YEAR - df[\"Year\"]).clip(lower=0, upper=50)\n",
    "\n",
    "def make_qbin(series, n_bins, prefix):\n",
    "    cat = pd.qcut(series, q=n_bins, duplicates=\"drop\")\n",
    "    return cat.astype(str).str.replace(\",\", \" to \", regex=False).map(lambda x: f\"{prefix}_{x}\")\n",
    "\n",
    "df[\"PriceBin\"] = make_qbin(df[\"Price\"], N_PRICE_BINS, \"price\")\n",
    "df[\"MileageBin\"] = make_qbin(df[\"Mileage\"], N_MILEAGE_BINS, \"mileage\")\n",
    "df[\"AgeBin\"] = make_qbin(df[\"Age\"], N_AGE_BINS, \"age\")\n",
    "\n",
    "print(\"Rows:\", len(df))\n",
    "display(df.head())\n",
    "display(df[[\"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\", \"Color\", \"Condition\", \"Country\", \"Age\", \"PriceBin\", \"MileageBin\", \"AgeBin\"]].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "071498f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "\n",
    "def join_cols(frame, cols, sep):\n",
    "    return frame[cols].astype(str).agg(sep.join, axis=1)\n",
    "\n",
    "def iterative_filter(interactions, min_items_per_user=5, min_users_per_item=10, max_iters=10):\n",
    "    out = interactions.copy()\n",
    "    for _ in range(max_iters):\n",
    "        old_n = out.shape[0]\n",
    "\n",
    "        user_sizes = out.groupby(\"user_id\").size()\n",
    "        keep_users = user_sizes[user_sizes >= min_items_per_user].index\n",
    "        out = out[out[\"user_id\"].isin(keep_users)].copy()\n",
    "\n",
    "        item_sizes = out.groupby(\"item_id\").size()\n",
    "        keep_items = item_sizes[item_sizes >= min_users_per_item].index\n",
    "        out = out[out[\"item_id\"].isin(keep_items)].copy()\n",
    "\n",
    "        if out.shape[0] == old_n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "\n",
    "def build_formulation(work_df, user_cols, item_cols, formulation_name):\n",
    "    tmp = work_df.copy()\n",
    "\n",
    "    tmp[\"user_id\"] = join_cols(tmp, user_cols, \" | \")\n",
    "    tmp[\"item_id\"] = join_cols(tmp, item_cols, \" :: \")\n",
    "\n",
    "    raw_interactions = (\n",
    "        tmp.groupby([\"user_id\", \"item_id\"], as_index=False)\n",
    "           .size()\n",
    "           .rename(columns={\"size\": \"count\"})\n",
    "    )\n",
    "\n",
    "    interactions = None\n",
    "    used_thresholds = None\n",
    "\n",
    "    for min_items_per_user, min_users_per_item in FILTER_SCHEDULE:\n",
    "        candidate = iterative_filter(\n",
    "            raw_interactions,\n",
    "            min_items_per_user=min_items_per_user,\n",
    "            min_users_per_item=min_users_per_item,\n",
    "            max_iters=MAX_FILTER_ITERS\n",
    "        ).reset_index(drop=True)\n",
    "\n",
    "        if not candidate.empty:\n",
    "            interactions = candidate\n",
    "            used_thresholds = (min_items_per_user, min_users_per_item)\n",
    "            break\n",
    "\n",
    "    if interactions is None or interactions.empty:\n",
    "        raise ValueError(\n",
    "            f\"{formulation_name} produced no interactions after filtering, \"\n",
    "            f\"even after trying FILTER_SCHEDULE={FILTER_SCHEDULE}\"\n",
    "        )\n",
    "\n",
    "    valid_users = set(interactions[\"user_id\"])\n",
    "    valid_items = set(interactions[\"item_id\"])\n",
    "\n",
    "    user_encoder = LabelEncoder()\n",
    "    item_encoder = LabelEncoder()\n",
    "\n",
    "    interactions[\"user_idx\"] = user_encoder.fit_transform(interactions[\"user_id\"])\n",
    "    interactions[\"item_idx\"] = item_encoder.fit_transform(interactions[\"item_id\"])\n",
    "\n",
    "    num_users = interactions[\"user_idx\"].nunique()\n",
    "    num_items = interactions[\"item_idx\"].nunique()\n",
    "\n",
    "    user_feature_df = (\n",
    "        tmp[tmp[\"user_id\"].isin(valid_users)][[\"user_id\"] + user_cols]\n",
    "        .drop_duplicates(\"user_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    user_feature_df[\"user_idx\"] = user_encoder.transform(user_feature_df[\"user_id\"])\n",
    "    user_feature_df = user_feature_df.sort_values(\"user_idx\").reset_index(drop=True)\n",
    "\n",
    "    item_feature_df = (\n",
    "        tmp[tmp[\"item_id\"].isin(valid_items)][[\"item_id\"] + item_cols]\n",
    "        .drop_duplicates(\"item_id\")\n",
    "        .copy()\n",
    "    )\n",
    "    item_feature_df[\"item_idx\"] = item_encoder.transform(item_feature_df[\"item_id\"])\n",
    "    item_feature_df = item_feature_df.sort_values(\"item_idx\").reset_index(drop=True)\n",
    "\n",
    "    # Weighted leave-one-out split\n",
    "    rng = np.random.default_rng(RANDOM_STATE)\n",
    "    test_indices = []\n",
    "    for uid, g in interactions.groupby(\"user_idx\"):\n",
    "        weights = g[\"count\"].to_numpy(dtype=np.float64)\n",
    "        probs = weights / weights.sum()\n",
    "        picked = rng.choice(g.index.to_numpy(), size=1, replace=False, p=probs)[0]\n",
    "        test_indices.append(picked)\n",
    "\n",
    "    test_interactions = interactions.loc[test_indices].copy().reset_index(drop=True)\n",
    "    train_interactions = interactions.drop(index=test_indices).copy().reset_index(drop=True)\n",
    "\n",
    "    rows = train_interactions[\"user_idx\"].to_numpy()\n",
    "    cols = train_interactions[\"item_idx\"].to_numpy()\n",
    "    vals = train_interactions[\"count\"].astype(np.float32).to_numpy()\n",
    "\n",
    "    X_counts = csr_matrix((vals, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)\n",
    "    X_binary = X_counts.copy()\n",
    "    X_binary.data = np.ones_like(X_binary.data, dtype=np.float32)\n",
    "\n",
    "    user_seen = {\n",
    "        int(uid): set(g[\"item_idx\"].astype(int).tolist())\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    user_strength = {\n",
    "        int(uid): {int(i): float(c) for i, c in zip(g[\"item_idx\"], g[\"count\"])}\n",
    "        for uid, g in train_interactions.groupby(\"user_idx\")\n",
    "    }\n",
    "\n",
    "    test_item_by_user = {\n",
    "        int(uid): int(i)\n",
    "        for uid, i in zip(test_interactions[\"user_idx\"], test_interactions[\"item_idx\"])\n",
    "    }\n",
    "\n",
    "    global_pop_rank = (\n",
    "        train_interactions.groupby(\"item_idx\")[\"count\"]\n",
    "        .sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .index.to_numpy()\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"name\": formulation_name,\n",
    "        \"user_cols\": user_cols,\n",
    "        \"item_cols\": item_cols,\n",
    "        \"interactions\": interactions,\n",
    "        \"train_interactions\": train_interactions,\n",
    "        \"test_interactions\": test_interactions,\n",
    "        \"user_feature_df\": user_feature_df,\n",
    "        \"item_feature_df\": item_feature_df,\n",
    "        \"user_encoder\": user_encoder,\n",
    "        \"item_encoder\": item_encoder,\n",
    "        \"num_users\": num_users,\n",
    "        \"num_items\": num_items,\n",
    "        \"X_counts\": X_counts,\n",
    "        \"X_binary\": X_binary,\n",
    "        \"user_seen\": user_seen,\n",
    "        \"user_strength\": user_strength,\n",
    "        \"test_item_by_user\": test_item_by_user,\n",
    "        \"global_pop_rank\": global_pop_rank,\n",
    "        \"item_ids\": item_encoder.classes_.tolist(),\n",
    "        \"user_ids\": user_encoder.classes_.tolist(),\n",
    "        \"used_thresholds\": used_thresholds,\n",
    "    }\n",
    "\n",
    "def print_bundle_summary(bundle):\n",
    "\n",
    "    avg_train_per_user = bundle[\"train_interactions\"].shape[0] / max(bundle[\"num_users\"], 1)\n",
    "    print(\"Formulation:\", bundle[\"name\"])\n",
    "    print(\"User cols  :\", bundle[\"user_cols\"])\n",
    "    print(\"Item cols  :\", bundle[\"item_cols\"])\n",
    "    print(\"Users      :\", bundle[\"num_users\"])\n",
    "    print(\"Items      :\", bundle[\"num_items\"])\n",
    "    print(\"Thresholds :\", bundle.get(\"used_thresholds\"))\n",
    "    print(\"Train rows :\", bundle[\"train_interactions\"].shape[0])\n",
    "    print(\"Test rows  :\", bundle[\"test_interactions\"].shape[0])\n",
    "    print(\"Avg train interactions/user:\", round(avg_train_per_user, 3))\n",
    "    print(\"Matrix shape:\", bundle[\"X_binary\"].shape)\n",
    "\n",
    "def topn_from_scores(scores, seen, n):\n",
    "    scores = np.asarray(scores, dtype=np.float32).copy()\n",
    "    if seen:\n",
    "        seen_idx = np.fromiter(seen, dtype=np.int32)\n",
    "        scores[seen_idx] = -np.inf\n",
    "    n = min(int(n), scores.shape[0])\n",
    "    if n <= 0:\n",
    "        return []\n",
    "    idx = np.argpartition(scores, -n)[-n:]\n",
    "    idx = idx[np.argsort(scores[idx])[::-1]]\n",
    "    return idx.astype(int).tolist()\n",
    "\n",
    "def topn_from_torch_scores(scores, seen, n):\n",
    "    x = scores.clone()\n",
    "    if seen:\n",
    "        idx = torch.tensor(list(seen), dtype=torch.long, device=x.device)\n",
    "        x[idx] = -1e9\n",
    "    k = min(int(n), int(x.shape[0]))\n",
    "    return torch.topk(x, k=k).indices.detach().cpu().tolist()\n",
    "\n",
    "def hit_rate_at_k(recs, true_item, k):\n",
    "    return 1.0 if true_item in recs[:k] else 0.0\n",
    "\n",
    "def mrr_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / rank\n",
    "    return 0.0\n",
    "\n",
    "def ndcg_at_k(recs, true_item, k):\n",
    "    recs_k = recs[:k]\n",
    "    if true_item in recs_k:\n",
    "        rank = recs_k.index(true_item) + 1\n",
    "        return 1.0 / math.log2(rank + 1)\n",
    "    return 0.0\n",
    "\n",
    "def evaluate_model(recommend_fn, model_name, bundle, user_indices=None, ks=(5, 10, 20)):\n",
    "    test_item_by_user = bundle[\"test_item_by_user\"]\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr10 = 0.0\n",
    "    ndcg10 = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    for uid in tqdm(user_indices, desc=model_name):\n",
    "        uid = int(uid)\n",
    "        true_item = test_item_by_user.get(uid, None)\n",
    "        if true_item is None:\n",
    "            continue\n",
    "\n",
    "        recs = recommend_fn(uid, n=max(ks))\n",
    "        valid += 1\n",
    "\n",
    "        for k in ks:\n",
    "            hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "        mrr10 += mrr_at_k(recs, true_item, 10)\n",
    "        ndcg10 += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    out = {\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": valid,\n",
    "    }\n",
    "    for k in ks:\n",
    "        out[f\"HR@{k}\"] = hits[k] / max(valid, 1)\n",
    "    out[\"MRR@10\"] = mrr10 / max(valid, 1)\n",
    "    out[\"NDCG@10\"] = ndcg10 / max(valid, 1)\n",
    "    return out\n",
    "\n",
    "def evaluate_ease_batched(\n",
    "    B_matrix,\n",
    "    X_train_matrix,\n",
    "    user_seen_dict,\n",
    "    test_item_by_user,\n",
    "    model_name,\n",
    "    user_indices=None,\n",
    "    ks=(5, 10, 20),\n",
    "    batch_size=2048,\n",
    "):\n",
    "    if user_indices is None:\n",
    "        user_indices = np.array(sorted(test_item_by_user.keys()), dtype=np.int32)\n",
    "    else:\n",
    "        user_indices = np.asarray(user_indices, dtype=np.int32)\n",
    "\n",
    "    max_k = max(ks)\n",
    "    hits = {k: 0.0 for k in ks}\n",
    "    mrr10 = 0.0\n",
    "    ndcg10 = 0.0\n",
    "    valid = 0\n",
    "\n",
    "    for start in tqdm(range(0, len(user_indices), batch_size), desc=model_name):\n",
    "        batch_uids = user_indices[start:start + batch_size]\n",
    "\n",
    "        # Sparse row slice -> dense batch -> batched matrix multiply.\n",
    "        # This is much faster than scoring users one by one in Python.\n",
    "        X_batch = X_train_matrix[batch_uids].toarray().astype(np.float32, copy=False)\n",
    "        scores = X_batch @ B_matrix\n",
    "        scores = np.asarray(scores, dtype=np.float32)\n",
    "\n",
    "        for row_idx, uid in enumerate(batch_uids):\n",
    "            seen = user_seen_dict.get(int(uid), None)\n",
    "            if seen:\n",
    "                seen_idx = np.fromiter(seen, dtype=np.int32)\n",
    "                scores[row_idx, seen_idx] = -np.inf\n",
    "\n",
    "        top_idx = np.argpartition(scores, -max_k, axis=1)[:, -max_k:]\n",
    "        top_scores = np.take_along_axis(scores, top_idx, axis=1)\n",
    "        order = np.argsort(top_scores, axis=1)[:, ::-1]\n",
    "        recs_batch = np.take_along_axis(top_idx, order, axis=1)\n",
    "\n",
    "        for row_idx, uid in enumerate(batch_uids):\n",
    "            uid = int(uid)\n",
    "            true_item = test_item_by_user.get(uid, None)\n",
    "            if true_item is None:\n",
    "                continue\n",
    "\n",
    "            recs = recs_batch[row_idx].astype(int).tolist()\n",
    "            valid += 1\n",
    "\n",
    "            for k in ks:\n",
    "                hits[k] += hit_rate_at_k(recs, true_item, k)\n",
    "            mrr10 += mrr_at_k(recs, true_item, 10)\n",
    "            ndcg10 += ndcg_at_k(recs, true_item, 10)\n",
    "\n",
    "    out = {\n",
    "        \"Model\": model_name,\n",
    "        \"UsersEval\": valid,\n",
    "    }\n",
    "    for k in ks:\n",
    "        out[f\"HR@{k}\"] = hits[k] / max(valid, 1)\n",
    "    out[\"MRR@10\"] = mrr10 / max(valid, 1)\n",
    "    out[\"NDCG@10\"] = ndcg10 / max(valid, 1)\n",
    "    return out\n",
    "\n",
    "def bm25_weight(X, K1=100, B=0.8):\n",
    "    X = X.tocoo(copy=True).astype(np.float32)\n",
    "    N = float(X.shape[0])\n",
    "\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    avgdl = row_sums.mean() + 1e-8\n",
    "\n",
    "    df = np.bincount(X.col, minlength=X.shape[1]).astype(np.float32)\n",
    "    idf = np.log((N - df + 0.5) / (df + 0.5))\n",
    "    idf = np.maximum(idf, 0)\n",
    "\n",
    "    denom = X.data + K1 * (1 - B + B * row_sums[X.row] / avgdl)\n",
    "    data = X.data * (K1 + 1) / denom * idf[X.col]\n",
    "\n",
    "    return csr_matrix((data, (X.row, X.col)), shape=X.shape, dtype=np.float32)\n",
    "\n",
    "def l1_row_normalize(X):\n",
    "    X = X.tocsr().astype(np.float32)\n",
    "    row_sums = np.asarray(X.sum(axis=1)).ravel()\n",
    "    inv = np.zeros_like(row_sums, dtype=np.float32)\n",
    "    mask = row_sums > 0\n",
    "    inv[mask] = 1.0 / row_sums[mask]\n",
    "    return sp.diags(inv) @ X\n",
    "\n",
    "def fit_p3alpha(X_binary, alpha=1.0, beta=0.0):\n",
    "    Pui = l1_row_normalize(X_binary).power(alpha).tocsr()\n",
    "    Piu = l1_row_normalize(X_binary.T).power(alpha).tocsr()\n",
    "    S = (Piu @ Pui).astype(np.float32).tocsr()\n",
    "    S.setdiag(0.0)\n",
    "    S.eliminate_zeros()\n",
    "\n",
    "    if beta > 0:\n",
    "        item_degree = np.asarray(X_binary.sum(axis=0)).ravel().astype(np.float32)\n",
    "        penalty = np.ones_like(item_degree, dtype=np.float32)\n",
    "        mask = item_degree > 0\n",
    "        penalty[mask] = np.power(item_degree[mask], -beta)\n",
    "        S = S @ sp.diags(penalty.astype(np.float32))\n",
    "\n",
    "    return S.tocsr()\n",
    "\n",
    "def make_sparse_similarity_recommender(X_matrix, S_matrix, user_seen_dict):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_matrix.getrow(uid) @ S_matrix\n",
    "        scores = np.asarray(scores.todense()).ravel()\n",
    "        seen = user_seen_dict.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5b918a2",
   "metadata": {},
   "source": [
    "## Quick formulation screening\n",
    "\n",
    "The notebook will build several harder pseudo-user / item formulations, evaluate cheap baselines on each one, and automatically reject trivial setups.\n",
    "\n",
    "The main selection rule is:\n",
    "- enough items\n",
    "- popularity baseline not too strong\n",
    "- average interactions per pseudo-user not too high\n",
    "- strong EASE screen score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d081b84",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Build all candidate formulations\n",
    "\n",
    "bundles = {}\n",
    "summary_rows = []\n",
    "failed_formulations = []\n",
    "\n",
    "for name, cfg in FORMULATIONS.items():\n",
    "    print(\"=\" * 100)\n",
    "    try:\n",
    "        bundle = build_formulation(df, cfg[\"user_cols\"], cfg[\"item_cols\"], name)\n",
    "    except Exception as e:\n",
    "        print(f\"Skipping {name}: {type(e).__name__}: {e}\")\n",
    "        failed_formulations.append({\n",
    "            \"Formulation\": name,\n",
    "            \"Status\": \"failed\",\n",
    "            \"Reason\": f\"{type(e).__name__}: {e}\",\n",
    "        })\n",
    "        continue\n",
    "\n",
    "    bundles[name] = bundle\n",
    "    print_bundle_summary(bundle)\n",
    "\n",
    "    summary_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"TrainRows\": bundle[\"train_interactions\"].shape[0],\n",
    "        \"AvgTrainPerUser\": bundle[\"train_interactions\"].shape[0] / bundle[\"num_users\"],\n",
    "        \"Thresholds\": bundle.get(\"used_thresholds\"),\n",
    "    })\n",
    "\n",
    "if not bundles:\n",
    "    raise RuntimeError(\n",
    "        \"All candidate formulations failed. Try relaxing FILTER_SCHEDULE, reducing bin counts, \"\n",
    "        \"or simplifying item/user definitions.\"\n",
    "    )\n",
    "\n",
    "summary_df = pd.DataFrame(summary_rows).sort_values(\n",
    "    [\"Items\", \"Users\"], ascending=False\n",
    ").reset_index(drop=True)\n",
    "display(summary_df)\n",
    "\n",
    "if failed_formulations:\n",
    "    failed_df = pd.DataFrame(failed_formulations)\n",
    "    display(failed_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da9f1404",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cheap screen:\n",
    "# 1) global popularity\n",
    "# 2) EASE with lambda=200 on binary interactions\n",
    "# Then reject trivial formulations automatically.\n",
    "\n",
    "screen_rows = []\n",
    "\n",
    "for name, bundle in bundles.items():\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Screening formulation:\", name)\n",
    "\n",
    "    def rec_pop(uid, n=10, bundle=bundle):\n",
    "        seen = bundle[\"user_seen\"].get(uid, set())\n",
    "        return [int(i) for i in bundle[\"global_pop_rank\"] if int(i) not in seen][:n]\n",
    "\n",
    "    pop_res = evaluate_model(rec_pop, f\"{name} / Pop\", bundle, ks=TOP_KS)\n",
    "\n",
    "    Xb = bundle[\"X_binary\"]\n",
    "    G = (Xb.T @ Xb).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += 200.0\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    B = B.astype(np.float32)\n",
    "\n",
    "    ease_res = evaluate_ease_batched(\n",
    "        B_matrix=B,\n",
    "        X_train_matrix=Xb,\n",
    "        user_seen_dict=bundle[\"user_seen\"],\n",
    "        test_item_by_user=bundle[\"test_item_by_user\"],\n",
    "        model_name=f\"{name} / EASE200_batched\",\n",
    "        user_indices=None,\n",
    "        ks=TOP_KS,\n",
    "        batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "    )\n",
    "\n",
    "    avg_train_per_user = bundle[\"train_interactions\"].shape[0] / bundle[\"num_users\"]\n",
    "\n",
    "    eligible = (\n",
    "        (bundle[\"num_items\"] >= MIN_ITEMS_FOR_VALID_FORMULATION) and\n",
    "        (pop_res[\"HR@10\"] <= MAX_POP_HR10_FOR_VALID_FORMULATION) and\n",
    "        (avg_train_per_user <= MAX_AVG_INTERACTIONS_PER_USER)\n",
    "    )\n",
    "\n",
    "    screen_score = (\n",
    "        ease_res[\"HR@10\"]\n",
    "        - 0.75 * pop_res[\"HR@10\"]\n",
    "        + 0.01 * np.log1p(bundle[\"num_items\"])\n",
    "    )\n",
    "\n",
    "    screen_rows.append({\n",
    "        \"Formulation\": name,\n",
    "        \"Users\": bundle[\"num_users\"],\n",
    "        \"Items\": bundle[\"num_items\"],\n",
    "        \"AvgTrainPerUser\": avg_train_per_user,\n",
    "        \"Pop_HR@10\": pop_res[\"HR@10\"],\n",
    "        \"EASE200_HR@10\": ease_res[\"HR@10\"],\n",
    "        \"EASE200_NDCG@10\": ease_res[\"NDCG@10\"],\n",
    "        \"Eligible\": eligible,\n",
    "        \"ScreenScore\": screen_score,\n",
    "    })\n",
    "\n",
    "screen_df = pd.DataFrame(screen_rows).sort_values(\n",
    "    [\"Eligible\", \"ScreenScore\", \"EASE200_HR@10\", \"Items\"],\n",
    "    ascending=[False, False, False, False]\n",
    ").reset_index(drop=True)\n",
    "\n",
    "display(screen_df)\n",
    "\n",
    "if screen_df[\"Eligible\"].any():\n",
    "    BEST_FORMULATION = screen_df.loc[screen_df[\"Eligible\"]].iloc[0][\"Formulation\"]\n",
    "else:\n",
    "    BEST_FORMULATION = screen_df.iloc[0][\"Formulation\"]\n",
    "\n",
    "print(\"Selected formulation:\", BEST_FORMULATION)\n",
    "data = bundles[BEST_FORMULATION]\n",
    "print_bundle_summary(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab89fdb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare reusable arrays and feature tables for the selected formulation\n",
    "\n",
    "user_feature_df = data[\"user_feature_df\"].copy()\n",
    "item_feature_df = data[\"item_feature_df\"].copy()\n",
    "\n",
    "train_interactions = data[\"train_interactions\"].copy()\n",
    "test_interactions = data[\"test_interactions\"].copy()\n",
    "\n",
    "X_counts = data[\"X_counts\"].tocsr()\n",
    "X_binary = data[\"X_binary\"].tocsr()\n",
    "\n",
    "num_users = data[\"num_users\"]\n",
    "num_items = data[\"num_items\"]\n",
    "\n",
    "user_seen = data[\"user_seen\"]\n",
    "user_strength = data[\"user_strength\"]\n",
    "test_item_by_user = data[\"test_item_by_user\"]\n",
    "global_pop_rank = data[\"global_pop_rank\"]\n",
    "item_ids = data[\"item_ids\"]\n",
    "\n",
    "X_binary_dense = X_binary.toarray().astype(np.float32)\n",
    "X_counts_dense = X_counts.toarray().astype(np.float32)\n",
    "\n",
    "eval_users = np.array(sorted(test_item_by_user.keys()))\n",
    "\n",
    "print(\"Final working matrix:\", X_binary.shape)\n",
    "display(user_feature_df.head())\n",
    "display(item_feature_df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa1814d8",
   "metadata": {},
   "source": [
    "## Regular model zoo\n",
    "\n",
    "This section includes:\n",
    "- popularity baselines\n",
    "- item-based collaborative filtering\n",
    "- content-based KNN\n",
    "- graph-style recommenders\n",
    "- EASE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1afb92e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Popularity baselines\n",
    "\n",
    "def recommend_popularity(user_idx, n=10):\n",
    "    seen = user_seen.get(int(user_idx), set())\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]\n",
    "\n",
    "country_col = \"Country\" if \"Country\" in user_feature_df.columns else None\n",
    "user_country = {}\n",
    "country_pop_rank = {}\n",
    "\n",
    "if country_col is not None:\n",
    "    user_country = dict(zip(user_feature_df[\"user_idx\"], user_feature_df[country_col]))\n",
    "    train_with_country = train_interactions.merge(\n",
    "        user_feature_df[[\"user_idx\", country_col]],\n",
    "        on=\"user_idx\",\n",
    "        how=\"left\"\n",
    "    )\n",
    "\n",
    "    for country, g in train_with_country.groupby(country_col):\n",
    "        country_pop_rank[country] = (\n",
    "            g.groupby(\"item_idx\")[\"count\"]\n",
    "             .sum()\n",
    "             .sort_values(ascending=False)\n",
    "             .index.to_numpy()\n",
    "        )\n",
    "\n",
    "def recommend_country_popularity(user_idx, n=10):\n",
    "    uid = int(user_idx)\n",
    "    seen = user_seen.get(uid, set())\n",
    "    country = user_country.get(uid, None)\n",
    "\n",
    "    if country in country_pop_rank:\n",
    "        recs = [int(i) for i in country_pop_rank[country] if int(i) not in seen][:n]\n",
    "        if len(recs) >= n:\n",
    "            return recs\n",
    "\n",
    "    return [int(i) for i in global_pop_rank if int(i) not in seen][:n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89251180",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ItemKNN variants\n",
    "\n",
    "def fit_itemknn_recommender(X_matrix, neighbors=100):\n",
    "    item_user = X_matrix.T.tocsr()\n",
    "\n",
    "    knn = NearestNeighbors(\n",
    "        metric=\"cosine\",\n",
    "        algorithm=\"brute\",\n",
    "        n_neighbors=min(neighbors, item_user.shape[0]),\n",
    "        n_jobs=-1\n",
    "    )\n",
    "    knn.fit(item_user)\n",
    "\n",
    "    distances, indices = knn.kneighbors(item_user, n_neighbors=min(neighbors, item_user.shape[0]))\n",
    "    similarities = (1.0 - distances).astype(np.float32)\n",
    "    return indices, similarities\n",
    "\n",
    "def make_itemknn_recommender(X_matrix, neighbors=100, use_strength=True):\n",
    "    neighbor_idx, neighbor_sim = fit_itemknn_recommender(X_matrix, neighbors=neighbors)\n",
    "\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        seen = user_seen.get(uid, set())\n",
    "        strength_map = user_strength.get(uid, {})\n",
    "        scores = {}\n",
    "\n",
    "        for item_idx in seen:\n",
    "            weight = float(strength_map.get(item_idx, 1.0)) if use_strength else 1.0\n",
    "            nbrs = neighbor_idx[item_idx]\n",
    "            sims = neighbor_sim[item_idx]\n",
    "\n",
    "            for j, sim in zip(nbrs, sims):\n",
    "                j = int(j)\n",
    "                if j == item_idx or j in seen:\n",
    "                    continue\n",
    "                scores[j] = scores.get(j, 0.0) + float(sim) * weight\n",
    "\n",
    "        if not scores:\n",
    "            return recommend_popularity(uid, n=n)\n",
    "\n",
    "        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:n]\n",
    "        return [int(i) for i, _ in ranked]\n",
    "\n",
    "    return recommend\n",
    "\n",
    "X_bm25 = bm25_weight(X_counts, K1=100, B=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24fed1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Content-based KNN on item attributes\n",
    "\n",
    "item_feature_cols = data[\"item_cols\"]\n",
    "item_feature_onehot = pd.get_dummies(item_feature_df[item_feature_cols].astype(str), sparse=True)\n",
    "item_feature_sparse = csr_matrix(item_feature_onehot.sparse.to_coo()).astype(np.float32)\n",
    "\n",
    "item_feature_norm = normalize(item_feature_sparse, norm=\"l2\", axis=1)\n",
    "content_sim = (item_feature_norm @ item_feature_norm.T).astype(np.float32).toarray()\n",
    "np.fill_diagonal(content_sim, 0.0)\n",
    "\n",
    "def recommend_content_knn(user_idx, n=10):\n",
    "    uid = int(user_idx)\n",
    "    seen = user_seen.get(uid, set())\n",
    "    user_vec = X_counts.getrow(uid).toarray().ravel().astype(np.float32)\n",
    "    scores = user_vec @ content_sim\n",
    "    return topn_from_scores(scores, seen, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee1b5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Graph recommenders: P3alpha and RP3beta\n",
    "\n",
    "def make_p3_recommender(X_input, alpha=1.0, beta=0.0):\n",
    "    S = fit_p3alpha(X_input, alpha=alpha, beta=beta)\n",
    "    return make_sparse_similarity_recommender(X_input, S, user_seen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "500d74ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EASE\n",
    "\n",
    "def fit_ease(X_matrix, lam):\n",
    "    G = (X_matrix.T @ X_matrix).toarray().astype(np.float32)\n",
    "    G[np.diag_indices_from(G)] += lam\n",
    "    P = np.linalg.inv(G)\n",
    "    B = -P / np.diag(P)\n",
    "    np.fill_diagonal(B, 0.0)\n",
    "    return B.astype(np.float32)\n",
    "\n",
    "def make_ease_recommender(X_train_dense, B_matrix):\n",
    "    def recommend(user_idx, n=10):\n",
    "        uid = int(user_idx)\n",
    "        scores = X_train_dense[uid] @ B_matrix\n",
    "        seen = user_seen.get(uid, set())\n",
    "        return topn_from_scores(scores, seen, n)\n",
    "    return recommend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af5374c9",
   "metadata": {},
   "source": [
    "## Neural model zoo\n",
    "\n",
    "This section favors models that actually keep the GPU busy:\n",
    "- full-softmax MLPs\n",
    "- two-tower retrieval\n",
    "- dense autoencoders\n",
    "- pairwise embedding models\n",
    "- NeuMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "366f6ed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional PyTorch setup for neural models\n",
    "\n",
    "HAS_TORCH = False\n",
    "torch = None\n",
    "nn = None\n",
    "F = None\n",
    "device = None\n",
    "\n",
    "if RUN_NEURAL:\n",
    "    import torch\n",
    "    import torch.nn as nn\n",
    "    import torch.nn.functional as F\n",
    "\n",
    "    HAS_TORCH = True\n",
    "    torch.manual_seed(RANDOM_STATE)\n",
    "\n",
    "    device = torch.device(\"cuda\" if PREFER_CUDA and torch.cuda.is_available() else \"cpu\")\n",
    "    USE_AMP = (device.type == \"cuda\")\n",
    "    print(\"Torch device:\", device)\n",
    "    print(\"USE_AMP:\", USE_AMP)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(RANDOM_STATE)\n",
    "        torch.backends.cuda.matmul.allow_tf32 = True\n",
    "        torch.backends.cudnn.allow_tf32 = True\n",
    "        torch.backends.cudnn.benchmark = True\n",
    "        try:\n",
    "            torch.set_float32_matmul_precision(\"high\")\n",
    "        except Exception:\n",
    "            pass\n",
    "else:\n",
    "    USE_AMP = False\n",
    "    print(\"RUN_NEURAL = False -> skipping PyTorch import and all neural models.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89dad4e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Feature encoders for neural models\n",
    "\n",
    "    user_cols = data[\"user_cols\"]\n",
    "    item_cols = data[\"item_cols\"]\n",
    "\n",
    "    feature_encoders = {}\n",
    "    for col in sorted(set(user_cols + item_cols)):\n",
    "        enc = LabelEncoder()\n",
    "        values = pd.concat([\n",
    "            user_feature_df[col].astype(str) if col in user_feature_df.columns else pd.Series(dtype=str),\n",
    "            item_feature_df[col].astype(str) if col in item_feature_df.columns else pd.Series(dtype=str),\n",
    "        ], ignore_index=True)\n",
    "        enc.fit(values)\n",
    "        feature_encoders[col] = enc\n",
    "\n",
    "    user_feature_arrays = {\n",
    "        col: feature_encoders[col].transform(user_feature_df[col].astype(str))\n",
    "        for col in user_cols\n",
    "    }\n",
    "    item_feature_arrays = {\n",
    "        col: feature_encoders[col].transform(item_feature_df[col].astype(str))\n",
    "        for col in item_cols\n",
    "    }\n",
    "\n",
    "    user_feature_tensors = {\n",
    "        col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "        for col, arr in user_feature_arrays.items()\n",
    "    }\n",
    "    item_feature_tensors = {\n",
    "        col: torch.tensor(arr, dtype=torch.long, device=device)\n",
    "        for col, arr in item_feature_arrays.items()\n",
    "    }\n",
    "\n",
    "    positive_user_idx = torch.tensor(train_interactions[\"user_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "    positive_item_idx = torch.tensor(train_interactions[\"item_idx\"].to_numpy(), dtype=torch.long, device=device)\n",
    "    positive_weight = torch.tensor(train_interactions[\"count\"].astype(np.float32).to_numpy(), dtype=torch.float32, device=device)\n",
    "\n",
    "    X_binary_dense_tensor = torch.tensor(X_binary_dense, dtype=torch.float32, device=device)\n",
    "    X_counts_dense_tensor = torch.tensor(X_counts_dense, dtype=torch.float32, device=device)\n",
    "\n",
    "    n_positive_rows = int(positive_user_idx.shape[0])\n",
    "    print(\"Positive interaction rows:\", n_positive_rows)\n",
    "    print(\"Dense matrix shape:\", tuple(X_binary_dense_tensor.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fe48d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Manual batch iterators\n",
    "\n",
    "    def iterate_positive_batches(batch_size, shuffle=True):\n",
    "        n = n_positive_rows\n",
    "        order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "        for start in range(0, n, batch_size):\n",
    "            idx = order[start:start + batch_size]\n",
    "            yield positive_user_idx[idx], positive_item_idx[idx], positive_weight[idx]\n",
    "\n",
    "    def iterate_dense_user_batches(batch_size, shuffle=True):\n",
    "        n = int(X_binary_dense_tensor.shape[0])\n",
    "        order = torch.randperm(n, device=device) if shuffle else torch.arange(n, device=device)\n",
    "        for start in range(0, n, batch_size):\n",
    "            idx = order[start:start + batch_size]\n",
    "            yield X_binary_dense_tensor[idx], idx\n",
    "\n",
    "    def make_user_feature_batch(user_idx_batch):\n",
    "        return {col: user_feature_tensors[col][user_idx_batch] for col in user_cols}\n",
    "\n",
    "    def make_item_feature_batch(item_idx_batch):\n",
    "        return {col: item_feature_tensors[col][item_idx_batch] for col in item_cols}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f30d69b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Shared building blocks\n",
    "\n",
    "    class FeatureEmbeddingBlock(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, emb_dim):\n",
    "            super().__init__()\n",
    "            self.feature_names = list(feature_cardinalities.keys())\n",
    "            self.embs = nn.ModuleDict({\n",
    "                name: nn.Embedding(card, emb_dim)\n",
    "                for name, card in feature_cardinalities.items()\n",
    "            })\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            parts = [self.embs[name](feature_batch[name]) for name in self.feature_names]\n",
    "            return torch.cat(parts, dim=-1)\n",
    "\n",
    "    class MLPBlock(nn.Module):\n",
    "        def __init__(self, input_dim, hidden_dims, dropout=0.1, final_dim=None, final_activation=False):\n",
    "            super().__init__()\n",
    "            layers = []\n",
    "            last = input_dim\n",
    "            for h in hidden_dims:\n",
    "                layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]\n",
    "                last = h\n",
    "            if final_dim is not None:\n",
    "                layers.append(nn.Linear(last, final_dim))\n",
    "                last = final_dim\n",
    "                if final_activation:\n",
    "                    layers.append(nn.ReLU())\n",
    "            self.net = nn.Sequential(*layers)\n",
    "            self.output_dim = last\n",
    "\n",
    "        def forward(self, x):\n",
    "            return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c79d1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 1: full-softmax feature MLP\n",
    "\n",
    "    class SoftmaxSegmentMLP(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, num_items, emb_dim=64, hidden_dims=(512, 512, 256), dropout=0.15):\n",
    "            super().__init__()\n",
    "            self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "            input_dim = len(feature_cardinalities) * emb_dim\n",
    "            self.mlp = MLPBlock(input_dim, hidden_dims, dropout=dropout)\n",
    "            self.out = nn.Linear(self.mlp.output_dim, num_items)\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            x = self.encoder(feature_batch)\n",
    "            x = self.mlp(x)\n",
    "            return self.out(x)\n",
    "\n",
    "    def train_softmax_model(model_name, emb_dim=64, hidden_dims=(512, 512, 256), epochs=20, lr=2e-3, wd=1e-5):\n",
    "        feature_cardinalities = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "        model = SoftmaxSegmentMLP(\n",
    "            feature_cardinalities=feature_cardinalities,\n",
    "            num_items=num_items,\n",
    "            emb_dim=emb_dim,\n",
    "            hidden_dims=hidden_dims,\n",
    "            dropout=0.15,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "        criterion = nn.CrossEntropyLoss(reduction=\"none\")\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "            for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(SOFTMAX_BATCH_SIZE, shuffle=True):\n",
    "                feat_batch = make_user_feature_batch(user_idx_batch)\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model(feat_batch)\n",
    "                    loss_vec = criterion(logits, item_idx_batch)\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_softmax_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "            feat_batch = make_user_feature_batch(u)\n",
    "            logits = model(feat_batch).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(logits, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2fc4897",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 2: feature two-tower with in-batch negatives\n",
    "\n",
    "    class TowerEncoder(nn.Module):\n",
    "        def __init__(self, feature_cardinalities, emb_dim=64, hidden_dims=(512, 256), out_dim=128, dropout=0.1):\n",
    "            super().__init__()\n",
    "            self.encoder = FeatureEmbeddingBlock(feature_cardinalities, emb_dim)\n",
    "            input_dim = len(feature_cardinalities) * emb_dim\n",
    "            self.net = MLPBlock(input_dim, hidden_dims, dropout=dropout, final_dim=out_dim)\n",
    "\n",
    "        def forward(self, feature_batch):\n",
    "            x = self.encoder(feature_batch)\n",
    "            x = self.net(x)\n",
    "            return F.normalize(x, dim=-1)\n",
    "\n",
    "    class TwoTowerModel(nn.Module):\n",
    "        def __init__(self, user_feature_cards, item_feature_cards, emb_dim=64, hidden_dims=(512, 256), out_dim=128):\n",
    "            super().__init__()\n",
    "            self.user_tower = TowerEncoder(user_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "            self.item_tower = TowerEncoder(item_feature_cards, emb_dim=emb_dim, hidden_dims=hidden_dims, out_dim=out_dim)\n",
    "\n",
    "    def train_two_tower(model_name, emb_dim=64, hidden_dims=(512, 256), out_dim=128, epochs=20, lr=2e-3, wd=1e-5, temperature=0.07):\n",
    "        user_feature_cards = {col: len(feature_encoders[col].classes_) for col in user_cols}\n",
    "        item_feature_cards = {col: len(feature_encoders[col].classes_) for col in item_cols}\n",
    "\n",
    "        model = TwoTowerModel(\n",
    "            user_feature_cards=user_feature_cards,\n",
    "            item_feature_cards=item_feature_cards,\n",
    "            emb_dim=emb_dim,\n",
    "            hidden_dims=hidden_dims,\n",
    "            out_dim=out_dim,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "            for user_idx_batch, item_idx_batch, weight_batch in iterate_positive_batches(TWOTOWER_BATCH_SIZE, shuffle=True):\n",
    "                user_batch = make_user_feature_batch(user_idx_batch)\n",
    "                item_batch = make_item_feature_batch(item_idx_batch)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    user_vec = model.user_tower(user_batch)\n",
    "                    item_vec = model.item_tower(item_batch)\n",
    "                    logits = (user_vec @ item_vec.T) / temperature\n",
    "                    targets = torch.arange(logits.shape[0], device=device)\n",
    "                    loss_vec = F.cross_entropy(logits, targets, reduction=\"none\")\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_two_tower_recommender(model):\n",
    "        model.eval()\n",
    "        all_item_idx = torch.arange(num_items, dtype=torch.long, device=device)\n",
    "        item_matrix_parts = []\n",
    "        with torch.no_grad():\n",
    "            for start in range(0, num_items, 4096):\n",
    "                idx = all_item_idx[start:start + 4096]\n",
    "                batch = make_item_feature_batch(idx)\n",
    "                item_matrix_parts.append(model.item_tower(batch))\n",
    "        item_matrix = torch.cat(item_matrix_parts, dim=0)\n",
    "\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = torch.tensor([uid], dtype=torch.long, device=device)\n",
    "            user_batch = make_user_feature_batch(u)\n",
    "            user_vec = model.user_tower(user_batch)\n",
    "            scores = (user_vec @ item_matrix.T).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "918da172",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 3: dense autoencoders\n",
    "\n",
    "    class MultDAE(nn.Module):\n",
    "        def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.2):\n",
    "            super().__init__()\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "            self.encoder = nn.Sequential(\n",
    "                nn.Linear(num_items, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, latent_dim),\n",
    "                nn.Tanh(),\n",
    "            )\n",
    "            self.decoder = nn.Sequential(\n",
    "                nn.Linear(latent_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, num_items),\n",
    "            )\n",
    "\n",
    "        def forward(self, x):\n",
    "            z = self.encoder(self.dropout(x))\n",
    "            logits = self.decoder(z)\n",
    "            return logits\n",
    "\n",
    "    class MultVAE(nn.Module):\n",
    "        def __init__(self, num_items, hidden_dim=1024, latent_dim=256, dropout=0.3):\n",
    "            super().__init__()\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "            self.encoder = nn.Sequential(\n",
    "                nn.Linear(num_items, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "            )\n",
    "            self.mu = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.decoder = nn.Sequential(\n",
    "                nn.Linear(latent_dim, hidden_dim),\n",
    "                nn.Tanh(),\n",
    "                nn.Linear(hidden_dim, num_items),\n",
    "            )\n",
    "\n",
    "        def encode(self, x):\n",
    "            h = self.encoder(self.dropout(x))\n",
    "            return self.mu(h), self.logvar(h)\n",
    "\n",
    "        def reparameterize(self, mu, logvar):\n",
    "            if self.training:\n",
    "                std = torch.exp(0.5 * logvar)\n",
    "                eps = torch.randn_like(std)\n",
    "                return mu + eps * std\n",
    "            return mu\n",
    "\n",
    "        def forward(self, x):\n",
    "            mu, logvar = self.encode(x)\n",
    "            z = self.reparameterize(mu, logvar)\n",
    "            logits = self.decoder(z)\n",
    "            return logits, mu, logvar\n",
    "\n",
    "    def train_multdae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.2, epochs=80, lr=1e-3, wd=0.0):\n",
    "        model = MultDAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_rows = 0\n",
    "            for batch_x, _ in iterate_dense_user_batches(AUTOENC_BATCH_SIZE, shuffle=True):\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model(batch_x)\n",
    "                    loss = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1).mean()\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "                total_rows += batch_x.shape[0]\n",
    "\n",
    "            if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "                print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def train_multvae(model_name, hidden_dim=1024, latent_dim=256, dropout=0.3, epochs=80, lr=1e-3, wd=0.0, anneal_cap=1.0):\n",
    "        model = MultVAE(num_items=num_items, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout).to(device)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        steps_per_epoch = max(1, math.ceil(num_users / AUTOENC_BATCH_SIZE))\n",
    "        total_steps = max(1, epochs * steps_per_epoch)\n",
    "        step = 0\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_rows = 0\n",
    "\n",
    "            for batch_x, _ in iterate_dense_user_batches(AUTOENC_BATCH_SIZE, shuffle=True):\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                anneal = min(anneal_cap, step / max(total_steps * 0.3, 1))\n",
    "\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits, mu, logvar = model(batch_x)\n",
    "                    recon = -(F.log_softmax(logits, dim=1) * batch_x).sum(dim=1)\n",
    "                    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)\n",
    "                    loss = (recon + anneal * kl).mean()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * batch_x.shape[0]\n",
    "                total_rows += batch_x.shape[0]\n",
    "                step += 1\n",
    "\n",
    "            if (epoch + 1) == 1 or (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:\n",
    "                print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_rows, 1):.6f} - anneal: {anneal:.3f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_multdae_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            x = X_binary_dense_tensor[uid:uid + 1]\n",
    "            logits = model(x).squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(logits, seen, n)\n",
    "        return recommend\n",
    "\n",
    "    def make_multvae_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            x = X_binary_dense_tensor[uid:uid + 1]\n",
    "            logits, _, _ = model(x)\n",
    "            scores = logits.squeeze(0)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d7bc556",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 4: BPR matrix factorization\n",
    "\n",
    "    class BPRMF(nn.Module):\n",
    "        def __init__(self, num_users, num_items, dim=128):\n",
    "            super().__init__()\n",
    "            self.user_emb = nn.Embedding(num_users, dim)\n",
    "            self.item_emb = nn.Embedding(num_items, dim)\n",
    "            self.item_bias = nn.Embedding(num_items, 1)\n",
    "\n",
    "            nn.init.normal_(self.user_emb.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_emb.weight, std=0.02)\n",
    "            nn.init.zeros_(self.item_bias.weight)\n",
    "\n",
    "        def forward(self, user_idx, pos_idx, neg_idx):\n",
    "            u = self.user_emb(user_idx)\n",
    "            p = self.item_emb(pos_idx)\n",
    "            n = self.item_emb(neg_idx)\n",
    "            pb = self.item_bias(pos_idx).squeeze(-1)\n",
    "            nb = self.item_bias(neg_idx).squeeze(-1)\n",
    "\n",
    "            pos_scores = (u * p).sum(dim=1) + pb\n",
    "            neg_scores = (u * n).sum(dim=1) + nb\n",
    "            return pos_scores, neg_scores\n",
    "\n",
    "    def train_bprmf(model_name, dim=128, epochs=24, lr=2e-3, wd=1e-6):\n",
    "        model = BPRMF(num_users=num_users, num_items=num_items, dim=dim).to(device)\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "\n",
    "            for user_idx_batch, pos_item_batch, weight_batch in iterate_positive_batches(PAIRWISE_BATCH_SIZE, shuffle=True):\n",
    "                neg_item_batch = torch.randint(0, num_items, size=pos_item_batch.shape, device=device)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    pos_scores, neg_scores = model(user_idx_batch, pos_item_batch, neg_item_batch)\n",
    "                    loss_vec = -F.logsigmoid(pos_scores - neg_scores)\n",
    "                    loss = (loss_vec * weight_batch).sum() / weight_batch.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(weight_batch.sum().item())\n",
    "                total_w += float(weight_batch.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_bprmf_recommender(model):\n",
    "        item_matrix = model.item_emb.weight\n",
    "        item_bias = model.item_bias.weight.squeeze(-1)\n",
    "\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            u = model.user_emb.weight[uid]\n",
    "            scores = item_matrix @ u + item_bias\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48fb928d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    print('Skipping cell because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Neural model 5: NeuMF\n",
    "\n",
    "    class NeuMF(nn.Module):\n",
    "        def __init__(self, num_users, num_items, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128)):\n",
    "            super().__init__()\n",
    "            self.user_mf = nn.Embedding(num_users, mf_dim)\n",
    "            self.item_mf = nn.Embedding(num_items, mf_dim)\n",
    "            self.user_mlp = nn.Embedding(num_users, mlp_dim)\n",
    "            self.item_mlp = nn.Embedding(num_items, mlp_dim)\n",
    "\n",
    "            layers = []\n",
    "            last = mlp_dim * 2\n",
    "            for h in hidden_dims:\n",
    "                layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(0.1)]\n",
    "                last = h\n",
    "            self.mlp = nn.Sequential(*layers)\n",
    "            self.out = nn.Linear(last + mf_dim, 1)\n",
    "\n",
    "            nn.init.normal_(self.user_mf.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_mf.weight, std=0.02)\n",
    "            nn.init.normal_(self.user_mlp.weight, std=0.02)\n",
    "            nn.init.normal_(self.item_mlp.weight, std=0.02)\n",
    "\n",
    "        def score(self, user_idx, item_idx):\n",
    "            mf_u = self.user_mf(user_idx)\n",
    "            mf_i = self.item_mf(item_idx)\n",
    "            mf = mf_u * mf_i\n",
    "\n",
    "            mlp_u = self.user_mlp(user_idx)\n",
    "            mlp_i = self.item_mlp(item_idx)\n",
    "            mlp = self.mlp(torch.cat([mlp_u, mlp_i], dim=-1))\n",
    "\n",
    "            x = torch.cat([mf, mlp], dim=-1)\n",
    "            return self.out(x).squeeze(-1)\n",
    "\n",
    "    def train_neumf(model_name, mf_dim=64, mlp_dim=128, hidden_dims=(256, 128), epochs=24, lr=2e-3, wd=1e-6):\n",
    "        model = NeuMF(num_users=num_users, num_items=num_items, mf_dim=mf_dim, mlp_dim=mlp_dim, hidden_dims=hidden_dims).to(device)\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n",
    "        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)\n",
    "\n",
    "        model.train()\n",
    "        for epoch in range(epochs):\n",
    "            total_loss = 0.0\n",
    "            total_w = 0.0\n",
    "\n",
    "            for user_idx_batch, pos_item_batch, weight_batch in iterate_positive_batches(PAIRWISE_BATCH_SIZE, shuffle=True):\n",
    "                neg_item_batch = torch.randint(0, num_items, size=pos_item_batch.shape, device=device)\n",
    "\n",
    "                user_cat = torch.cat([user_idx_batch, user_idx_batch], dim=0)\n",
    "                item_cat = torch.cat([pos_item_batch, neg_item_batch], dim=0)\n",
    "                target = torch.cat([\n",
    "                    torch.ones_like(pos_item_batch, dtype=torch.float32),\n",
    "                    torch.zeros_like(neg_item_batch, dtype=torch.float32),\n",
    "                ], dim=0)\n",
    "                sample_weight = torch.cat([weight_batch, weight_batch], dim=0)\n",
    "\n",
    "                optimizer.zero_grad(set_to_none=True)\n",
    "                with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=USE_AMP):\n",
    "                    logits = model.score(user_cat, item_cat)\n",
    "                    loss_vec = F.binary_cross_entropy_with_logits(logits, target, reduction=\"none\")\n",
    "                    loss = (loss_vec * sample_weight).sum() / sample_weight.sum()\n",
    "\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(optimizer)\n",
    "                scaler.update()\n",
    "\n",
    "                total_loss += float(loss.item()) * float(sample_weight.sum().item())\n",
    "                total_w += float(sample_weight.sum().item())\n",
    "\n",
    "            print(f\"{model_name} epoch {epoch + 1}/{epochs} - loss: {total_loss / max(total_w, 1e-8):.6f}\")\n",
    "\n",
    "        return model.eval()\n",
    "\n",
    "    def make_neumf_recommender(model):\n",
    "        @torch.no_grad()\n",
    "        def recommend(user_idx, n=10):\n",
    "            uid = int(user_idx)\n",
    "            users = torch.full((num_items,), uid, dtype=torch.long, device=device)\n",
    "            items = torch.arange(num_items, dtype=torch.long, device=device)\n",
    "            scores = model.score(users, items)\n",
    "            seen = user_seen.get(uid, set())\n",
    "            return topn_from_torch_scores(scores, seen, n)\n",
    "        return recommend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9e63b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit and evaluate regular models\n",
    "\n",
    "regular_results = []\n",
    "regular_models = {}\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: Popularity\")\n",
    "regular_models[\"Popularity\"] = recommend_popularity\n",
    "regular_results.append(evaluate_model(recommend_popularity, \"Popularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: CountryPopularity\")\n",
    "regular_models[\"CountryPopularity\"] = recommend_country_popularity\n",
    "regular_results.append(evaluate_model(recommend_country_popularity, \"CountryPopularity\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "print(\"=\" * 100)\n",
    "print(\"Evaluating regular baseline: ContentKNN\")\n",
    "regular_models[\"ContentKNN\"] = recommend_content_knn\n",
    "regular_results.append(evaluate_model(recommend_content_knn, \"ContentKNN\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN binary neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_binary, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_binary_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for k in ITEMKNN_NEIGHBORS_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting ItemKNN BM25 neighbors={k}\")\n",
    "    rec = make_itemknn_recommender(X_bm25, neighbors=k, use_strength=True)\n",
    "    name = f\"ItemKNN_bm25_k{k}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for alpha in P3_ALPHA_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting P3alpha alpha={alpha}\")\n",
    "    rec = make_p3_recommender(X_binary, alpha=alpha, beta=0.0)\n",
    "    name = f\"P3alpha_a{str(alpha).replace('.', '_')}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for alpha, beta in RP3_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting RP3beta alpha={alpha} beta={beta}\")\n",
    "    rec = make_p3_recommender(X_binary, alpha=alpha, beta=beta)\n",
    "    name = f\"RP3beta_a{str(alpha).replace('.', '_')}_b{str(beta).replace('.', '_')}\"\n",
    "    regular_models[name] = rec\n",
    "    regular_results.append(evaluate_model(rec, name, data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "for lam in EASE_LAMBDA_GRID:\n",
    "    print(\"=\" * 100)\n",
    "    print(f\"Fitting EASE binary lambda={lam}\")\n",
    "    B_bin = fit_ease(X_binary, lam=lam)\n",
    "    rec_bin = make_ease_recommender(X_binary_dense, B_bin)\n",
    "    name_bin = f\"EASE_binary_l{int(lam)}\"\n",
    "    regular_models[name_bin] = rec_bin\n",
    "    regular_results.append(\n",
    "        evaluate_ease_batched(\n",
    "            B_matrix=B_bin,\n",
    "            X_train_matrix=X_binary,\n",
    "            user_seen_dict=user_seen,\n",
    "            test_item_by_user=test_item_by_user,\n",
    "            model_name=name_bin,\n",
    "            user_indices=eval_users,\n",
    "            ks=TOP_KS,\n",
    "            batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    print(f\"Fitting EASE count lambda={lam}\")\n",
    "    B_cnt = fit_ease(X_counts, lam=lam)\n",
    "    rec_cnt = make_ease_recommender(X_counts_dense, B_cnt)\n",
    "    name_cnt = f\"EASE_count_l{int(lam)}\"\n",
    "    regular_models[name_cnt] = rec_cnt\n",
    "    regular_results.append(\n",
    "        evaluate_ease_batched(\n",
    "            B_matrix=B_cnt,\n",
    "            X_train_matrix=X_counts,\n",
    "            user_seen_dict=user_seen,\n",
    "            test_item_by_user=test_item_by_user,\n",
    "            model_name=name_cnt,\n",
    "            user_indices=eval_users,\n",
    "            ks=TOP_KS,\n",
    "            batch_size=EASE_EVAL_BATCH_SIZE,\n",
    "        )\n",
    "    )\n",
    "\n",
    "regular_results_df = pd.DataFrame(regular_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "display(regular_results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bdd000a",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RUN_NEURAL:\n",
    "    neural_results = []\n",
    "    neural_models = {}\n",
    "    neural_results_df = pd.DataFrame(columns=['Model','UsersEval','HR@5','HR@10','HR@20','MRR@10','NDCG@10'])\n",
    "    print('Skipping neural model training because RUN_NEURAL = False')\n",
    "else:\n",
    "    # Fit and evaluate neural models\n",
    "\n",
    "    neural_results = []\n",
    "    neural_models = {}\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training SoftmaxMLP_base\")\n",
    "    softmax_base = train_softmax_model(\n",
    "        model_name=\"SoftmaxMLP_base\",\n",
    "        emb_dim=64,\n",
    "        hidden_dims=(512, 512, 256),\n",
    "        epochs=SOFTMAX_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-5,\n",
    "    )\n",
    "    rec = make_softmax_recommender(softmax_base)\n",
    "    neural_models[\"SoftmaxMLP_base\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"SoftmaxMLP_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training SoftmaxMLP_large\")\n",
    "    softmax_large = train_softmax_model(\n",
    "        model_name=\"SoftmaxMLP_large\",\n",
    "        emb_dim=96,\n",
    "        hidden_dims=(1024, 1024, 512, 256),\n",
    "        epochs=SOFTMAX_EPOCHS,\n",
    "        lr=1.5e-3,\n",
    "        wd=1e-5,\n",
    "    )\n",
    "    rec = make_softmax_recommender(softmax_large)\n",
    "    neural_models[\"SoftmaxMLP_large\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"SoftmaxMLP_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training TwoTower_base\")\n",
    "    twotower_base = train_two_tower(\n",
    "        model_name=\"TwoTower_base\",\n",
    "        emb_dim=64,\n",
    "        hidden_dims=(512, 256),\n",
    "        out_dim=128,\n",
    "        epochs=TWOTOWER_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-5,\n",
    "        temperature=0.07,\n",
    "    )\n",
    "    rec = make_two_tower_recommender(twotower_base)\n",
    "    neural_models[\"TwoTower_base\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"TwoTower_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training TwoTower_large\")\n",
    "    twotower_large = train_two_tower(\n",
    "        model_name=\"TwoTower_large\",\n",
    "        emb_dim=96,\n",
    "        hidden_dims=(1024, 512, 256),\n",
    "        out_dim=192,\n",
    "        epochs=TWOTOWER_EPOCHS,\n",
    "        lr=1.5e-3,\n",
    "        wd=1e-5,\n",
    "        temperature=0.05,\n",
    "    )\n",
    "    rec = make_two_tower_recommender(twotower_large)\n",
    "    neural_models[\"TwoTower_large\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"TwoTower_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training MultDAE\")\n",
    "    multdae = train_multdae(\n",
    "        model_name=\"MultDAE\",\n",
    "        hidden_dim=1024,\n",
    "        latent_dim=256,\n",
    "        dropout=0.2,\n",
    "        epochs=AUTOENC_EPOCHS,\n",
    "        lr=1e-3,\n",
    "        wd=0.0,\n",
    "    )\n",
    "    rec = make_multdae_recommender(multdae)\n",
    "    neural_models[\"MultDAE\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"MultDAE\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training MultVAE_base\")\n",
    "    multvae_base = train_multvae(\n",
    "        model_name=\"MultVAE_base\",\n",
    "        hidden_dim=1024,\n",
    "        latent_dim=256,\n",
    "        dropout=0.25,\n",
    "        epochs=AUTOENC_EPOCHS,\n",
    "        lr=1e-3,\n",
    "        wd=0.0,\n",
    "    )\n",
    "    rec = make_multvae_recommender(multvae_base)\n",
    "    neural_models[\"MultVAE_base\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"MultVAE_base\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training MultVAE_large\")\n",
    "    multvae_large = train_multvae(\n",
    "        model_name=\"MultVAE_large\",\n",
    "        hidden_dim=1536,\n",
    "        latent_dim=384,\n",
    "        dropout=0.30,\n",
    "        epochs=AUTOENC_EPOCHS,\n",
    "        lr=8e-4,\n",
    "        wd=0.0,\n",
    "    )\n",
    "    rec = make_multvae_recommender(multvae_large)\n",
    "    neural_models[\"MultVAE_large\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"MultVAE_large\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training BPRMF\")\n",
    "    bprmf = train_bprmf(\n",
    "        model_name=\"BPRMF\",\n",
    "        dim=128,\n",
    "        epochs=PAIRWISE_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-6,\n",
    "    )\n",
    "    rec = make_bprmf_recommender(bprmf)\n",
    "    neural_models[\"BPRMF\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"BPRMF\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    print(\"=\" * 100)\n",
    "    print(\"Training NeuMF\")\n",
    "    neumf = train_neumf(\n",
    "        model_name=\"NeuMF\",\n",
    "        mf_dim=64,\n",
    "        mlp_dim=128,\n",
    "        hidden_dims=(256, 128),\n",
    "        epochs=PAIRWISE_EPOCHS,\n",
    "        lr=2e-3,\n",
    "        wd=1e-6,\n",
    "    )\n",
    "    rec = make_neumf_recommender(neumf)\n",
    "    neural_models[\"NeuMF\"] = rec\n",
    "    neural_results.append(evaluate_model(rec, \"NeuMF\", data, user_indices=eval_users, ks=TOP_KS))\n",
    "\n",
    "    neural_results_df = pd.DataFrame(neural_results).sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "    display(neural_results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1ab8431",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Final combined comparison\n",
    "\n",
    "frames = [regular_results_df]\n",
    "if isinstance(neural_results_df, pd.DataFrame) and not neural_results_df.empty:\n",
    "    frames.append(neural_results_df)\n",
    "\n",
    "all_results_df = pd.concat(frames, ignore_index=True)\n",
    "all_results_df = all_results_df.sort_values([\"HR@10\", \"NDCG@10\", \"MRR@10\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "print(\"Top regular models:\")\n",
    "display(regular_results_df.head(12))\n",
    "\n",
    "if isinstance(neural_results_df, pd.DataFrame) and not neural_results_df.empty:\n",
    "    print(\"Top neural models:\")\n",
    "    display(neural_results_df.head(12))\n",
    "else:\n",
    "    print(\"Neural models were skipped.\")\n",
    "\n",
    "print(\"Overall ranking:\")\n",
    "display(all_results_df)\n",
    "\n",
    "best_model_name = all_results_df.iloc[0][\"Model\"]\n",
    "print(\"Best overall model:\", best_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed7dcdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show example recommendations from the best overall model\n",
    "\n",
    "all_model_funcs = {}\n",
    "all_model_funcs.update(regular_models)\n",
    "if 'neural_models' in globals():\n",
    "    all_model_funcs.update(neural_models)\n",
    "\n",
    "best_recommender = all_model_funcs[best_model_name]\n",
    "\n",
    "example_user = int(eval_users[0])\n",
    "example_item_ids = best_recommender(example_user, n=10)\n",
    "example_item_names = [item_ids[i] for i in example_item_ids]\n",
    "\n",
    "print(\"Example pseudo-user index:\", example_user)\n",
    "display(user_feature_df[user_feature_df[\"user_idx\"] == example_user])\n",
    "\n",
    "example_df = pd.DataFrame({\n",
    "    \"rank\": np.arange(1, len(example_item_names) + 1),\n",
    "    \"recommended_item\": example_item_names,\n",
    "})\n",
    "display(example_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26869d31",
   "metadata": {},
   "source": [
    "## Notes\n",
    "\n",
    "If you still want even harder settings after this run, the next levers are:\n",
    "- raise item granularity again\n",
    "- require lower popularity ceiling for formulation selection\n",
    "- increase negative-sampling pressure for pairwise neural models\n",
    "- tune the strongest 2 to 3 models only instead of running the full zoo"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
