From 4ec6923b884683197a8dd24ef44301ec68b50338 Mon Sep 17 00:00:00 2001
From: Uwe Schmitt <uwe.schmitt@id.ethz.ch>
Date: Fri, 5 Feb 2021 16:20:38 +0100
Subject: [PATCH] introduced MNIST example to show case cross validation

---
 03_overfitting_and_cross_validation.ipynb | 284 +++++++++++++++++-----
 1 file changed, 219 insertions(+), 65 deletions(-)

diff --git a/03_overfitting_and_cross_validation.ipynb b/03_overfitting_and_cross_validation.ipynb
index 14bb4dd..26e78cb 100644
--- a/03_overfitting_and_cross_validation.ipynb
+++ b/03_overfitting_and_cross_validation.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [
     {
@@ -103,15 +103,15 @@
        "    div#maintoolbar {display: none !important;}\n",
        "    /*\n",
        "\n",
-       "    div#site {\n",
-       "        border-top: 20px solid #1F407A;\n",
-       "        border-right: 20px solid #1F407A;\n",
+       "    div#site { \n",
+       "        border-top: 20px solid #1F407A; \n",
+       "        border-right: 20px solid #1F407A; \n",
        "        margin-bottom: 0;\n",
        "        padding-bottom: 0;\n",
        "    }\n",
-       "    div#toc-wrapper {\n",
-       "        border-left: 20px solid #1F407A;\n",
-       "        border-top: 20px solid #1F407A;\n",
+       "    div#toc-wrapper { \n",
+       "        border-left: 20px solid #1F407A; \n",
+       "        border-top: 20px solid #1F407A; \n",
        "\n",
        "    }\n",
        "\n",
@@ -143,7 +143,7 @@
        "<IPython.core.display.HTML object>"
       ]
      },
-     "execution_count": 25,
+     "execution_count": 1,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -836,9 +836,9 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "The `cross_val_score` as used in the previous code example works as follows:\n",
+    "The `cross_val_score` as used in the previous code example works internally as follows:\n",
     "\n",
-    "0. split training data in four chunks\n",
+    "- split training data in four chunks\n",
     "- learn `classifier` on chunk `1, 2, 3`, apply classifier to chunk `4` and compute score `s1`\n",
     "- learn `classifier` on chunk `1, 2, 4`, apply classifier to chunk `3` and compute score `s2`\n",
     "- learn `classifier` on chunk `1, 3, 4`, apply classifier to chunk `2` and compute score `s3`\n",
