Skip to content
Snippets Groups Projects
create_datasets.py.ipynb 155 KiB
Newer Older
  • Learn to ignore specific revisions
  • schmittu's avatar
    schmittu committed
    {
     "cells": [
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 2,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [],
       "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "\n",
        "import pylab\n",
        "import matplotlib\n",
        "\n",
        "%matplotlib inline\n",
        "\n",
        "np.random.seed(44)"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 3,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [],
       "source": [
        "feature_names = [\"alcohol_content\", \"bitterness\", \"darkness\"]\n",
        "\n",
        "beer_kinds = [\"pils\", \"pale ale\", \"stout\"]\n",
        "\n",
        "# centers of features\n",
        "centers = {\"pils\": (4.5, 0.5, 1,),\n",
        "           \"pale ale\": (4.5, 0.6, 2),\n",
        "           \"stout\": (5, .3, 5)}\n",
        "\n",
        "# std deviations of features:\n",
        "deviations = {\"pils\": (.3, .2, .5),\n",
        "              \"pale ale\": (.5, .1, .5),\n",
        "              \"stout\": (.3, .3, 1)}\n",
        "\n",
        "\n",
        "\n",
        "# feature fruitiness is redundant:\n",
        "feature_names.append(\"fruitiness\")\n",
        "\n",
        "def sample_features(kind):\n",
        "    means = centers[kind]\n",
        "    stddevs = deviations[kind]\n",
        "    features = [max(0.0, m + s * np.random.randn()) for (m, s) in zip(means, stddevs)]\n",
        "    # fruitiness correlates with hop and negatively with darkness:\n",
        "    fruitiness = 0.5 * (features[1] - .01 * features[2] + 0.06 * np.random.randn())\n",
        "    features.append(max(0, fruitiness))\n",
        "    return features"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 4,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
          "text/plain": [
           "[4.274815584823238, 0.763271464942364, 1.6230700143217152, 0.3253729101618156]"
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 4,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "sample_features(\"pils\")"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 5,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
          "text/html": [
           "<div>\n",
           "<style scoped>\n",
           "    .dataframe tbody tr th:only-of-type {\n",
           "        vertical-align: middle;\n",
           "    }\n",
           "\n",
           "    .dataframe tbody tr th {\n",
           "        vertical-align: top;\n",
           "    }\n",
           "\n",
           "    .dataframe thead th {\n",
           "        text-align: right;\n",
           "    }\n",
           "</style>\n",
           "<table border=\"1\" class=\"dataframe\">\n",
           "  <thead>\n",
           "    <tr style=\"text-align: right;\">\n",
           "      <th></th>\n",
           "      <th>alcohol_content</th>\n",
           "      <th>bitterness</th>\n",
           "      <th>darkness</th>\n",
           "      <th>fruitiness</th>\n",
           "    </tr>\n",
           "  </thead>\n",
           "  <tbody>\n",
           "    <tr>\n",
           "      <th>count</th>\n",
           "      <td>300.000000</td>\n",
           "      <td>300.000000</td>\n",
           "      <td>300.000000</td>\n",
           "      <td>300.000000</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>mean</th>\n",
           "      <td>4.705936</td>\n",
           "      <td>0.466646</td>\n",
           "      <td>2.587510</td>\n",
           "      <td>0.221585</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>std</th>\n",
           "      <td>0.448071</td>\n",
           "      <td>0.225136</td>\n",
           "      <td>1.741583</td>\n",
           "      <td>0.116405</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>min</th>\n",
           "      <td>3.073993</td>\n",
           "      <td>0.000000</td>\n",
           "      <td>0.000000</td>\n",
           "      <td>0.000000</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>25%</th>\n",
           "      <td>4.421718</td>\n",
           "      <td>0.287010</td>\n",
           "      <td>1.192515</td>\n",
           "      <td>0.137466</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>50%</th>\n",
           "      <td>4.714367</td>\n",
           "      <td>0.499811</td>\n",
           "      <td>2.012838</td>\n",
           "      <td>0.242206</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>75%</th>\n",
           "      <td>5.005725</td>\n",
           "      <td>0.626256</td>\n",
           "      <td>4.075562</td>\n",
           "      <td>0.303578</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>max</th>\n",
           "      <td>5.955272</td>\n",
           "      <td>1.080170</td>\n",
           "      <td>7.221285</td>\n",
           "      <td>0.535315</td>\n",
           "    </tr>\n",
           "  </tbody>\n",
           "</table>\n",
           "</div>"
          ],
          "text/plain": [
           "       alcohol_content  bitterness    darkness  fruitiness\n",
           "count       300.000000  300.000000  300.000000  300.000000\n",
           "mean          4.705936    0.466646    2.587510    0.221585\n",
           "std           0.448071    0.225136    1.741583    0.116405\n",
           "min           3.073993    0.000000    0.000000    0.000000\n",
           "25%           4.421718    0.287010    1.192515    0.137466\n",
           "50%           4.714367    0.499811    2.012838    0.242206\n",
           "75%           5.005725    0.626256    4.075562    0.303578\n",
           "max           5.955272    1.080170    7.221285    0.535315"
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 5,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "# rows per beer kind:\n",
        "N = 100\n",
        "\n",
        "rows = []\n",
        "\n",
        "ns = (100, 100, 100)\n",
        "for i, (n, kind) in enumerate(zip(ns, beer_kinds)):\n",
        "    rows.extend(sample_features(kind) for _ in range(n))\n",
        "    \n",
        "rows = np.array(rows)\n",
        "\n",
        "# full_features also contain beer kind\n",
        "\n",
        "features = pd.DataFrame(rows, columns = feature_names)\n",
        "features[\"fruitiness\"] -= features[\"fruitiness\"].min()\n",
        "\n",
        "# shuffle rows, see\n",
        "# https://stackoverflow.com/questions/29576430/shuffle-dataframe-rows\n",
        "features = features.sample(frac=1).reset_index(drop=True)\n",
        "\n",
        "features.describe()"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 6,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
          "text/html": [
           "<div>\n",
           "<style scoped>\n",
           "    .dataframe tbody tr th:only-of-type {\n",
           "        vertical-align: middle;\n",
           "    }\n",
           "\n",
           "    .dataframe tbody tr th {\n",
           "        vertical-align: top;\n",
           "    }\n",
           "\n",
           "    .dataframe thead th {\n",
           "        text-align: right;\n",
           "    }\n",
           "</style>\n",
           "<table border=\"1\" class=\"dataframe\">\n",
           "  <thead>\n",
           "    <tr style=\"text-align: right;\">\n",
           "      <th></th>\n",
           "      <th>alcohol_content</th>\n",
           "      <th>bitterness</th>\n",
           "      <th>darkness</th>\n",
           "      <th>fruitiness</th>\n",
           "    </tr>\n",
           "  </thead>\n",
           "  <tbody>\n",
           "    <tr>\n",
           "      <th>0</th>\n",
           "      <td>3.739295</td>\n",
           "      <td>0.422503</td>\n",
           "      <td>0.989463</td>\n",
           "      <td>0.215791</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>1</th>\n",
           "      <td>4.207849</td>\n",
           "      <td>0.841668</td>\n",
           "      <td>0.928626</td>\n",
           "      <td>0.380420</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>2</th>\n",
           "      <td>4.709494</td>\n",
           "      <td>0.322037</td>\n",
           "      <td>5.374682</td>\n",
           "      <td>0.145231</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>3</th>\n",
           "      <td>4.684743</td>\n",
           "      <td>0.434315</td>\n",
           "      <td>4.072805</td>\n",
           "      <td>0.191321</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>4</th>\n",
           "      <td>4.148710</td>\n",
           "      <td>0.570586</td>\n",
           "      <td>1.461568</td>\n",
           "      <td>0.260218</td>\n",
           "    </tr>\n",
           "  </tbody>\n",
           "</table>\n",
           "</div>"
          ],
          "text/plain": [
           "   alcohol_content  bitterness  darkness  fruitiness\n",
           "0         3.739295    0.422503  0.989463    0.215791\n",
           "1         4.207849    0.841668  0.928626    0.380420\n",
           "2         4.709494    0.322037  5.374682    0.145231\n",
           "3         4.684743    0.434315  4.072805    0.191321\n",
           "4         4.148710    0.570586  1.461568    0.260218"
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 6,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "features.head()"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 7,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "(300,)\n",
          "149 good\n",
          "150 bad\n"
         ]
        },
        {
         "data": {
    
    schmittu's avatar
    schmittu committed
          "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADrVJREFUeJzt3X2MZXV9x/H3p6xi1baAjNsti51tJRpqykMmBKJ/WPCBpwBNiMGQdptusv/YFFsSu0jSxKR/LGnjQxNr3Yhl01CFopYNWC1dMU2TBh18QGDdsuJSd7OwQws+tEnr6rd/3LNxXGa8Z2bu3Xvnx/uVTOaeh5n72V/mfubs75xzJ1WFJGn9+7lJB5AkjYaFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhqxoc9OSQ4C3wd+BByrqrkkZwB3AbPAQeAdVfXceGJKkoZZyRH6b1XV+VU11y3vAPZW1TnA3m5ZkjQh6XOnaHeEPldVzy5atx94c1UdSbIJ+GJVve5nfZ8zzzyzZmdn15ZYkl5kHn744WerambYfr2mXIAC/ilJAR+tql3Axqo60m1/Gtg47JvMzs4yPz/f8yklSQBJnuqzX99Cf1NVHU7yauCBJN9cvLGqqiv7pYJsB7YDvOY1r+n5dJKkleo1h15Vh7vPR4HPABcBz3RTLXSfjy7ztbuqaq6q5mZmhv6PQZK0SkMLPckrkvzC8cfA24BHgT3A1m63rcC94wopSRquz5TLRuAzSY7v/3dV9bkkXwbuTrINeAp4x/hiSpKGGVroVfUkcN4S6/8TuGwcoSRJK+edopLUCAtdkhphoUtSIyx0SWpE3xuLXrRmd9zfe9+DO68aYxJJ+tk8QpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSI3xzrhHq+0ZevomXpHHwCF2SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1Aj/wMUE+IcwJI1D7yP0JKck+WqS+7rlLUkeSnIgyV1JXjq+mJKkYVYy5XITsG/R8m3AB6rqtcBzwLZRBpMkrUyvQk+yGbgK+Fi3HOBS4J5ul93AdeMIKEnqp+8R+geB9wA/7pZfBTxfVce65UPAWUt9YZLtSeaTzC8sLKwprCRpeUMLPcnVwNGqeng1T1BVu6pqrqrmZmZmVvMtJEk99LnK5Y3ANUmuBF4G/CLwIeC0JBu6o/TNwOHxxZQkDTP0CL2qbqmqzVU1C9wAfKGqbgQeBK7vdtsK3Du2lJKkodZyY9GfAH+c5ACDOfXbRxNJkrQaK7qxqKq+CHyxe/wkcNHoI0mSVsNb/yWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGrFh0gG0vNkd9/fa7+DOq8acRNJ64BG6JDXCQpekRljoktSIF+0cet/5aUlaL4YeoSd5WZIvJfl6kseSvK9bvyXJQ0kOJLkryUvHH1eStJw+Uy7/C1xaVecB5wOXJ7kYuA34QFW9FngO2Da+mJKkYYYWeg38oFt8SfdRwKXAPd363cB1Y0koSeql10nRJKck+RpwFHgA+BbwfFUd63Y5BJy1zNduTzKfZH5hYWEUmSVJS+hV6FX1o6o6H9gMXAS8vu8TVNWuqpqrqrmZmZlVxpQkDbOiyxar6nngQeAS4LQkx6+S2QwcHnE2SdIK9LnKZSbJad3jnwfeCuxjUOzXd7ttBe4dV0hJ0nB9rkPfBOxOcgqDXwB3V9V9SR4HPpnkz4CvArePMackaYihhV5VjwAXLLH+SQbz6ZKkKeCt/5LUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqRJ93W9SUm91xf6/9Du68asxJJE2SR+iS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjRha6EnOTvJgkseTPJbkpm79GUkeSPJE9/n08ceVJC2nzxH6MeDmqjoXuBh4V5JzgR3A3qo6B9jbLUuSJmRooVfVkar6Svf4+8A+4CzgWmB3t9tu4LpxhZQkDbeiOfQks8AFwEPAxqo60m16Gtg40mSSpBXpXehJXgl8Cnh3VX1v8baqKqCW+brtSeaTzC8sLKwprCRpeb0KPclLGJT5nVX16W71M0k2dds3AUeX+tqq2lVVc1U1NzMzM4rMkqQl9LnKJcDtwL6qev+iTXuArd3jrcC9o48nSeprQ4993gj8DvCNJF/r1r0X2AncnWQb8BTwjvFElCT1MbTQq+pfgSyz+bLRxpEkrZZ3ikpSIyx0SWqEhS5JjehzUlSNmN1xf6/9Du68asxJJI2DR+iS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpEUMLPcnHkxxN8uiidWckeSDJE93n08cbU5I0TJ8j9DuAy09YtwPYW1XnAHu7ZUnSBA0t9Kr6F+C/Tlh9LbC7e7wbuG7EuSRJK7TaOfSNVXWke/w0sHFEeSRJq7Tmk6JVVUAttz3J9iTzSeYXFhbW+nSSpGWsttCfSbIJoPt8dLkdq2pXVc1V1dzMzMwqn06SNMxqC30PsLV7vBW4dzRxJEmr1eeyxU8A/wa8LsmhJNuAncBbkzwBvKVbliRN0IZhO1TVO5fZdNmIs4zE7I77Jx1BkibCO0UlqREWuiQ1wkKXpEYMnUPXi0/f8xAHd1415iSSVsIjdElqhIUuSY2w0CWpERa6JDXCk6I6KTzRKo2fR+iS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXC69C1apP8YyJe1y69kEfoktQIC12SGmGhS1IjLHRJaoQnRTVVRn2i1ZOnejHxCF2SGmGhS1IjLHRJasS6mUOf5E0s0nEr+Tl0Xl4nm0foktQIC12SGmGhS1Ij1s0curTevNiugff8wuR5hC5JjbDQJakRFrokNcI5dIn18cc6Rq3vPPZ6GJtJzclPW741HaEnuTzJ/iQHkuwYVShJ0sqtutCTnAJ8GLgCOBd4Z5JzRxVMkrQyazlCvwg4UFVPVtX/AZ8Erh1NLEnSSq2l0M8CvrNo+VC3TpI0AWM/KZpkO7C9W/xBkv0r+PIzgWdHn2qszHxymHmNcluv3caSuedzr/b7TdU4Q69/77DMv9rnedZS6IeBsxctb+7W/ZSq2gXsWs0TJJmvqrnVxZsMM58cZj45zHxyjCrzWqZcvgyck2RLkpcCNwB71hpIkrQ6qz5Cr6pjSf4A+DxwCvDxqnpsZMkkSSuypjn0qvos8NkRZVnKqqZqJszMJ4eZTw4znxwjyZyqGsX3kSRNmO/lIkmNmMpCXw9vKZDk7CQPJnk8yWNJburWn5HkgSRPdJ9Pn3TWEyU5JclXk9zXLW9J8lA33nd1J7mnRpLTktyT5JtJ9iW5ZNrHOckfdT8Xjyb5RJKXTds4J/l4kqNJHl20bslxzcBfdtkfSXLhFGX+8+5n45Ekn0ly2qJtt3SZ9yd5+7RkXrTt5iSV5MxueU3jPHWFvo7eUuAYcHNVnQtcDLyry7kD2FtV5wB7u+VpcxOwb9HybcAHquq1wHPAtomkWt6HgM9V1euB8xhkn9pxTnIW8IfAXFW9gcFFAzcwfeN8B3D5CeuWG9crgHO6j+3AR05SxhPdwQszPwC8oap+E/h34BaA7vV4A/Ab3df8VdcvJ9sdvDAzSc4G3gb8x6LVaxvnqpqqD+AS4POLlm8Bbpl0rh657wXeCuwHNnXrNgH7J53thJybGbxQLwXuA8LghoYNS43/pD+AXwK+TXe+Z9H6qR1nfnIX9RkMLjy4D3j7NI4zMAs8OmxcgY8C71xqv0lnPmHbbwN3do9/qjsYXJF3ybRkBu5hcIByEDhzFOM8dUforMO3FEgyC1wAPARsrKoj3aangY0TirWcDwLvAX7cLb8KeL6qjnXL0zbeW4AF4G+6aaKPJXkFUzzOVXUY+AsGR15HgO8CDzPd43zccuO6Xl6Xvw/8Y/d4ajMnuRY4XFVfP2HTmjJPY6GvK0leCXwKeHdVfW/xthr8ip2ay4iSXA0craqHJ51lBTYAFwIfqaoLgP/mhOmVKRzn0xm8Ud0W4FeAV7DEf7mn3bSN6zBJbmUwFXrnpLP8LEleDrwX+NNRf+9pLPRebykwDZK8hEGZ31lVn+5WP5NkU7d9E3B0UvmW8EbgmiQHGbw75qUM5qdPS3L8noRpG+9DwKGqeqhbvodBwU/zOL8F+HZVLVTVD4FPMxj7aR7n45Yb16l+XSb5PeBq4MbuFxFMb+ZfZ/DL/uvda3Ez8JUkv8waM09joa+LtxRIEuB2YF9VvX/Rpj3A1u7xVgZz61Ohqm6pqs1VNctgXL9QVTcCDwLXd7tNW+ange8keV236jLgcaZ4nBlMtVyc5OXdz8nxzFM7zossN657gN/trsK4GPjuoqmZiUpyOYNpxGuq6n8WbdoD3JDk1CRbGJxo/NIkMi5WVd+oqldX1Wz3WjwEXNj9rK9tnCdxgqDHCYQrGZyt/hZw66TzLJPxTQz+O/oI8LXu40oGc9J7gSeAfwbOmHTWZfK/Gbive/xrDH7QDwB/D5w66XwnZD0fmO/G+h+A06d9nIH3Ad8EHgX+Fjh12sYZ+ASDOf4fdqWybblxZXDy/MPda/IbDK7gmZbMBxjMOx9/Hf71ov1v7TLvB66YlswnbD/IT06KrmmcvVNUkhoxjVMukqRVsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWrE/wPGhsdyapPyewAAAABJRU5ErkJggg==\n",
    
    schmittu's avatar
    schmittu committed
          "text/plain": [
           "<Figure size 432x288 with 1 Axes>"
          ]
         },
         "metadata": {
          "needs_background": "light"
         },
         "output_type": "display_data"
        }
       ],
       "source": [
        "# compute score which we use for assigning class label:\n",
        "\n",
        "weights_uwe = np.array((-1, 1.5, -2, 1.5))\n",
        "scores = np.array(features @ weights_uwe)\n",
        "\n",
        "# add some non linear term to make svm work better than logistic regression:\n",
        "scores = scores + 1 + 1 * 0.8 * features.iloc[:, 0] ** 2 * (1 + features.iloc[:, 1]  * features.iloc[:, 2])\n",
        "\n",
        "\n",
        "print(scores.shape)\n",
        "\n",
        "pylab.hist(scores, bins=30)\n",
        "\n",
        "\n",
        "# add some noise:\n",
        "scores +=  .1 * np.random.randn(len(scores))\n",
        "\n",
        "# threshold is median of scores, so we get a balanced data set:\n",
        "thresh = sorted(scores)[len(scores) // 2]\n",
        "\n",
        "# move some low scored beers towards the \"center\":\n",
        "lowlim = sorted(scores)[len(scores) // 10]\n",
        "scores[scores < lowlim] += 0.4 * np.median(scores)\n",
        "\n",
        "good = (scores>thresh)\n",
        "bad = (scores<thresh)\n",
        "\n",
        "print(sum(good), \"good\")\n",
        "print(sum(bad), \"bad\")\n",
        "\n",
        "\n",
        "labels = np.zeros(sum(ns), dtype=int)\n",
        "labels[good] = 1\n",
        "\n",
        "features[\"is_yummy\"] = labels\n",
        "# labels[:100] = 1"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 8,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
    
    schmittu's avatar
    schmittu committed
          "text/html": [
           "<div>\n",
           "<style scoped>\n",
           "    .dataframe tbody tr th:only-of-type {\n",
           "        vertical-align: middle;\n",
           "    }\n",
           "\n",
           "    .dataframe tbody tr th {\n",
           "        vertical-align: top;\n",
           "    }\n",
           "\n",
           "    .dataframe thead th {\n",
           "        text-align: right;\n",
           "    }\n",
           "</style>\n",
           "<table border=\"1\" class=\"dataframe\">\n",
           "  <thead>\n",
           "    <tr style=\"text-align: right;\">\n",
           "      <th></th>\n",
           "      <th>alcohol_content</th>\n",
           "      <th>bitterness</th>\n",
           "      <th>darkness</th>\n",
    
    schmittu's avatar
    schmittu committed
           "      <th>fruitiness</th>\n",
    
    schmittu's avatar
    schmittu committed
           "      <th>label</th>\n",
           "    </tr>\n",
           "  </thead>\n",
           "  <tbody>\n",
           "    <tr>\n",
           "      <th>0</th>\n",
           "      <td>3.739295</td>\n",
           "      <td>0.422503</td>\n",
           "      <td>0.989463</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>0.215791</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>class_0</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>1</th>\n",
           "      <td>4.207849</td>\n",
           "      <td>0.841668</td>\n",
           "      <td>0.928626</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>0.380420</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>class_0</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>2</th>\n",
           "      <td>4.709494</td>\n",
           "      <td>0.322037</td>\n",
           "      <td>5.374682</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>0.145231</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>class_1</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>3</th>\n",
           "      <td>4.684743</td>\n",
           "      <td>0.434315</td>\n",
           "      <td>4.072805</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>0.191321</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>class_1</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>4</th>\n",
           "      <td>4.148710</td>\n",
           "      <td>0.570586</td>\n",
           "      <td>1.461568</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>0.260218</td>\n",
    
    schmittu's avatar
    schmittu committed
           "      <td>class_0</td>\n",
           "    </tr>\n",
           "  </tbody>\n",
           "</table>\n",
           "</div>"
          ],
    
    schmittu's avatar
    schmittu committed
          "text/plain": [
    
    schmittu's avatar
    schmittu committed
           "   alcohol_content  bitterness  darkness  fruitiness    label\n",
           "0         3.739295    0.422503  0.989463    0.215791  class_0\n",
           "1         4.207849    0.841668  0.928626    0.380420  class_0\n",
           "2         4.709494    0.322037  5.374682    0.145231  class_1\n",
           "3         4.684743    0.434315  4.072805    0.191321  class_1\n",
           "4         4.148710    0.570586  1.461568    0.260218  class_0"
    
    schmittu's avatar
    schmittu committed
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 8,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
    
    schmittu's avatar
    schmittu committed
         "output_type": "execute_result"
    
    schmittu's avatar
    schmittu committed
        }
       ],
       "source": [
        "import seaborn as sns\n",
        "sns.set(style=\"ticks\")\n",
        "\n",
        "for_plot = features.iloc[:, :-1].copy()\n",
        "for_plot[\"label\"] = [\"class_\" + li for li in labels.astype(str)]\n",
        "\n",
    
    schmittu's avatar
    schmittu committed
        "for_plot.head()\n",
        "\n",
        "# sns.pairplot(for_plot, hue=\"label\", diag_kind=\"hist\");"
    
    schmittu's avatar
    schmittu committed
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 9,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [],
       "source": [
        "learn = features.iloc[:225, :]\n",
        "learn.to_csv(\"beers.csv\", index=False)\n",
        "for_eval = features.iloc[225:, :]\n",
        "for_eval.to_csv(\"beers_eval.csv\", index=False)"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 10,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [],
       "source": [
        "from sklearn.linear_model import LogisticRegression\n",
    
    schmittu's avatar
    schmittu committed
        "from sklearn.svm import SVC"
    
    schmittu's avatar
    schmittu committed
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 11,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
          "text/html": [
           "<div>\n",
           "<style scoped>\n",
           "    .dataframe tbody tr th:only-of-type {\n",
           "        vertical-align: middle;\n",
           "    }\n",
           "\n",
           "    .dataframe tbody tr th {\n",
           "        vertical-align: top;\n",
           "    }\n",
           "\n",
           "    .dataframe thead th {\n",
           "        text-align: right;\n",
           "    }\n",
           "</style>\n",
           "<table border=\"1\" class=\"dataframe\">\n",
           "  <thead>\n",
           "    <tr style=\"text-align: right;\">\n",
           "      <th></th>\n",
           "      <th>alcohol_content</th>\n",
           "      <th>bitterness</th>\n",
           "      <th>darkness</th>\n",
           "      <th>fruitiness</th>\n",
           "      <th>is_yummy</th>\n",
           "    </tr>\n",
           "  </thead>\n",
           "  <tbody>\n",
           "    <tr>\n",
           "      <th>0</th>\n",
           "      <td>3.739295</td>\n",
           "      <td>0.422503</td>\n",
           "      <td>0.989463</td>\n",
           "      <td>0.215791</td>\n",
           "      <td>0</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>1</th>\n",
           "      <td>4.207849</td>\n",
           "      <td>0.841668</td>\n",
           "      <td>0.928626</td>\n",
           "      <td>0.380420</td>\n",
           "      <td>0</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>2</th>\n",
           "      <td>4.709494</td>\n",
           "      <td>0.322037</td>\n",
           "      <td>5.374682</td>\n",
           "      <td>0.145231</td>\n",
           "      <td>1</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>3</th>\n",
           "      <td>4.684743</td>\n",
           "      <td>0.434315</td>\n",
           "      <td>4.072805</td>\n",
           "      <td>0.191321</td>\n",
           "      <td>1</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>4</th>\n",
           "      <td>4.148710</td>\n",
           "      <td>0.570586</td>\n",
           "      <td>1.461568</td>\n",
           "      <td>0.260218</td>\n",
           "      <td>0</td>\n",
           "    </tr>\n",
           "  </tbody>\n",
           "</table>\n",
           "</div>"
          ],
          "text/plain": [
           "   alcohol_content  bitterness  darkness  fruitiness  is_yummy\n",
           "0         3.739295    0.422503  0.989463    0.215791         0\n",
           "1         4.207849    0.841668  0.928626    0.380420         0\n",
           "2         4.709494    0.322037  5.374682    0.145231         1\n",
           "3         4.684743    0.434315  4.072805    0.191321         1\n",
           "4         4.148710    0.570586  1.461568    0.260218         0"
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 11,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "beers = pd.read_csv(\"beers.csv\")\n",
        "beers.head()"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 12,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
          "text/html": [
           "<div>\n",
           "<style scoped>\n",
           "    .dataframe tbody tr th:only-of-type {\n",
           "        vertical-align: middle;\n",
           "    }\n",
           "\n",
           "    .dataframe tbody tr th {\n",
           "        vertical-align: top;\n",
           "    }\n",
           "\n",
           "    .dataframe thead th {\n",
           "        text-align: right;\n",
           "    }\n",
           "</style>\n",
           "<table border=\"1\" class=\"dataframe\">\n",
           "  <thead>\n",
           "    <tr style=\"text-align: right;\">\n",
           "      <th></th>\n",
           "      <th>alcohol_content</th>\n",
           "      <th>bitterness</th>\n",
           "      <th>darkness</th>\n",
           "      <th>fruitiness</th>\n",
           "    </tr>\n",
           "  </thead>\n",
           "  <tbody>\n",
           "    <tr>\n",
           "      <th>0</th>\n",
           "      <td>3.739295</td>\n",
           "      <td>0.422503</td>\n",
           "      <td>0.989463</td>\n",
           "      <td>0.215791</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>1</th>\n",
           "      <td>4.207849</td>\n",
           "      <td>0.841668</td>\n",
           "      <td>0.928626</td>\n",
           "      <td>0.380420</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>2</th>\n",
           "      <td>4.709494</td>\n",
           "      <td>0.322037</td>\n",
           "      <td>5.374682</td>\n",
           "      <td>0.145231</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>3</th>\n",
           "      <td>4.684743</td>\n",
           "      <td>0.434315</td>\n",
           "      <td>4.072805</td>\n",
           "      <td>0.191321</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>4</th>\n",
           "      <td>4.148710</td>\n",
           "      <td>0.570586</td>\n",
           "      <td>1.461568</td>\n",
           "      <td>0.260218</td>\n",
           "    </tr>\n",
           "  </tbody>\n",
           "</table>\n",
           "</div>"
          ],
          "text/plain": [
           "   alcohol_content  bitterness  darkness  fruitiness\n",
           "0         3.739295    0.422503  0.989463    0.215791\n",
           "1         4.207849    0.841668  0.928626    0.380420\n",
           "2         4.709494    0.322037  5.374682    0.145231\n",
           "3         4.684743    0.434315  4.072805    0.191321\n",
           "4         4.148710    0.570586  1.461568    0.260218"
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 12,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "features = beers.iloc[:, :-1]\n",
        "labels = beers[\"is_yummy\"]\n",
        "features.head()"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "# first use of classifiers"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 13,
    
    schmittu's avatar
    schmittu committed
       "metadata": {
        "scrolled": true
       },
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "0.8311111111111111\n"
         ]
        }
       ],
       "source": [
        "model = LogisticRegression()\n",
        "model.fit(features, labels)\n",
        "predicted = model.predict(features)\n",
        "\n",
        "percent_correct = np.sum(predicted == labels) / len(labels)\n",
        "print(percent_correct)"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 14,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "0.9111111111111111\n"
         ]
        }
       ],
       "source": [
        "model = SVC()\n",
        "model.fit(features, labels)\n",
        "\n",
        "predicted = model.predict(features)\n",
        "\n",
        "percent_correct = np.sum(predicted == labels) / len(labels)\n",
        "print(percent_correct)"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 15,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
          "text/html": [
           "<div>\n",
           "<style scoped>\n",
           "    .dataframe tbody tr th:only-of-type {\n",
           "        vertical-align: middle;\n",
           "    }\n",
           "\n",
           "    .dataframe tbody tr th {\n",
           "        vertical-align: top;\n",
           "    }\n",
           "\n",
           "    .dataframe thead th {\n",
           "        text-align: right;\n",
           "    }\n",
           "</style>\n",
           "<table border=\"1\" class=\"dataframe\">\n",
           "  <thead>\n",
           "    <tr style=\"text-align: right;\">\n",
           "      <th></th>\n",
           "      <th>alcohol_content</th>\n",
           "      <th>bitterness</th>\n",
           "      <th>darkness</th>\n",
           "      <th>fruitiness</th>\n",
           "      <th>is_yummy</th>\n",
           "    </tr>\n",
           "  </thead>\n",
           "  <tbody>\n",
           "    <tr>\n",
           "      <th>0</th>\n",
           "      <td>4.381306</td>\n",
           "      <td>0.365976</td>\n",
           "      <td>1.159893</td>\n",
           "      <td>0.168321</td>\n",
           "      <td>0</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>1</th>\n",
           "      <td>5.540088</td>\n",
           "      <td>0.282582</td>\n",
           "      <td>5.077826</td>\n",
           "      <td>0.129492</td>\n",
           "      <td>1</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>2</th>\n",
           "      <td>5.306264</td>\n",
           "      <td>0.109893</td>\n",
           "      <td>6.159705</td>\n",
           "      <td>0.033846</td>\n",
           "      <td>0</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>3</th>\n",
           "      <td>4.479080</td>\n",
           "      <td>0.414778</td>\n",
           "      <td>1.101224</td>\n",
           "      <td>0.228998</td>\n",
           "      <td>0</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>4</th>\n",
           "      <td>3.789652</td>\n",
           "      <td>0.661923</td>\n",
           "      <td>1.477141</td>\n",
           "      <td>0.280621</td>\n",
           "      <td>0</td>\n",
           "    </tr>\n",
           "  </tbody>\n",
           "</table>\n",
           "</div>"
          ],
          "text/plain": [
           "   alcohol_content  bitterness  darkness  fruitiness  is_yummy\n",
           "0         4.381306    0.365976  1.159893    0.168321         0\n",
           "1         5.540088    0.282582  5.077826    0.129492         1\n",
           "2         5.306264    0.109893  6.159705    0.033846         0\n",
           "3         4.479080    0.414778  1.101224    0.228998         0\n",
           "4         3.789652    0.661923  1.477141    0.280621         0"
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 15,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "beers_eval = pd.read_csv(\"beers_eval.csv\")\n",
        "beers_eval.head()"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 16,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "data": {
          "text/html": [
           "<div>\n",
           "<style scoped>\n",
           "    .dataframe tbody tr th:only-of-type {\n",
           "        vertical-align: middle;\n",
           "    }\n",
           "\n",
           "    .dataframe tbody tr th {\n",
           "        vertical-align: top;\n",
           "    }\n",
           "\n",
           "    .dataframe thead th {\n",
           "        text-align: right;\n",
           "    }\n",
           "</style>\n",
           "<table border=\"1\" class=\"dataframe\">\n",
           "  <thead>\n",
           "    <tr style=\"text-align: right;\">\n",
           "      <th></th>\n",
           "      <th>alcohol_content</th>\n",
           "      <th>bitterness</th>\n",
           "      <th>darkness</th>\n",
           "      <th>fruitiness</th>\n",
           "    </tr>\n",
           "  </thead>\n",
           "  <tbody>\n",
           "    <tr>\n",
           "      <th>0</th>\n",
           "      <td>4.381306</td>\n",
           "      <td>0.365976</td>\n",
           "      <td>1.159893</td>\n",
           "      <td>0.168321</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>1</th>\n",
           "      <td>5.540088</td>\n",
           "      <td>0.282582</td>\n",
           "      <td>5.077826</td>\n",
           "      <td>0.129492</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>2</th>\n",
           "      <td>5.306264</td>\n",
           "      <td>0.109893</td>\n",
           "      <td>6.159705</td>\n",
           "      <td>0.033846</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>3</th>\n",
           "      <td>4.479080</td>\n",
           "      <td>0.414778</td>\n",
           "      <td>1.101224</td>\n",
           "      <td>0.228998</td>\n",
           "    </tr>\n",
           "    <tr>\n",
           "      <th>4</th>\n",
           "      <td>3.789652</td>\n",
           "      <td>0.661923</td>\n",
           "      <td>1.477141</td>\n",
           "      <td>0.280621</td>\n",
           "    </tr>\n",
           "  </tbody>\n",
           "</table>\n",
           "</div>"
          ],
          "text/plain": [
           "   alcohol_content  bitterness  darkness  fruitiness\n",
           "0         4.381306    0.365976  1.159893    0.168321\n",
           "1         5.540088    0.282582  5.077826    0.129492\n",
           "2         5.306264    0.109893  6.159705    0.033846\n",
           "3         4.479080    0.414778  1.101224    0.228998\n",
           "4         3.789652    0.661923  1.477141    0.280621"
          ]
         },
    
    schmittu's avatar
    schmittu committed
         "execution_count": 16,
    
    schmittu's avatar
    schmittu committed
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "features_eval = beers_eval.iloc[:, :-1]\n",
        "labels_eval = beers_eval[\"is_yummy\"]\n",
        "features_eval.head()"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "# apply classifiers to test data set"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 17,
    
    schmittu's avatar
    schmittu committed
       "metadata": {
        "scrolled": true
       },
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "LogisticRegression\n",
          "on learning set: 0.8311111111111111\n",
          "on eval set    : 0.76\n",
          "\n",
          "SVC\n",
          "on learning set: 0.9111111111111111\n",
          "on eval set    : 0.8933333333333333\n",
          "\n"
         ]
        }
       ],
       "source": [
        "# train model and eval on learning and test data set:\n",
        "\n",
        "def check(model):\n",
        "    print(model.__class__.__qualname__)\n",
        "    model.fit(features, labels)\n",
        "\n",
        "    predicted = model.predict(features)\n",
        "    percent_correct = np.sum(predicted == labels) / len(labels)\n",
        "    print(\"on learning set:\", percent_correct)\n",
        "\n",
        "    predicted = model.predict(features_eval)\n",
        "    percent_correct = np.sum(predicted == labels_eval) / len(labels_eval)\n",
        "    print(\"on eval set    :\", percent_correct)\n",
        "    print()\n",
        "\n",
        "\n",
        "check(LogisticRegression())\n",
        "check(SVC())"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "# cross validation"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 18,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [],
       "source": [
        "# now we merge both datasets\n",
        "\n",
        "full_features = pd.concat((features, features_eval))\n",
        "full_labels = pd.concat((labels, labels_eval))"
       ]
      },
      {
       "cell_type": "code",
    
    schmittu's avatar
    schmittu committed
       "execution_count": 19,
    
    schmittu's avatar
    schmittu committed
       "metadata": {},
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [