{ "cells": [ { "cell_type": "markdown", "id": "4375c2f8-249a-4f6d-8089-2b647bd83ecb", "metadata": {}, "source": [ "# Lecture 3 - Neural network deep dive" ] }, { "cell_type": "markdown", "id": "07d6c97c-1780-40a8-a739-f4c962182986", "metadata": {}, "source": [ "> A deep dive into optimising neural networks with stochastic gradient descent" ] }, { "cell_type": "markdown", "id": "dc5f2762-c6b8-4368-8042-336c40a448d8", "metadata": {}, "source": [ "## Learning objectives\n", "\n", "* Understand how to implement neural networks from scratch\n", "* Understand all the ingredients needed to define a `Learner` in fastai" ] }, { "cell_type": "markdown", "id": "47dfe094-686e-40e7-8eaa-0d4991a2243c", "metadata": {}, "source": [ "## References\n", "\n", "* Chapter 4 of [_Deep Learning for Coders with fastai & PyTorch_](https://github.com/fastai/fastbook) by Jeremy Howard and Sylvain Gugger.\n", "* [What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html#what-is-torch-nn-really) by Jeremy Howard." ] }, { "cell_type": "markdown", "id": "f0ee1908-d0d4-455b-b758-cd48ac5a377b", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "9f53116e-b38d-43c3-bcb5-8bc9636a0b57", "metadata": {}, "outputs": [], "source": [ "# Uncomment and run this cell if using Colab, Kaggle etc\n", "# %pip install fastai==2.6.0 datasets" ] }, { "cell_type": "markdown", "id": "699fa5d4-0c13-4e79-821c-e2701c4279c3", "metadata": { "tags": [] }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 103, "id": "73261150-b566-48b0-812f-293e97564d7c", "metadata": {}, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "from datasets import load_dataset\n", "from fastai.tabular.all import *\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import MinMaxScaler\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "id": "62c70888-089f-42ec-904a-a271a66c5060", "metadata": {}, "outputs": [], "source": [ "import datasets\n", "\n", "# Suppress logs to keep things tidy\n", "datasets.logging.set_verbosity_error()" ] }, { "cell_type": "markdown", "id": "230c5b5f-3b6c-4c99-a4aa-94393105d7a4", "metadata": {}, "source": [ "## The dataset" ] }, { "cell_type": "markdown", "id": "0cb5955a-ecce-4d19-8f2b-08581a0fdac6", "metadata": {}, "source": [ "In lecture 2, we focused on optimising simple functions with stochastic gradient descent. Let's now tackle a real-world problem using neural networks! We'll use the $N$-subjettiness dataset from lecture 1 that represents jets in terms of $\\tau_N^{(\\beta)}$ variables that measure the radiation about $N$ axes in the jet according to an angular exponent $\\beta>0$. As usual, we'll load the dataset from the Hugging Face Hub and convert it to a Pandas `DataFrame` via the `to_pandas()` method:" ] }, { "cell_type": "code", "execution_count": 3, "id": "d67c6c7d-f7fc-44bf-8ba0-448d66495554", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fb4a1359f144a2ba5a44a814e7b586b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", " | pT | \n", "mass | \n", "tau_1_0.5 | \n", "tau_1_1 | \n", "tau_1_2 | \n", "tau_2_0.5 | \n", "tau_2_1 | \n", "tau_2_2 | \n", "tau_3_0.5 | \n", "tau_3_1 | \n", "... | \n", "tau_4_0.5 | \n", "tau_4_1 | \n", "tau_4_2 | \n", "tau_5_0.5 | \n", "tau_5_1 | \n", "tau_5_2 | \n", "tau_6_0.5 | \n", "tau_6_1 | \n", "tau_6_2 | \n", "label | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "543.633944 | \n", "25.846792 | \n", "0.165122 | \n", "0.032661 | \n", "0.002262 | \n", "0.048830 | \n", "0.003711 | \n", "0.000044 | \n", "0.030994 | \n", "0.001630 | \n", "... | \n", "0.024336 | \n", "0.001115 | \n", "0.000008 | \n", "0.004252 | \n", "0.000234 | \n", "7.706005e-07 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000e+00 | \n", "0 | \n", "
1 | \n", "452.411860 | \n", "13.388679 | \n", "0.162938 | \n", "0.027598 | \n", "0.000876 | \n", "0.095902 | \n", "0.015461 | \n", "0.000506 | \n", "0.079750 | \n", "0.009733 | \n", "... | \n", "0.056854 | \n", "0.005454 | \n", "0.000072 | \n", "0.044211 | \n", "0.004430 | \n", "6.175314e-05 | \n", "0.037458 | \n", "0.003396 | \n", "3.670517e-05 | \n", "0 | \n", "
2 | \n", "429.495258 | \n", "32.021091 | \n", "0.244436 | \n", "0.065901 | \n", "0.005557 | \n", "0.155202 | \n", "0.038807 | \n", "0.002762 | \n", "0.123285 | \n", "0.025339 | \n", "... | \n", "0.078205 | \n", "0.012678 | \n", "0.000567 | \n", "0.052374 | \n", "0.005935 | \n", "9.395772e-05 | \n", "0.037572 | \n", "0.002932 | \n", "2.237277e-05 | \n", "0 | \n", "
3 | \n", "512.675443 | \n", "6.684734 | \n", "0.102580 | \n", "0.011369 | \n", "0.000170 | \n", "0.086306 | \n", "0.007760 | \n", "0.000071 | \n", "0.068169 | \n", "0.005386 | \n", "... | \n", "0.044705 | \n", "0.002376 | \n", "0.000008 | \n", "0.027895 | \n", "0.001364 | \n", "4.400042e-06 | \n", "0.009012 | \n", "0.000379 | \n", "6.731099e-07 | \n", "0 | \n", "
4 | \n", "527.956859 | \n", "133.985415 | \n", "0.407009 | \n", "0.191839 | \n", "0.065169 | \n", "0.291460 | \n", "0.105479 | \n", "0.029753 | \n", "0.209341 | \n", "0.049187 | \n", "... | \n", "0.143768 | \n", "0.033249 | \n", "0.003689 | \n", "0.135407 | \n", "0.029054 | \n", "2.593460e-03 | \n", "0.110805 | \n", "0.023179 | \n", "2.202088e-03 | \n", "0 | \n", "
5 rows × 21 columns
\n", "\n", " | pT | \n", "mass | \n", "tau_1_0.5 | \n", "tau_1_1 | \n", "tau_1_2 | \n", "tau_2_0.5 | \n", "tau_2_1 | \n", "tau_2_2 | \n", "tau_3_0.5 | \n", "tau_3_1 | \n", "... | \n", "tau_4_0.5 | \n", "tau_4_1 | \n", "tau_4_2 | \n", "tau_5_0.5 | \n", "tau_5_1 | \n", "tau_5_2 | \n", "tau_6_0.5 | \n", "tau_6_1 | \n", "tau_6_2 | \n", "label | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "... | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "908250.000000 | \n", "
mean | \n", "487.107393 | \n", "88.090520 | \n", "0.366716 | \n", "0.198446 | \n", "0.319559 | \n", "0.222759 | \n", "0.079243 | \n", "0.072535 | \n", "0.148137 | \n", "0.035372 | \n", "... | \n", "0.112024 | \n", "0.022150 | \n", "0.008670 | \n", "0.088400 | \n", "0.015329 | \n", "0.004875 | \n", "0.070679 | \n", "0.011019 | \n", "0.002914 | \n", "0.500366 | \n", "
std | \n", "48.568267 | \n", "48.393646 | \n", "0.186922 | \n", "0.339542 | \n", "2.003898 | \n", "0.110955 | \n", "0.125155 | \n", "0.674091 | \n", "0.072627 | \n", "0.051869 | \n", "... | \n", "0.059393 | \n", "0.032004 | \n", "0.155468 | \n", "0.051949 | \n", "0.022866 | \n", "0.107641 | \n", "0.046571 | \n", "0.017133 | \n", "0.078247 | \n", "0.500000 | \n", "
min | \n", "225.490387 | \n", "-0.433573 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "... | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
25% | \n", "452.879289 | \n", "39.958178 | \n", "0.224456 | \n", "0.058381 | \n", "0.006443 | \n", "0.139269 | \n", "0.025638 | \n", "0.001565 | \n", "0.094603 | \n", "0.013308 | \n", "... | \n", "0.069037 | \n", "0.007949 | \n", "0.000188 | \n", "0.051012 | \n", "0.004936 | \n", "0.000079 | \n", "0.036142 | \n", "0.002977 | \n", "0.000033 | \n", "0.000000 | \n", "
50% | \n", "485.894050 | \n", "99.887418 | \n", "0.380172 | \n", "0.166016 | \n", "0.045887 | \n", "0.222763 | \n", "0.061597 | \n", "0.008788 | \n", "0.148810 | \n", "0.028501 | \n", "... | \n", "0.110220 | \n", "0.017609 | \n", "0.000787 | \n", "0.086045 | \n", "0.011755 | \n", "0.000387 | \n", "0.067797 | \n", "0.008028 | \n", "0.000193 | \n", "1.000000 | \n", "
75% | \n", "520.506446 | \n", "126.518545 | \n", "0.477122 | \n", "0.240550 | \n", "0.074417 | \n", "0.299708 | \n", "0.108207 | \n", "0.022441 | \n", "0.196156 | \n", "0.046588 | \n", "... | \n", "0.151137 | \n", "0.029990 | \n", "0.002006 | \n", "0.121905 | \n", "0.021089 | \n", "0.001103 | \n", "0.100437 | \n", "0.015359 | \n", "0.000635 | \n", "1.000000 | \n", "
max | \n", "647.493145 | \n", "299.211555 | \n", "2.431888 | \n", "6.013309 | \n", "37.702422 | \n", "2.218956 | \n", "5.392683 | \n", "33.352249 | \n", "1.917912 | \n", "4.502011 | \n", "... | \n", "1.616280 | \n", "3.753716 | \n", "21.161948 | \n", "1.407356 | \n", "3.158352 | \n", "17.645603 | \n", "1.388879 | \n", "3.127371 | \n", "17.340970 | \n", "1.000000 | \n", "
8 rows × 21 columns
\n", "epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
0 | \n", "0.251400 | \n", "0.311221 | \n", "0.834794 | \n", "00:13 | \n", "
1 | \n", "0.243241 | \n", "0.369533 | \n", "0.796215 | \n", "00:13 | \n", "
2 | \n", "0.242842 | \n", "0.313126 | \n", "0.863372 | \n", "00:13 | \n", "