@@ -1216,30 +1216,96 @@
    "source": [
     "### Demonstration\n",
     "\n",
-    "We introduce the `train_test_split` function from `sklearn.model_selection` in the following example.\n",
-    "\n",
-    "It splits features and labels in a given proportion. Usually this is randomized, so that you get different results for every function invocation. To get the same result every time we use `random_state=..` (with arbitrary number) below:"
+    "We demonstrate what we explained before using the MNIST dataset."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 54,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "digits data set shape: (1797, 8, 8)\n",
+      "feature matrix shape: (1797, 64)\n"
+     ]
+    }
+   ],
    "source": [
-    "import pandas as pd\n",
+    "import numpy as np\n",
+    "from sklearn.datasets import load_digits\n",
+    "from sklearn.metrics import classification_report\n",
+    "from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split\n",
+    "from sklearn.svm import SVC\n",
     "\n",
-    "beer = pd.read_csv(\"data/beers.csv\")\n",
-    "beer_eval = pd.read_csv(\"data/beers_eval.csv\")\n",
-    "all_beer = pd.concat((beer, beer_eval))\n",
+    "digits = load_digits()\n",
     "\n",
-    "features = all_beer.iloc[:, :-1]\n",
-    "labels = all_beer.iloc[:, -1]"
+    "print(\"digits data set shape:\", digits.images.shape)\n",
+    "\n",
+    "# flatten images of shape N_SAMPLES x 8 x 8\n",
+    "# to N_SAMPLES x 64:\n",
+    "labels = digits.target\n",
+    "n_samples = len(labels)\n",
+    "\n",
+    "features = digits.images.reshape((n_samples, -1))\n",
+    "print(\"feature matrix shape:\", features.shape)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 55,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAD3CAYAAADmIkO7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAJMklEQVR4nO3df6hfdR3H8efbpmNiugtZ4crJJsYQZLj9VURC4s8si3aziMofbQViv6BpUFmELTBC+8eQqEhKHCEqarZbWUL4hyNZDZFylaamrV1NbWjqpz/OufHd5e6Hd9u5r+0+H3Dh7ny/3/P5fL/3PL/nfM+97FRrDUl5jpjrCUiamXFKoYxTCmWcUijjlEIZpxRqXsdZVVur6oy5nsfeVNU3qmp7Vf3jAK3vHVX1p6p6vqourKq7q+rjB2LdOnDK33Nmq6oTgYeBpa21pw/QOn8J3N5au+5ArO9AqqqVwPeBFcBDwKWttQfnck5zZV7vOQ8RJwL/mk2YVbVgNzctBbbu16z2b/zd3f8o4DbgJmAM+BFwW798/mmtzdsv4K/Amf33VwMb6TaM54A/AKcAVwFPA48BZ4089mK6d/bngG3Aumnr/iLwJPAEcBnQgJP72xYC1wKPAk8BNwCLZpjfmcBO4FXgeeCH/fL30sX1DHAvsGLac1oPbAFeBBZMW+cj/fp29utc2K/jsv721wHfBrYDfwEu7+e+YPprNvK63dR/f1J/30v75/bbfvkl/Ws1CdxDdxQw08/jLOBx+iO6ftmjwDlzva3MxZd7zl1dAPyY7l3793Qb0hHAEuDrwPdG7vs08B7gWLpQv1NVpwNU1TnA5+niOhk4Y9o4G+jCX9nfvgT4yvTJtNYmgHOBJ1prx7TWPlFVpwA/BT4LHA/cBdwxbe/yYeB8YHFr7eVp61xOt8Ff0K/zxWnDfrIfcyVwOnDhDK/T3ryL7rD07Kp6H/Al4AP9fO/r5z+TU4Etra+yt6VfPu8Y567ua63d02/QG+k2pg2ttf8CNwMnVdVigNbana21R1rnN8AvgHf26xkHftBa29pa+w/d3gWAqipgLfC51tqO1tpzwDXARfs4xw8Bd7bWNvXzuhZYBLx95D7Xt9Yea63tnMVrMA5c11r7e2ttku6N5LW6urX2Qj/+p4BvttYe6l/Xa4CVVbV0hscdAzw7bdmzwOtnMYdDnnHu6qmR73cC21trr4z8G7oNiKo6t6rur6odVfUMcB7whv4+J9AdBk8Z/f544Ghgc1U90z/25/3yfXEC8Lepf7TWXu3Xv2Q3471We5r7vhp9zFLgupHnugModp3vlOfpjkRGHUv30WHeMc5ZqKqFwM/o9lpvaq0tpju8rP4uTwJvGXnIW0e+304X+qmttcX913GttWP2cfgn6Db4qblUv/7HR+6zP6fg9zR3gBfo3lymvHmGdYyO/xjd5/HFI1+LWmu/m+FxW4HT+uc05TQGOHmVyDhn5yi6Eyn/BF6uqnPpTmZMuQW4uKpWVNXRwJenbuj3dDfSfUZ9I0BVLamqs/dx7FuA86vq3VV1JPAFuhM/M23ss3EL8Jl+TovpTi6NehC4qKqOrKrVwAf3sr4bgKuq6lSAqjquqtbs5r73Aq8AV1TVwqq6vF/+q9f+NA59xjkL/efEK+g25EngI8DtI7ffDVwP/Br4M3B/f9PUyZf1U8ur6t/ABPC2fRz7YeCjwHfp9sIX0J3ceWn/ntX/3Uj3+XkL3Umxu4CX6aKB7o1mOd3z/hrwk73M91bgW8DN/XP9I90Jp5nu+xLdCaiP0Z2JvgS48AA+t0OKf4QwgKpaQbdRLpx+9jRdf1RwQ2ttphM4Oojccx4kVfX+/tBsjG7PccehEGZVLaqq86pqQVUtAb4K3DrX85qPjPPgWUf3u9BH6A4JPz2309lnRXe4Okl3WPsQM/wOVgefh7VSKPecUqg9/mFyVR2Wu9U1a3Z3Jv/g2LBhNn9kMzsTExODjXXllVcONtbk5ORgYw2ttVYzLXfPKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYUyTinUHi/HcLga8vIIAMuWLRtsrLGxscHG2rFjx2BjjY+PDzYWwMaNGwcdbybuOaVQximFMk4plHFKoYxTCmWcUijjlEIZpxTKOKVQximFMk4plHFKoYxTCmWcUijjlEIZpxTKOKVQximFMk4plHFKoYxTCmWcUijjlEIZpxTKOKVQximFirkcw6pVqwYba8jLIwAsX758sLG2bds22FibNm0abKwhtw/wcgyS9sA4pVDGKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYWKuVbK2NjYYGNt3rx5sLFg2OuXDGno13G+cc8phTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKdS8vBzDxMTEYGMdzob8mU1OTg42Vgr3nFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCxVyOYcj/bn/VqlWDjTW0IS+RMOTruHHjxsHGSuGeUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFKoaq3t/saq3d94gC1btmyooXjggQcGGwtg3bp1g421Zs2awcYa8me2evXqwcYaWmutZlrunlMKZZxSKOOUQhmnFMo4pVDGKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYUyTimUcUqhjFMKZZxSKOOUQhmnFMo4pVDGKYUyTimUcUqhjFMKZZxSKOOUQsVcK2VIa9euHXS89evXDzbW5s2bBxtrfHx8sLEOZ14rRTrEGKcUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUyjilUMYphTJOKZRxSqGMUwplnFIo45RCGacUao+XY5A0d9xzSqGMUwplnFIo45RCGacUyjilUP8DHpy7RLLJO2wAAAAASUVORK5CYII=\n",
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAqsAAAAoCAYAAADZoLOuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAFuklEQVR4nO3dXYhcdxnH8e9vowZraF1trSY1hqRUai+0JuCVmItNtL5cCGaVXpQGahZvii9goqAGEd2IUvHKEIoViy+Nii9IqVlQsRe9SCql1hdwV2tst4lpt9gXKU19vJizOC7b7Jl1J3s2+X5u9sw5z/9/ntmZOfPjzNnZVBWSJElSF42sdgOSJEnSizGsSpIkqbMMq5IkSeosw6okSZI6y7AqSZKkzjKsSpIkqbMMq5I6KclDSXaudh9LSfKFJGeSPJZkS5JK8pLV7mtQa7l3SRc2w6qkTqqq66rqV6vdx7kk2Qx8AnhTVb12wLE7k/x9wbqDSe5cyR4laa0zrErS8m0GHq+q06vdiCRdqAyrkjopyV+TjDXLB5McTXJnkqeSPJjkmiSfSnI6yckku/vG7k3yh6Z2JsnEgrk/mWQ2yaNJbmk+/r662bY+yVeS/C3JqSTfSPLyRfobA44BG5M8neSORWoW7SPJK4C7+8Y+neRG4NPAB5vbDzS1lyW5ven3keayg3XNtpuT3Nv0O5fkL0lu6Nv/ucaua8adSTIDvOf/ebwkaVgMq5LWivcB3wZGgd8C99A7hm0CPg8c7qs9DbwXuBTYC9yW5K0ASd4FfBwYA64Gdi7YzyRwDfCWZvsm4LMLm6mqKeAG4NGq2lBVNy/S86J9VNUzC8ZuqKrvAF8Evt/cfnMzxx3A2aaX64HdwC19+3gb8CfgcuDLwO1J0mLsh5vergd2AB9YpH9JWnWGVUlrxW+q6p6qOgscBa4AJqvqeeB7wJYkrwSoqp9X1XT1/Br4BfD2Zp5x4JtV9VBVPQscnN9BE/L2AR+rqieq6il6AfJDy2l4iT6WlORK4N3AR6vqmeZyg9sW9PNwVR2pqheAbwGvA65sMXYc+FpVnayqJ4AvLec+StKw+VefktaKU33L/wLONAFt/jbABuDJ5qPwz9E7QzoCXAI82NRsBI73zXWyb/mKpvbEf09OEmDdchpeoo823gC8FJjt62dkQc+PzS9U1bNN3QbgVUuM3bhgnocH6EuSzhvDqqQLSpL1wA+Bm4CfVNXzSX5ML3QCzAJX9Q15fd/yGXrB97qqemTIfdQiwxauOwk8B1zenFEexFJjZ/nf+755wPkl6bzwMgBJF5qXAeuBfwBnm7Obu/u23wXsTXJtkkuAz8xvqKp/A0foXVv6GoAkm5K8cwh9nAJeneSyBeu2JBlp+pmld+nAV5NcmmQkybYk71hq5y3G3gXcmuSqJKPAgWXcR0kaOsOqpAtKc53prfTC2BxwI/DTvu13A18Hfgn8Gbiv2fRc83P//Pok/wSmgDcOoY8/At8FZpI8mWQjvWtxAR5Pcn+zfBO94Pv7Zp4f0LsutY1zjT1C74/UHgDuB3406H2UpPMhVYt9EiVJF4ck1wK/A9Yv46N2SdKQeWZV0kUnyfub71MdBQ4BPzOoSlI3GVYlXYwm6H0H6jTwAvCR1W1HkvRivAxAkiRJneWZVUmSJHXWOb9nNcmKn3bds2dPq7rJycnWc05NTbWqO3Cg/TezzM3Nta5dTdPT063qtm7d2nrOo0ePLl1E+8cSYHx8fEX3PYjt27e3qjt+/PjSRY1t27a1qpuZmWk957Fjx1rVnThxovWcgzzn2xobG2tVN8hreMeOHcttZ0X239ZK/z4HeQ21fR4P4zE/fPjw0kWNtsfO1XxuAuzfv79V3a5du1rPOTo62qqu7Wsd2h8TDx061HrOttq+bwxy7JyYmGhVN8hro22fwzjODMO+ffta17Z9Hg/yvtH2/XoYqiqLrffMqiRJkjrLsCpJkqTOMqxKkiSpswyrkiRJ6izDqiRJkjrLsCpJkqTOMqxKkiSpswyrkiRJ6izDqiRJkjrLsCpJkqTOStWK/0dVSZIkaUV4ZlWSJEmdZViVJElSZxlWJUmS1FmGVUmSJHWWYVWSJEmdZViVJElSZ/0H/V7EskkLf6EAAAAASUVORK5CYII=\n",
+      "text/plain": [
+       "<Figure size 864x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "import matplotlib.pyplot as plt\n",
+    "plt.imshow(digits.images[0], cmap=\"gray\", )\n",
+    "plt.axis('off')\n",
+    "plt.title(f\"image for figure {labels[0]}\")\n",
+    "fig = plt.figure(figsize=(12, 4))\n",
+    "ax = plt.imshow(features[0][None, :], cmap=\"gray\")\n",
+    "plt.title(\"image flattened\")\n",
+    "plt.axis('off');"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We introduce the `train_test_split` function from `sklearn.model_selection` in the following example.\n",
+    "\n",
+    "It splits features and labels in a given proportion. Usually this is randomized, so that you get different results for every function invocation. To get the same result every time we use `random_state=..` (with arbitrary number) below:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 57,
    "metadata": {},
    "outputs": [
     {
@@ -1247,16 +1313,44 @@
      "output_type": "stream",
      "text": [
       "# Whole dataset \n",
-      "number of all samples: 300\n",
-      "proportion of yummy samples: 0.49666666666666665\n",
+      "number of all samples: 1797\n",
+      "proportion of images for class 0: 0.099\n",
+      "proportion of images for class 1: 0.101\n",
+      "proportion of images for class 2: 0.098\n",
+      "proportion of images for class 3: 0.102\n",
+      "proportion of images for class 4: 0.101\n",
+      "proportion of images for class 5: 0.101\n",
+      "proportion of images for class 6: 0.101\n",
+      "proportion of images for class 7: 0.1\n",
+      "proportion of images for class 8: 0.097\n",
+      "proportion of images for class 9: 0.1\n",
       "\n",
       "# Cross-validation dataset \n",
-      "number of all samples: 240\n",
-      "proportion of yummy samples: 0.49583333333333335\n",
+      "number of all samples: 1437\n",
+      "proportion of images for class 0: 0.099\n",
+      "proportion of images for class 1: 0.102\n",
+      "proportion of images for class 2: 0.099\n",
+      "proportion of images for class 3: 0.102\n",
+      "proportion of images for class 4: 0.101\n",
+      "proportion of images for class 5: 0.101\n",
+      "proportion of images for class 6: 0.101\n",
+      "proportion of images for class 7: 0.1\n",
+      "proportion of images for class 8: 0.097\n",
+      "proportion of images for class 9: 0.1\n",
       "\n",
       "# Validation dataset \n",
-      "number of all samples: 60\n",
-      "proportion of yummy samples: 0.5\n"
+      "number of all samples: 360\n",
+      "proportion of images for class 0: 0.1\n",
+      "proportion of images for class 1: 0.1\n",
+      "proportion of images for class 2: 0.097\n",
+      "proportion of images for class 3: 0.103\n",
+      "proportion of images for class 4: 0.1\n",
+      "proportion of images for class 5: 0.103\n",
+      "proportion of images for class 6: 0.1\n",
+      "proportion of images for class 7: 0.1\n",
+      "proportion of images for class 8: 0.097\n",
+      "proportion of images for class 9: 0.1\n",
+      "\n"
      ]
     }
    ],
@@ -1274,39 +1368,85 @@
     "    labels_validation,\n",
     ") = train_test_split(features, labels, test_size=0.2, stratify=labels, random_state=42)\n",
     "\n",
+    "def report(labels):\n",
+    "    print(\"number of all samples:\", len(labels))\n",
+    "    for number in range(10):\n",
+    "        print(f\"proportion of images for class {number}:\", round(sum(labels == number) / len(labels), 3))\n",
+    "    print()\n",
+    "\n",
     "print(\"# Whole dataset \")\n",
-    "print(\"number of all samples:\", len(labels))\n",
-    "print(\"proportion of yummy samples:\", sum(labels == 1) / len(labels))\n",
-    "print()\n",
+    "report(labels)\n",
     "print(\"# Cross-validation dataset \")\n",
-    "print(\"number of all samples:\", len(labels_crosseval))\n",
-    "print(\n",
-    "    \"proportion of yummy samples:\", sum(labels_crosseval == 1) / len(labels_crosseval)\n",
-    ")\n",
-    "print()\n",
+    "report(labels_crosseval)\n",
     "print(\"# Validation dataset \")\n",
-    "print(\"number of all samples:\", len(labels_validation))\n",
-    "print(\n",
-    "    \"proportion of yummy samples:\", sum(labels_validation == 1) / len(labels_validation)\n",
-    ")"
+    "report(labels_validation)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
+    "As you can see the splits maintained the distribution of all classes `0` to `9`.\n",
+    "\n",
+    "\n",
+    "\n",
     "Moreover, we introduce use of explicit speficiation of a cross-validation method: `StratifiedKFold` from `sklearn.model_selection`. \n",
     "\n",
-    "This allows us to spilt data during cross validation in the same way as we did with `train_test_split`, i.e. \n",
+    "`StratifiedKFold` allows us to splitt data during cross validation in the same way as we did with `train_test_split`, i.e. \n",
     "\n",
-    "a) with data shufflling before split, and \n",
+    "1. with data shufflling before split, and \n",
+    "2. **perserving class-proportions of samples**. \n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 56,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "score = 0.103 +/- 0.011, C = 1.0e+01,  gamma = 1.0e-01\n"
+     ]
+    }
+   ],
+   "source": [
+    "from sklearn.model_selection import StratifiedKFold\n",
     "\n",
-    "b) perserving class-proportions of samples, "
+    "cross_validator = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)\n",
+    "\n",
+    "\n",
+    "classifier = SVC(C=1, gamma=0.1)\n",
+    "test_scores = cross_val_score(\n",
+    "    classifier,\n",
+    "    features_crosseval,\n",
+    "    labels_crosseval,\n",
+    "    scoring=\"accuracy\",\n",
+    "    cv=cross_validator,\n",
+    ")  # cv arg is now different\n",
+    "print(\n",
+    "    \"score = {:.3f} +/- {:.3f}, C = {:.1e},  gamma = {:.1e}\".format(\n",
+    "        test_scores.mean(), test_scores.std(), C, gamma\n",
+    "    )\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can now try to use this approach to tune the **hyper-parameters** `C` and `gamma` of the `SVC` classifier.\n",
+    "\n",
+    "Remember:\n",
+    "1. A classifier learns parameters\n",
+    "2. Hyper-parameters control how a classifier learns."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 58,
    "metadata": {},
    "outputs": [
     {
@@ -1314,21 +1454,21 @@
      "output_type": "stream",
      "text": [
       "OPTIMIZE HYPERPARAMETERS\n",
-      "score = 0.725 +/- 0.086, C =   0.1,  gamma =   0.1\n",
-      "score = 0.796 +/- 0.096, C =   0.1,  gamma =   1.0\n",
-      "score = 0.867 +/- 0.064, C =   0.1,  gamma =  10.0\n",
-      "score = 0.512 +/- 0.037, C =   0.1,  gamma = 100.0\n",
-      "score = 0.842 +/- 0.083, C =   1.0,  gamma =   0.1\n",
-      "score = 0.908 +/- 0.055, C =   1.0,  gamma =   1.0\n",
-      "score = 0.892 +/- 0.046, C =   1.0,  gamma =  10.0\n",
-      "score = 0.754 +/- 0.057, C =   1.0,  gamma = 100.0\n",
-      "score = 0.938 +/- 0.038, C =  10.0,  gamma =   0.1\n",
-      "score = 0.950 +/- 0.036, C =  10.0,  gamma =   1.0\n",
-      "score = 0.925 +/- 0.041, C =  10.0,  gamma =  10.0\n",
-      "score = 0.754 +/- 0.057, C =  10.0,  gamma = 100.0\n",
+      "score = 0.900 +/- 0.029, C = 1.0e-01,  gamma = 1.0e-04\n",
+      "score = 0.962 +/- 0.012, C = 1.0e-01,  gamma = 1.0e-03\n",
+      "score = 0.125 +/- 0.056, C = 1.0e-01,  gamma = 1.0e-02\n",
+      "score = 0.099 +/- 0.003, C = 1.0e-01,  gamma = 1.0e-01\n",
+      "score = 0.971 +/- 0.010, C = 1.0e+00,  gamma = 1.0e-04\n",
+      "score = 0.991 +/- 0.007, C = 1.0e+00,  gamma = 1.0e-03\n",
+      "score = 0.836 +/- 0.028, C = 1.0e+00,  gamma = 1.0e-02\n",
+      "score = 0.103 +/- 0.011, C = 1.0e+00,  gamma = 1.0e-01\n",
+      "score = 0.986 +/- 0.011, C = 1.0e+01,  gamma = 1.0e-04\n",
+      "score = 0.990 +/- 0.006, C = 1.0e+01,  gamma = 1.0e-03\n",
+      "score = 0.843 +/- 0.026, C = 1.0e+01,  gamma = 1.0e-02\n",
+      "score = 0.104 +/- 0.014, C = 1.0e+01,  gamma = 1.0e-01\n",
       "\n",
       "BEST RESULT CROSS VALIDATION\n",
-      "score = 0.950 +/- 0.036, C = 10.0,  gamma = 1.0\n"
+      "score = 0.991 +/- 0.007, C = 1.0e+00,  gamma = 0.0\n"
      ]
     }
    ],
@@ -1346,8 +1486,8 @@
     "\n",
     "print(\"OPTIMIZE HYPERPARAMETERS\")\n",
     "# selected classifier hyperparameters to optimize\n",
-    "SVC_C_values = (0.1, 1, 10)\n",
-    "SVC_gamma_values = (0.1, 1, 10, 100)\n",
+    "SVC_C_values = (1e-1, 1, 10)\n",
+    "SVC_gamma_values = (0.0001, 0.001, 0.01, 0.1)\n",
     "\n",
     "for C in SVC_C_values:\n",
     "    for gamma in SVC_gamma_values:\n",
@@ -1360,7 +1500,7 @@
     "            cv=cross_validator,\n",
     "        )  # cv arg is now different\n",
     "        print(\n",
-    "            \"score = {:.3f} +/- {:.3f}, C = {:5.1f},  gamma = {:5.1f}\".format(\n",
+    "            \"score = {:.3f} +/- {:.3f}, C = {:.1e},  gamma = {:.1e}\".format(\n",
     "                test_scores.mean(), test_scores.std(), C, gamma\n",
     "            )\n",
     "        )\n",
@@ -1375,15 +1515,22 @@
     "print()\n",
     "print(\"BEST RESULT CROSS VALIDATION\")\n",
     "print(\n",
-    "    \"score = {:.3f} +/- {:.3f}, C = {:.1f},  gamma = {:.1f}\".format(\n",
+    "    \"score = {:.3f} +/- {:.3f}, C = {:.1e},  gamma = {:.1f}\".format(\n",
     "        best_score_mean, best_score_std, best_C, best_gamma\n",
     "    )\n",
     ")"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Finally we evaluate our tuned classifier on the validation data set:"
+   ]
+  },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 60,
    "metadata": {},
    "outputs": [
     {
@@ -1391,7 +1538,7 @@
      "output_type": "stream",
      "text": [
       "VALIDATION\n",
-      "score = 0.967\n"
+      "score = 0.989\n"
      ]
     }
    ],
@@ -1422,7 +1569,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 2,
    "metadata": {
     "tags": [
      "solution"
@@ -1597,7 +1744,7 @@
     "\n",
     "2. Run cross-validation for the `SVC` classifier applied to the `\"data/spiral.csv\"` data set. Try different `C` and `gamma` values.\n",
     "\n",
-    "2. Optional exercise: implement same strategy for the iris data set introduced in script 1."
+    "3. Implement same strategy for the iris data set introduced in script 1."
    ]
   },
   {
@@ -1900,6 +2047,13 @@
     "# Here, SVC is robust"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
-- 
GitLab