{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f43dd523",
   "metadata": {},
   "source": [
    "# Car recommendation notebook (initial version)\n",
    "\n",
    "## Task framing\n",
    "\n",
    "This dataset does not contain rich repeated user histories like MovieLens, so the car task is framed as a **context-aware top-N recommendation** problem:\n",
    "\n",
    "- **Input/query**: a desired car profile and customer context  \n",
    "  (`Year`, `Price`, `Mileage`, `Color`, `Condition`, `Country`)\n",
    "- **Target item**: car **Brand + Model**\n",
    "- **Output**: top-N recommended car models for that query\n",
    "\n",
    "This is the closest honest recommender formulation for the current schema.\n",
    "\n",
    "## Model set in this initial version\n",
    "\n",
    "### Regular models\n",
    "1. Popularity by country\n",
    "2. Context KNN\n",
    "3. Linear softmax classifier\n",
    "\n",
    "### Neural models\n",
    "4. MLP classifier\n",
    "5. Two-tower retrieval model\n",
    "\n",
    "## Metrics\n",
    "The models are evaluated with:\n",
    "- `HitRate@5`, `HitRate@10`, `HitRate@20`\n",
    "- `MRR@10`\n",
    "- `NDCG@10`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2188f21b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch device: cuda\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import math\n",
    "import random\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "random.seed(RANDOM_STATE)\n",
    "torch.manual_seed(RANDOM_STATE)\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Torch device:\", DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3883bf1a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using dataset: /home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\n"
     ]
    }
   ],
   "source": [
    "# Paths and core settings\n",
    "\n",
    "BASE_DIR = Path.cwd()\n",
    "\n",
    "DATA_PATH = BASE_DIR / \"car_sales_dataset_with_person_details.csv\"\n",
    "\n",
    "# Optional absolute fallback for your current local setup\n",
    "if not DATA_PATH.exists():\n",
    "    fallback = Path(\"/home/konnilol/Documents/uni/kursovaya-sem5/car_sales_dataset_with_person_details.csv\")\n",
    "    if fallback.exists():\n",
    "        DATA_PATH = fallback\n",
    "\n",
    "if not DATA_PATH.exists():\n",
    "    raise FileNotFoundError(f\"CSV file not found: {DATA_PATH}\")\n",
    "\n",
    "TARGET_TOP_N = 20\n",
    "\n",
    "# Keep only items with enough observations so top-N evaluation is meaningful\n",
    "MIN_ITEM_SUPPORT = 20\n",
    "\n",
    "# Limit evaluation set for faster iteration\n",
    "MAX_EVAL_ROWS = 5000\n",
    "\n",
    "# KNN settings\n",
    "KNN_NEIGHBORS = 100\n",
    "\n",
    "# Neural settings\n",
    "BATCH_SIZE = 512\n",
    "MLP_EPOCHS = 8\n",
    "TWOTOWER_EPOCHS = 8\n",
    "LEARNING_RATE = 1e-3\n",
    "EMBED_DIM = 32\n",
    "\n",
    "print(\"Using dataset:\", DATA_PATH.resolve())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c41c1b7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape: (1000000, 11)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>First Name</th>\n",
       "      <th>Last Name</th>\n",
       "      <th>Address</th>\n",
       "      <th>Country</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Emily</td>\n",
       "      <td>Harris</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>Brazil</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>John</td>\n",
       "      <td>Harris</td>\n",
       "      <td>101 Maple Dr</td>\n",
       "      <td>Italy</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Karen</td>\n",
       "      <td>Wilson</td>\n",
       "      <td>202 Birch Blvd</td>\n",
       "      <td>UK</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Susan</td>\n",
       "      <td>Martinez</td>\n",
       "      <td>123 Main St</td>\n",
       "      <td>Mexico</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>Charles</td>\n",
       "      <td>Miller</td>\n",
       "      <td>456 Oak Ave</td>\n",
       "      <td>USA</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition First Name Last Name         Address Country  \n",
       "0  Certified Pre-Owned      Emily    Harris     456 Oak Ave  Brazil  \n",
       "1  Certified Pre-Owned       John    Harris    101 Maple Dr   Italy  \n",
       "2  Certified Pre-Owned      Karen    Wilson  202 Birch Blvd      UK  \n",
       "3                 Used      Susan  Martinez     123 Main St  Mexico  \n",
       "4                 Used    Charles    Miller     456 Oak Ave     USA  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Brand          object\n",
      "Model          object\n",
      "Year            int64\n",
      "Price         float64\n",
      "Mileage         int64\n",
      "Color          object\n",
      "Condition      object\n",
      "First Name     object\n",
      "Last Name      object\n",
      "Address        object\n",
      "Country        object\n",
      "dtype: object\n"
     ]
    }
   ],
   "source": [
    "# Load data\n",
    "\n",
    "df = pd.read_csv(DATA_PATH)\n",
    "\n",
    "print(\"Shape:\", df.shape)\n",
    "display(df.head())\n",
    "print(df.dtypes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f8209587",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows after cleanup: 1000000\n",
      "Unique items before filtering: 88\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>Country</th>\n",
       "      <th>item_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Honda</td>\n",
       "      <td>Civic</td>\n",
       "      <td>2023</td>\n",
       "      <td>25627.20</td>\n",
       "      <td>58513</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Brazil</td>\n",
       "      <td>Honda :: Civic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>Mazda3</td>\n",
       "      <td>2000</td>\n",
       "      <td>12027.14</td>\n",
       "      <td>60990</td>\n",
       "      <td>Brown</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>Italy</td>\n",
       "      <td>Mazda :: Mazda3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Mazda</td>\n",
       "      <td>CX-5</td>\n",
       "      <td>2014</td>\n",
       "      <td>49194.93</td>\n",
       "      <td>1703</td>\n",
       "      <td>Green</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>UK</td>\n",
       "      <td>Mazda :: CX-5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Hyundai</td>\n",
       "      <td>Tucson</td>\n",
       "      <td>2003</td>\n",
       "      <td>11955.94</td>\n",
       "      <td>25353</td>\n",
       "      <td>Silver</td>\n",
       "      <td>Used</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>Hyundai :: Tucson</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Land Rover</td>\n",
       "      <td>Range Rover</td>\n",
       "      <td>2012</td>\n",
       "      <td>10910.01</td>\n",
       "      <td>76854</td>\n",
       "      <td>Orange</td>\n",
       "      <td>Used</td>\n",
       "      <td>USA</td>\n",
       "      <td>Land Rover :: Range Rover</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Brand        Model  Year     Price  Mileage   Color  \\\n",
       "0       Honda        Civic  2023  25627.20    58513   Green   \n",
       "1       Mazda       Mazda3  2000  12027.14    60990   Brown   \n",
       "2       Mazda         CX-5  2014  49194.93     1703   Green   \n",
       "3     Hyundai       Tucson  2003  11955.94    25353  Silver   \n",
       "4  Land Rover  Range Rover  2012  10910.01    76854  Orange   \n",
       "\n",
       "             Condition Country                    item_id  \n",
       "0  Certified Pre-Owned  Brazil             Honda :: Civic  \n",
       "1  Certified Pre-Owned   Italy            Mazda :: Mazda3  \n",
       "2  Certified Pre-Owned      UK              Mazda :: CX-5  \n",
       "3                 Used  Mexico          Hyundai :: Tucson  \n",
       "4                 Used     USA  Land Rover :: Range Rover  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Define recommendation target and query features\n",
    "\n",
    "required_cols = [\n",
    "    \"Brand\", \"Model\", \"Year\", \"Price\", \"Mileage\",\n",
    "    \"Color\", \"Condition\", \"Country\"\n",
    "]\n",
    "\n",
    "missing = [c for c in required_cols if c not in df.columns]\n",
    "if missing:\n",
    "    raise ValueError(f\"Missing required columns: {missing}\")\n",
    "\n",
    "df = df[required_cols].copy()\n",
    "df = df.dropna().reset_index(drop=True)\n",
    "\n",
    "# Item definition: Brand + Model\n",
    "df[\"item_id\"] = df[\"Brand\"].astype(str).str.strip() + \" :: \" + df[\"Model\"].astype(str).str.strip()\n",
    "\n",
    "# Query features used by recommenders\n",
    "num_cols = [\"Year\", \"Price\", \"Mileage\"]\n",
    "cat_cols = [\"Color\", \"Condition\", \"Country\"]\n",
    "feature_cols = num_cols + cat_cols\n",
    "target_col = \"item_id\"\n",
    "\n",
    "print(\"Rows after cleanup:\", len(df))\n",
    "print(\"Unique items before filtering:\", df[target_col].nunique())\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0a48f3ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rows after item filtering: 1000000\n",
      "Unique items after filtering: 88\n",
      "Top items:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "item_id\n",
       "Chrysler :: Voyager         18572\n",
       "Chrysler :: 300             18497\n",
       "Chrysler :: Pacifica        18297\n",
       "Chevrolet :: Silverado      11391\n",
       "Mercedes-Benz :: E-Class    11301\n",
       "Hyundai :: Sonata           11298\n",
       "Kia :: Optima               11255\n",
       "Mercedes-Benz :: GLC        11249\n",
       "Jeep :: Cherokee            11245\n",
       "Toyota :: Tacoma            11241\n",
       "Name: count, dtype: int64"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train shape: (800000, 9)\n",
      "Test shape : (5000, 9)\n"
     ]
    }
   ],
   "source": [
    "# Filter rare items and create train/test split\n",
    "\n",
    "item_counts = df[target_col].value_counts()\n",
    "kept_items = item_counts[item_counts >= MIN_ITEM_SUPPORT].index\n",
    "\n",
    "df_model = df[df[target_col].isin(kept_items)].copy().reset_index(drop=True)\n",
    "\n",
    "print(\"Rows after item filtering:\", len(df_model))\n",
    "print(\"Unique items after filtering:\", df_model[target_col].nunique())\n",
    "print(\"Top items:\")\n",
    "display(df_model[target_col].value_counts().head(10))\n",
    "\n",
    "if df_model[target_col].nunique() < 5:\n",
    "    raise ValueError(\n",
    "        \"Too few supported items after filtering. Lower MIN_ITEM_SUPPORT and rerun.\"\n",
    "    )\n",
    "\n",
    "train_df, test_df = train_test_split(\n",
    "    df_model,\n",
    "    test_size=0.2,\n",
    "    random_state=RANDOM_STATE,\n",
    "    stratify=df_model[target_col]\n",
    ")\n",
    "\n",
    "train_df = train_df.reset_index(drop=True)\n",
    "test_df = test_df.reset_index(drop=True)\n",
    "\n",
    "if MAX_EVAL_ROWS is not None and len(test_df) > MAX_EVAL_ROWS:\n",
    "    test_df = test_df.sample(MAX_EVAL_ROWS, random_state=RANDOM_STATE).reset_index(drop=True)\n",
    "\n",
    "print(\"Train shape:\", train_df.shape)\n",
    "print(\"Test shape :\", test_df.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "63c8b73f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of items: 88\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>avg_year</th>\n",
       "      <th>avg_price</th>\n",
       "      <th>avg_mileage</th>\n",
       "      <th>main_color</th>\n",
       "      <th>main_condition</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>item_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Audi :: A3</th>\n",
       "      <td>Audi</td>\n",
       "      <td>A3</td>\n",
       "      <td>2011.970831</td>\n",
       "      <td>42423.440536</td>\n",
       "      <td>98860.997852</td>\n",
       "      <td>Gray</td>\n",
       "      <td>Used</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Audi :: A4</th>\n",
       "      <td>Audi</td>\n",
       "      <td>A4</td>\n",
       "      <td>2011.916124</td>\n",
       "      <td>42401.075803</td>\n",
       "      <td>99596.451157</td>\n",
       "      <td>Black</td>\n",
       "      <td>Used</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Audi :: A6</th>\n",
       "      <td>Audi</td>\n",
       "      <td>A6</td>\n",
       "      <td>2012.033868</td>\n",
       "      <td>42415.865483</td>\n",
       "      <td>99910.641023</td>\n",
       "      <td>Gray</td>\n",
       "      <td>Used</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Audi :: Q5</th>\n",
       "      <td>Audi</td>\n",
       "      <td>Q5</td>\n",
       "      <td>2011.936859</td>\n",
       "      <td>42487.546798</td>\n",
       "      <td>99363.572943</td>\n",
       "      <td>Orange</td>\n",
       "      <td>New</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Audi :: Q7</th>\n",
       "      <td>Audi</td>\n",
       "      <td>Q7</td>\n",
       "      <td>2012.034886</td>\n",
       "      <td>42450.890595</td>\n",
       "      <td>100324.285568</td>\n",
       "      <td>Red</td>\n",
       "      <td>New</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           Brand Model     avg_year     avg_price    avg_mileage main_color  \\\n",
       "item_id                                                                       \n",
       "Audi :: A3  Audi    A3  2011.970831  42423.440536   98860.997852       Gray   \n",
       "Audi :: A4  Audi    A4  2011.916124  42401.075803   99596.451157      Black   \n",
       "Audi :: A6  Audi    A6  2012.033868  42415.865483   99910.641023       Gray   \n",
       "Audi :: Q5  Audi    Q5  2011.936859  42487.546798   99363.572943     Orange   \n",
       "Audi :: Q7  Audi    Q7  2012.034886  42450.890595  100324.285568        Red   \n",
       "\n",
       "           main_condition  \n",
       "item_id                    \n",
       "Audi :: A3           Used  \n",
       "Audi :: A4           Used  \n",
       "Audi :: A6           Used  \n",
       "Audi :: Q5            New  \n",
       "Audi :: Q7            New  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Build item vocabulary and helper tables\n",
    "\n",
    "label_encoder = LabelEncoder()\n",
    "label_encoder.fit(train_df[target_col])\n",
    "\n",
    "train_df = train_df[train_df[target_col].isin(label_encoder.classes_)].copy()\n",
    "test_df = test_df[test_df[target_col].isin(label_encoder.classes_)].copy()\n",
    "\n",
    "train_df[\"item_idx\"] = label_encoder.transform(train_df[target_col])\n",
    "test_df[\"item_idx\"] = label_encoder.transform(test_df[target_col])\n",
    "\n",
    "all_items = label_encoder.classes_.tolist()\n",
    "num_items = len(all_items)\n",
    "\n",
    "item_popularity = train_df[target_col].value_counts()\n",
    "global_popularity = item_popularity.index.tolist()\n",
    "\n",
    "country_item_popularity = (\n",
    "    train_df.groupby([\"Country\", target_col])\n",
    "    .size()\n",
    "    .reset_index(name=\"cnt\")\n",
    "    .sort_values([\"Country\", \"cnt\"], ascending=[True, False])\n",
    ")\n",
    "\n",
    "country_top_items = {\n",
    "    country: grp[target_col].tolist()\n",
    "    for country, grp in country_item_popularity.groupby(\"Country\")\n",
    "}\n",
    "\n",
    "item_meta = (\n",
    "    train_df.groupby(target_col)\n",
    "    .agg(\n",
    "        Brand=(\"Brand\", \"first\"),\n",
    "        Model=(\"Model\", \"first\"),\n",
    "        avg_year=(\"Year\", \"mean\"),\n",
    "        avg_price=(\"Price\", \"mean\"),\n",
    "        avg_mileage=(\"Mileage\", \"mean\"),\n",
    "        main_color=(\"Color\", lambda s: s.mode().iloc[0] if not s.mode().empty else s.iloc[0]),\n",
    "        main_condition=(\"Condition\", lambda s: s.mode().iloc[0] if not s.mode().empty else s.iloc[0]),\n",
    "    )\n",
    ")\n",
    "\n",
    "print(\"Number of items:\", num_items)\n",
    "display(item_meta.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "08042e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Metrics\n",
    "\n",
    "def hit_rate_at_k(true_item, ranked_items, k):\n",
    "    return 1.0 if true_item in ranked_items[:k] else 0.0\n",
    "\n",
    "def mrr_at_k(true_item, ranked_items, k):\n",
    "    topk = ranked_items[:k]\n",
    "    if true_item in topk:\n",
    "        rank = topk.index(true_item) + 1\n",
    "        return 1.0 / rank\n",
    "    return 0.0\n",
    "\n",
    "def ndcg_at_k(true_item, ranked_items, k):\n",
    "    topk = ranked_items[:k]\n",
    "    if true_item in topk:\n",
    "        rank = topk.index(true_item) + 1\n",
    "        return 1.0 / math.log2(rank + 1)\n",
    "    return 0.0\n",
    "\n",
    "def evaluate_model(test_frame, recommend_fn, ks=(5, 10, 20), show_progress=True):\n",
    "    rows = []\n",
    "\n",
    "    iterator = tqdm(test_frame.itertuples(index=False), total=len(test_frame)) if show_progress else test_frame.itertuples(index=False)\n",
    "\n",
    "    for row in iterator:\n",
    "        ranked = recommend_fn(row, max(ks))\n",
    "        true_item = getattr(row, target_col)\n",
    "\n",
    "        result = {\"true_item\": true_item}\n",
    "        for k in ks:\n",
    "            result[f\"HR@{k}\"] = hit_rate_at_k(true_item, ranked, k)\n",
    "        result[\"MRR@10\"] = mrr_at_k(true_item, ranked, 10)\n",
    "        result[\"NDCG@10\"] = ndcg_at_k(true_item, ranked, 10)\n",
    "        rows.append(result)\n",
    "\n",
    "    detail_df = pd.DataFrame(rows)\n",
    "    summary = detail_df.drop(columns=[\"true_item\"]).mean().to_dict()\n",
    "    return summary, detail_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3b7574aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 1: popularity by country\n",
    "\n",
    "def dedupe_fill(primary, fallback, n):\n",
    "    out = []\n",
    "    seen = set()\n",
    "    for item in primary + fallback:\n",
    "        if item not in seen:\n",
    "            out.append(item)\n",
    "            seen.add(item)\n",
    "        if len(out) >= n:\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def recommend_popularity(row, n=10):\n",
    "    country = getattr(row, \"Country\")\n",
    "    country_items = country_top_items.get(country, [])\n",
    "    return dedupe_fill(country_items, global_popularity, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "69d53903",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KNN train matrix shape: (800000, 26)\n"
     ]
    }
   ],
   "source": [
    "# Regular model 2: context KNN on query features\n",
    "\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        (\"num\", Pipeline([\n",
    "            (\"imputer\", SimpleImputer(strategy=\"median\")),\n",
    "            (\"scaler\", StandardScaler())\n",
    "        ]), num_cols),\n",
    "        (\"cat\", Pipeline([\n",
    "            (\"imputer\", SimpleImputer(strategy=\"most_frequent\")),\n",
    "            (\"onehot\", OneHotEncoder(handle_unknown=\"ignore\"))\n",
    "        ]), cat_cols),\n",
    "    ]\n",
    ")\n",
    "\n",
    "X_train_ctx = preprocessor.fit_transform(train_df[feature_cols])\n",
    "\n",
    "knn = NearestNeighbors(\n",
    "    n_neighbors=min(KNN_NEIGHBORS, len(train_df)),\n",
    "    metric=\"cosine\",\n",
    "    algorithm=\"brute\",\n",
    "    n_jobs=-1\n",
    ")\n",
    "knn.fit(X_train_ctx)\n",
    "\n",
    "train_items_array = train_df[target_col].to_numpy()\n",
    "train_countries_array = train_df[\"Country\"].to_numpy()\n",
    "\n",
    "print(\"KNN train matrix shape:\", X_train_ctx.shape)\n",
    "\n",
    "def recommend_knn(row, n=10):\n",
    "    query_df = pd.DataFrame([{c: getattr(row, c) for c in feature_cols}])\n",
    "    q = preprocessor.transform(query_df)\n",
    "\n",
    "    distances, indices = knn.kneighbors(q, n_neighbors=min(KNN_NEIGHBORS, len(train_df)))\n",
    "    distances = distances[0]\n",
    "    indices = indices[0]\n",
    "\n",
    "    scores = {}\n",
    "    for dist, idx in zip(distances, indices):\n",
    "        item = train_items_array[idx]\n",
    "        country_bonus = 0.15 if train_countries_array[idx] == getattr(row, \"Country\") else 0.0\n",
    "        weight = (1.0 / (dist + 1e-6)) + country_bonus\n",
    "        scores[item] = scores.get(item, 0.0) + weight\n",
    "\n",
    "    ranked = [item for item, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]\n",
    "    fallback = country_top_items.get(getattr(row, \"Country\"), [])\n",
    "    return dedupe_fill(ranked, fallback + global_popularity, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "026f2236",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regular model 3: linear softmax classifier\n",
    "\n",
    "linear_model = SGDClassifier(\n",
    "    loss=\"log_loss\",\n",
    "    penalty=\"l2\",\n",
    "    alpha=1e-5,\n",
    "    max_iter=50,\n",
    "    tol=1e-3,\n",
    "    random_state=RANDOM_STATE\n",
    ")\n",
    "\n",
    "linear_model.fit(X_train_ctx, train_df[\"item_idx\"])\n",
    "\n",
    "def recommend_linear(row, n=10):\n",
    "    query_df = pd.DataFrame([{c: getattr(row, c) for c in feature_cols}])\n",
    "    q = preprocessor.transform(query_df)\n",
    "\n",
    "    probs = linear_model.predict_proba(q)[0]\n",
    "    top_idx = np.argsort(probs)[::-1][:n]\n",
    "    return label_encoder.inverse_transform(top_idx).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e25ed13c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neural train size: 800000\n",
      "Neural test size : 5000\n",
      "Items: 88\n"
     ]
    }
   ],
   "source": [
    "# Tabular encoders for neural models\n",
    "\n",
    "def build_vocab(series):\n",
    "    values = [\"[UNK]\"] + sorted(series.astype(str).unique().tolist())\n",
    "    return {v: i for i, v in enumerate(values)}\n",
    "\n",
    "color_vocab = build_vocab(train_df[\"Color\"])\n",
    "condition_vocab = build_vocab(train_df[\"Condition\"])\n",
    "country_vocab = build_vocab(train_df[\"Country\"])\n",
    "\n",
    "num_means = train_df[num_cols].mean()\n",
    "num_stds = train_df[num_cols].std().replace(0, 1.0)\n",
    "\n",
    "def encode_cat(value, vocab):\n",
    "    return vocab.get(str(value), 0)\n",
    "\n",
    "def encode_numeric_frame(frame):\n",
    "    arr = frame[num_cols].copy()\n",
    "    arr = (arr - num_means) / num_stds\n",
    "    return arr.astype(np.float32).to_numpy()\n",
    "\n",
    "def encode_categorical_frame(frame):\n",
    "    color = frame[\"Color\"].astype(str).map(lambda x: encode_cat(x, color_vocab)).astype(np.int64).to_numpy()\n",
    "    condition = frame[\"Condition\"].astype(str).map(lambda x: encode_cat(x, condition_vocab)).astype(np.int64).to_numpy()\n",
    "    country = frame[\"Country\"].astype(str).map(lambda x: encode_cat(x, country_vocab)).astype(np.int64).to_numpy()\n",
    "    return color, condition, country\n",
    "\n",
    "train_num = encode_numeric_frame(train_df)\n",
    "test_num = encode_numeric_frame(test_df)\n",
    "\n",
    "train_color, train_condition, train_country = encode_categorical_frame(train_df)\n",
    "test_color, test_condition, test_country = encode_categorical_frame(test_df)\n",
    "\n",
    "y_train = train_df[\"item_idx\"].astype(np.int64).to_numpy()\n",
    "y_test = test_df[\"item_idx\"].astype(np.int64).to_numpy()\n",
    "\n",
    "print(\"Neural train size:\", len(train_df))\n",
    "print(\"Neural test size :\", len(test_df))\n",
    "print(\"Items:\", num_items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "96f23867",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset classes\n",
    "\n",
    "class CarQueryDataset(Dataset):\n",
    "    def __init__(self, num_x, color_x, condition_x, country_x, y):\n",
    "        self.num_x = torch.tensor(num_x, dtype=torch.float32)\n",
    "        self.color_x = torch.tensor(color_x, dtype=torch.long)\n",
    "        self.condition_x = torch.tensor(condition_x, dtype=torch.long)\n",
    "        self.country_x = torch.tensor(country_x, dtype=torch.long)\n",
    "        self.y = torch.tensor(y, dtype=torch.long)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return (\n",
    "            self.num_x[idx],\n",
    "            self.color_x[idx],\n",
    "            self.condition_x[idx],\n",
    "            self.country_x[idx],\n",
    "            self.y[idx]\n",
    "        )\n",
    "\n",
    "train_dataset = CarQueryDataset(train_num, train_color, train_condition, train_country, y_train)\n",
    "test_dataset = CarQueryDataset(test_num, test_color, test_condition, test_country, y_test)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "79a7b0bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP epoch 1/8 - loss: 4.4729\n",
      "MLP epoch 2/8 - loss: 4.4719\n",
      "MLP epoch 3/8 - loss: 4.4717\n",
      "MLP epoch 4/8 - loss: 4.4716\n",
      "MLP epoch 5/8 - loss: 4.4714\n",
      "MLP epoch 6/8 - loss: 4.4712\n",
      "MLP epoch 7/8 - loss: 4.4710\n",
      "MLP epoch 8/8 - loss: 4.4708\n"
     ]
    }
   ],
   "source": [
    "# Neural model 1: MLP classifier\n",
    "\n",
    "class QueryEncoder(nn.Module):\n",
    "    def __init__(self, n_colors, n_conditions, n_countries, embed_dim=16, num_dim=3, hidden_dim=128, out_dim=128):\n",
    "        super().__init__()\n",
    "        self.color_emb = nn.Embedding(n_colors, embed_dim)\n",
    "        self.condition_emb = nn.Embedding(n_conditions, embed_dim)\n",
    "        self.country_emb = nn.Embedding(n_countries, embed_dim)\n",
    "\n",
    "        input_dim = num_dim + embed_dim * 3\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(input_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(hidden_dim, out_dim),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "\n",
    "    def forward(self, num_x, color_x, condition_x, country_x):\n",
    "        x = torch.cat([\n",
    "            num_x,\n",
    "            self.color_emb(color_x),\n",
    "            self.condition_emb(condition_x),\n",
    "            self.country_emb(country_x)\n",
    "        ], dim=1)\n",
    "        return self.net(x)\n",
    "\n",
    "class MLPClassifierRecommender(nn.Module):\n",
    "    def __init__(self, n_colors, n_conditions, n_countries, n_items, embed_dim=16, hidden_dim=128):\n",
    "        super().__init__()\n",
    "        self.encoder = QueryEncoder(\n",
    "            n_colors=n_colors,\n",
    "            n_conditions=n_conditions,\n",
    "            n_countries=n_countries,\n",
    "            embed_dim=embed_dim,\n",
    "            hidden_dim=hidden_dim,\n",
    "            out_dim=hidden_dim\n",
    "        )\n",
    "        self.head = nn.Linear(hidden_dim, n_items)\n",
    "\n",
    "    def forward(self, num_x, color_x, condition_x, country_x):\n",
    "        z = self.encoder(num_x, color_x, condition_x, country_x)\n",
    "        return self.head(z)\n",
    "\n",
    "mlp_model = MLPClassifierRecommender(\n",
    "    n_colors=len(color_vocab),\n",
    "    n_conditions=len(condition_vocab),\n",
    "    n_countries=len(country_vocab),\n",
    "    n_items=num_items,\n",
    "    embed_dim=EMBED_DIM,\n",
    "    hidden_dim=128\n",
    ").to(DEVICE)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(mlp_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "for epoch in range(MLP_EPOCHS):\n",
    "    mlp_model.train()\n",
    "    total_loss = 0.0\n",
    "\n",
    "    for num_x, color_x, condition_x, country_x, y in train_loader:\n",
    "        num_x = num_x.to(DEVICE)\n",
    "        color_x = color_x.to(DEVICE)\n",
    "        condition_x = condition_x.to(DEVICE)\n",
    "        country_x = country_x.to(DEVICE)\n",
    "        y = y.to(DEVICE)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        logits = mlp_model(num_x, color_x, condition_x, country_x)\n",
    "        loss = criterion(logits, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item() * len(y)\n",
    "\n",
    "    epoch_loss = total_loss / len(train_loader.dataset)\n",
    "    print(f\"MLP epoch {epoch+1}/{MLP_EPOCHS} - loss: {epoch_loss:.4f}\")\n",
    "\n",
    "@torch.no_grad()\n",
    "def recommend_mlp(row, n=10):\n",
    "    mlp_model.eval()\n",
    "\n",
    "    num_x = torch.tensor([[\n",
    "        (float(getattr(row, \"Year\")) - num_means[\"Year\"]) / num_stds[\"Year\"],\n",
    "        (float(getattr(row, \"Price\")) - num_means[\"Price\"]) / num_stds[\"Price\"],\n",
    "        (float(getattr(row, \"Mileage\")) - num_means[\"Mileage\"]) / num_stds[\"Mileage\"],\n",
    "    ]], dtype=torch.float32, device=DEVICE)\n",
    "\n",
    "    color_x = torch.tensor([encode_cat(getattr(row, \"Color\"), color_vocab)], dtype=torch.long, device=DEVICE)\n",
    "    condition_x = torch.tensor([encode_cat(getattr(row, \"Condition\"), condition_vocab)], dtype=torch.long, device=DEVICE)\n",
    "    country_x = torch.tensor([encode_cat(getattr(row, \"Country\"), country_vocab)], dtype=torch.long, device=DEVICE)\n",
    "\n",
    "    logits = mlp_model(num_x, color_x, condition_x, country_x)[0]\n",
    "    top_idx = torch.topk(logits, k=min(n, num_items)).indices.cpu().numpy()\n",
    "    return label_encoder.inverse_transform(top_idx).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ea7c18a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TwoTower epoch 1/8 - loss: 4.4788\n",
      "TwoTower epoch 2/8 - loss: 4.4774\n",
      "TwoTower epoch 3/8 - loss: 4.4773\n",
      "TwoTower epoch 4/8 - loss: 4.4773\n",
      "TwoTower epoch 5/8 - loss: 4.4773\n",
      "TwoTower epoch 6/8 - loss: 4.4773\n",
      "TwoTower epoch 7/8 - loss: 4.4773\n",
      "TwoTower epoch 8/8 - loss: 4.4773\n"
     ]
    }
   ],
   "source": [
    "# Neural model 2: two-tower retrieval\n",
    "\n",
    "class TwoTowerRecommender(nn.Module):\n",
    "    def __init__(self, n_colors, n_conditions, n_countries, n_items, embed_dim=32, hidden_dim=128):\n",
    "        super().__init__()\n",
    "        self.query_encoder = QueryEncoder(\n",
    "            n_colors=n_colors,\n",
    "            n_conditions=n_conditions,\n",
    "            n_countries=n_countries,\n",
    "            embed_dim=embed_dim,\n",
    "            hidden_dim=hidden_dim,\n",
    "            out_dim=embed_dim\n",
    "        )\n",
    "        self.item_embedding = nn.Embedding(n_items, embed_dim)\n",
    "\n",
    "    def forward(self, num_x, color_x, condition_x, country_x):\n",
    "        q = self.query_encoder(num_x, color_x, condition_x, country_x)\n",
    "        item_emb = self.item_embedding.weight\n",
    "        scores = q @ item_emb.T\n",
    "        return scores\n",
    "\n",
    "twotower_model = TwoTowerRecommender(\n",
    "    n_colors=len(color_vocab),\n",
    "    n_conditions=len(condition_vocab),\n",
    "    n_countries=len(country_vocab),\n",
    "    n_items=num_items,\n",
    "    embed_dim=EMBED_DIM,\n",
    "    hidden_dim=128\n",
    ").to(DEVICE)\n",
    "\n",
    "criterion_tt = nn.CrossEntropyLoss()\n",
    "optimizer_tt = torch.optim.Adam(twotower_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "for epoch in range(TWOTOWER_EPOCHS):\n",
    "    twotower_model.train()\n",
    "    total_loss = 0.0\n",
    "\n",
    "    for num_x, color_x, condition_x, country_x, y in train_loader:\n",
    "        num_x = num_x.to(DEVICE)\n",
    "        color_x = color_x.to(DEVICE)\n",
    "        condition_x = condition_x.to(DEVICE)\n",
    "        country_x = country_x.to(DEVICE)\n",
    "        y = y.to(DEVICE)\n",
    "\n",
    "        optimizer_tt.zero_grad()\n",
    "        scores = twotower_model(num_x, color_x, condition_x, country_x)\n",
    "        loss = criterion_tt(scores, y)\n",
    "        loss.backward()\n",
    "        optimizer_tt.step()\n",
    "\n",
    "        total_loss += loss.item() * len(y)\n",
    "\n",
    "    epoch_loss = total_loss / len(train_loader.dataset)\n",
    "    print(f\"TwoTower epoch {epoch+1}/{TWOTOWER_EPOCHS} - loss: {epoch_loss:.4f}\")\n",
    "\n",
    "@torch.no_grad()\n",
    "def recommend_twotower(row, n=10):\n",
    "    twotower_model.eval()\n",
    "\n",
    "    num_x = torch.tensor([[\n",
    "        (float(getattr(row, \"Year\")) - num_means[\"Year\"]) / num_stds[\"Year\"],\n",
    "        (float(getattr(row, \"Price\")) - num_means[\"Price\"]) / num_stds[\"Price\"],\n",
    "        (float(getattr(row, \"Mileage\")) - num_means[\"Mileage\"]) / num_stds[\"Mileage\"],\n",
    "    ]], dtype=torch.float32, device=DEVICE)\n",
    "\n",
    "    color_x = torch.tensor([encode_cat(getattr(row, \"Color\"), color_vocab)], dtype=torch.long, device=DEVICE)\n",
    "    condition_x = torch.tensor([encode_cat(getattr(row, \"Condition\"), condition_vocab)], dtype=torch.long, device=DEVICE)\n",
    "    country_x = torch.tensor([encode_cat(getattr(row, \"Country\"), country_vocab)], dtype=torch.long, device=DEVICE)\n",
    "\n",
    "    scores = twotower_model(num_x, color_x, condition_x, country_x)[0]\n",
    "    top_idx = torch.topk(scores, k=min(n, num_items)).indices.cpu().numpy()\n",
    "    return label_encoder.inverse_transform(top_idx).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "fece0740",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example query:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Year</th>\n",
       "      <th>Price</th>\n",
       "      <th>Mileage</th>\n",
       "      <th>Color</th>\n",
       "      <th>Condition</th>\n",
       "      <th>Country</th>\n",
       "      <th>item_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1501</th>\n",
       "      <td>2007</td>\n",
       "      <td>5180.49</td>\n",
       "      <td>84626</td>\n",
       "      <td>Black</td>\n",
       "      <td>Certified Pre-Owned</td>\n",
       "      <td>France</td>\n",
       "      <td>BMW :: Z4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      Year    Price Mileage  Color            Condition Country    item_id\n",
       "1501  2007  5180.49   84626  Black  Certified Pre-Owned  France  BMW :: Z4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Popularity:\n",
      "['Chrysler :: Pacifica', 'Chrysler :: 300', 'Chrysler :: Voyager', 'Lexus :: NX', 'Kia :: Sportage', 'Volkswagen :: Atlas', 'Chevrolet :: Trax', 'Ford :: F-150', 'Hyundai :: Santa Fe', 'Honda :: Civic']\n",
      "KNN:\n",
      "['Kia :: Seltos', 'BMW :: 3 Series', 'Chevrolet :: Trax', 'Audi :: Q7', 'Subaru :: Legacy', 'Audi :: Q5', 'Ford :: Mustang', 'Toyota :: Tacoma', 'Ford :: Fusion', 'Hyundai :: Kona']\n",
      "Linear:\n",
      "['Chrysler :: Pacifica', 'Chrysler :: 300', 'Dodge :: Journey', 'Dodge :: Ram 1500', 'Subaru :: Forester', 'Land Rover :: Velar', 'Lexus :: IS', 'Chevrolet :: Equinox', 'Hyundai :: Sonata', 'Mercedes-Benz :: E-Class']\n",
      "MLP:\n",
      "['Chrysler :: 300', 'Chrysler :: Pacifica', 'Chrysler :: Voyager', 'Toyota :: Tacoma', 'Hyundai :: Santa Fe', 'Volkswagen :: Tiguan', 'Subaru :: Crosstrek', 'Land Rover :: Range Rover', 'Nissan :: Murano', 'Jeep :: Wrangler']\n",
      "TwoTower:\n",
      "['BMW :: X3', 'BMW :: 5 Series', 'Audi :: Q7', 'BMW :: 3 Series', 'Audi :: A4', 'Audi :: A3', 'Audi :: A6', 'Audi :: Q5', 'BMW :: X5', 'BMW :: Z4']\n"
     ]
    }
   ],
   "source": [
    "# Example recommendations for one test query\n",
    "\n",
    "example_row = test_df.sample(1, random_state=RANDOM_STATE).iloc[0]\n",
    "\n",
    "print(\"Example query:\")\n",
    "display(example_row[feature_cols + [target_col]].to_frame().T)\n",
    "\n",
    "print(\"Popularity:\")\n",
    "print(recommend_popularity(example_row, 10))\n",
    "\n",
    "print(\"KNN:\")\n",
    "print(recommend_knn(example_row, 10))\n",
    "\n",
    "print(\"Linear:\")\n",
    "print(recommend_linear(example_row, 10))\n",
    "\n",
    "print(\"MLP:\")\n",
    "print(recommend_mlp(example_row, 10))\n",
    "\n",
    "print(\"TwoTower:\")\n",
    "print(recommend_twotower(example_row, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "261628ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: PopularityByCountry\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5000/5000 [00:00<00:00, 211334.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: ContextKNN\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5000/5000 [05:25<00:00, 15.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: LinearSoftmax\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5000/5000 [00:14<00:00, 352.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: MLPClassifier\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5000/5000 [00:02<00:00, 1680.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Evaluating: TwoTower\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5000/5000 [00:02<00:00, 1806.94it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>HR@5</th>\n",
       "      <th>HR@10</th>\n",
       "      <th>HR@20</th>\n",
       "      <th>MRR@10</th>\n",
       "      <th>NDCG@10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>LinearSoftmax</td>\n",
       "      <td>0.0720</td>\n",
       "      <td>0.1300</td>\n",
       "      <td>0.2366</td>\n",
       "      <td>0.040116</td>\n",
       "      <td>0.060743</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MLPClassifier</td>\n",
       "      <td>0.0708</td>\n",
       "      <td>0.1278</td>\n",
       "      <td>0.2394</td>\n",
       "      <td>0.040576</td>\n",
       "      <td>0.060644</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>PopularityByCountry</td>\n",
       "      <td>0.0702</td>\n",
       "      <td>0.1260</td>\n",
       "      <td>0.2454</td>\n",
       "      <td>0.042500</td>\n",
       "      <td>0.061681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>TwoTower</td>\n",
       "      <td>0.0686</td>\n",
       "      <td>0.1248</td>\n",
       "      <td>0.2458</td>\n",
       "      <td>0.038380</td>\n",
       "      <td>0.058247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ContextKNN</td>\n",
       "      <td>0.0606</td>\n",
       "      <td>0.1208</td>\n",
       "      <td>0.2362</td>\n",
       "      <td>0.034730</td>\n",
       "      <td>0.054437</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Model    HR@5   HR@10   HR@20    MRR@10   NDCG@10\n",
       "0        LinearSoftmax  0.0720  0.1300  0.2366  0.040116  0.060743\n",
       "1        MLPClassifier  0.0708  0.1278  0.2394  0.040576  0.060644\n",
       "2  PopularityByCountry  0.0702  0.1260  0.2454  0.042500  0.061681\n",
       "3             TwoTower  0.0686  0.1248  0.2458  0.038380  0.058247\n",
       "4           ContextKNN  0.0606  0.1208  0.2362  0.034730  0.054437"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Evaluate all models\n",
    "\n",
    "results = []\n",
    "\n",
    "model_registry = {\n",
    "    \"PopularityByCountry\": recommend_popularity,\n",
    "    \"ContextKNN\": recommend_knn,\n",
    "    \"LinearSoftmax\": recommend_linear,\n",
    "    \"MLPClassifier\": recommend_mlp,\n",
    "    \"TwoTower\": recommend_twotower,\n",
    "}\n",
    "\n",
    "for model_name, fn in model_registry.items():\n",
    "    print(\"=\" * 80)\n",
    "    print(\"Evaluating:\", model_name)\n",
    "    summary, detail = evaluate_model(test_df, fn, ks=(5, 10, 20), show_progress=True)\n",
    "    summary[\"Model\"] = model_name\n",
    "    results.append(summary)\n",
    "\n",
    "results_df = pd.DataFrame(results)[[\"Model\", \"HR@5\", \"HR@10\", \"HR@20\", \"MRR@10\", \"NDCG@10\"]]\n",
    "results_df = results_df.sort_values([\"HR@10\", \"NDCG@10\"], ascending=False).reset_index(drop=True)\n",
    "\n",
    "display(results_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78e97eac",
   "metadata": {},
   "source": [
    "## Notes for the next iteration\n",
    "\n",
    "Likely next improvements:\n",
    "1. Better item definition: compare `Brand + Model` vs `Brand + Model + Condition`\n",
    "2. More query features: add price buckets, age, mileage buckets\n",
    "3. Hard-negative training for the neural retrieval model\n",
    "4. Better regular model: CatBoost or LightGBM ranking/classification\n",
    "5. Better neural model: deeper MLP, residual blocks, feature crosses, or a stronger retrieval loss"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
