Newer
Older
"source": [
"# IGNORE THIS CELL WHICH CUSTOMIZES LAYOUT AND STYLING OF THE NOTEBOOK !\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
"import warnings\n",
"warnings.filterwarnings('ignore', category=FutureWarning)\n",
"warnings.filterwarnings = lambda *a, **kw: None\n",
"from IPython.core.display import HTML; HTML(open(\"custom.html\", \"r\").read())"
]
},
{
"cell_type": "code",
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Chapter 9: Use case - prediction of arm movements\n",
"\n",
"<center>\n",
"<figure>\n",
"<table><tr>\n",
"<td> <img src=\"./images/eeg_cap.png\" style=\"width: 400px;\"/> </td>\n",
"<td> <img src=\"./images/arm_movement.png\" style=\"width: 400px;\"/> </td>\n",
"</tr></table>\n",
"<figcaption>Setup of an EEG-experiment.</figcaption>\n",
"</figure>\n",
"</center>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<center>\n",
"<figure>\n",
" <img src=\"./images/eeg_electrode_numbering.jpg\" width=35%/> \n",
" <figcaption>Arrangement of electrodes on head.</figcaption>\n",
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This data contains EEG recordings of one subject performing **grasp-and-lift (GAL)** trials. \n",
"There is **1 subject** in total, **8 series** of trials for this subject, and approximately **30 trials** within each series. The number of trials varies for each series.\n",
"\n",
"For each **GAL**, you are tasked to detect 6 events:\n",
"\n",
"- HandStart\n",
"- FirstDigitTouch\n",
"- BothStartLoadPhase\n",
"- LiftOff\n",
"- Replace\n",
"- BothReleased\n",
"\n",
"These events always occur in the same order. In this dataset, there are two files for the subject + series combination:\n",
"the ***_data.csv** files contain the raw 32 channels EEG data (sampling rate 500Hz)\n",
"the ***_events.csv** files contains the ground truth frame-wise labels for all events\n",
"\n",
"\n",
"Detailed information about the data can be found here:\n",
"Luciw MD, Jarocka E, Edin BB (2014) Multi-channel EEG recordings during 3,936 grasp and lift trials with varying weight and friction. Scientific Data 1:140047. www.nature.com/articles/sdata201447\n",
"\n",
"*Description from https://www.kaggle.com/c/grasp-and-lift-eeg-detection/data*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<center>\n",
"<figure>\n",
" <img src=\"./images/eeg_signal_preprocessing.png\" title=\"made at imgflip.com\" width=75%/> \n",
" <figcaption>Preprocessing steps for EEG-signals.</figcaption>\n",
"</figure>\n",
"</center>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The data can be found in: `/data/eeg_use_case` and contains:\n",
"\n",
"- 8 series of recorded EEG data\n",
"\n",
"- 8 series of events of arm movements\n",
"\n",
"Load the EEG data and the events:\n",
"- combine all EEG series in one array (size: (total number of time series, number of channels))\n",
"- combine all events in one array (size: (total number of time series, number of different arm movement))\n",
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div class=\"alert alert-block alert-warning\">\n",
" <i class=\"fa fa-info-circle\"></i> <strong>Filter strings with the lambda-operator</strong> \n",
" The lambda-operator allows to build hidden functions, which are basically functions without a name. These hidden functions have any number of parameters, execute an expression and return the value of this expression. The lambda operator can be applied in the following way to filter the filenames:\n",
" \n",
" all_data_files = list(filter(lambda x: '_data' in x, os.listdir(path)))\n",
"</div>"
]
},
"metadata": {},
"outputs": [],
"source": [
"def load_data(file_names, path):\n",
" # read the csv file and drop the id column\n",
" dfs = []\n",
" for f in file_names:\n",
" df = pd.read_csv(path + f).drop('id', axis = 1)\n",
"metadata": {},
"outputs": [],
"source": [
"# define path and list of all data and event files\n",
"import os\n",
"import pandas as pd\n",
"\n",
"path = 'data/eeg_use_case/' \n",
"\n",
"all_data_files = list(filter(lambda x: '_data' in x, os.listdir(path)))\n",
"all_event_files = list(filter(lambda x: '_events' in x, os.listdir(path)))\n",
"\n",
"all_data_sort = np.sort(all_data_files)\n",
"all_event_sort = np.sort(all_event_files)"
"metadata": {},
"outputs": [],
"source": [
"# load all data and event files\n",
"all_data = np.concatenate(load_data(all_data_sort, path))\n",
"all_events = np.concatenate(load_data(all_event_sort, path))"
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualize the EEG-data and events and pay attention to:\n",
"- the EEG traces (plt.plot())\n",
"- the number of detected arm movements (plt.hist())\n",
"metadata": {
"tags": [
"solution"
]
},
"columns = pd.read_csv(path + all_data_sort[0]).columns[1:]\n",
"start = np.where(all_events == 1)[0][0]\n",
"\n",
"plt.figure(figsize = (7,10))\n",
"plt.subplots_adjust(hspace = 0.3)\n",
"for i, ch in enumerate(ix):\n",
" ax = plt.subplot(5,1,i+1)\n",
" ax.plot(all_data[(start-500):(start+3500), ch], linewidth = 1.5, color = cols[i], label = labels[i])\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" ax.set_yticks([])\n",
" ax.set_xticks([])\n",
" ax.legend(loc='upper left', bbox_to_anchor= (0, 1.1), fontsize = 14)\n",
" ax.set_ylim(-500,3000)\n",
" \n",
"ax = plt.subplot(5,1,5)\n",
"ax.spines['right'].set_visible(False)\n",
"ax.spines['top'].set_visible(False)\n",
"ax.spines['left'].set_visible(False)\n",
"ax.set_yticks([])\n",
"ax.set_xticks([])\n",
"ax.plot(all_events[(start-500):(start+3500)], linewidth = 2)\n",
"ax.set_xticks(np.arange(0,4100,1000))\n",
"ax.set_xticklabels(['0', '2', '4', '6', '8'], fontsize = 14)\n",
"ax.set_xlabel('Time [sec]', fontsize = 14)\n",
"ax.set_ylim(0.1,1)\n",
"lgd = ax.legend(['1', '2', '3', '4', '5', '6'],\n",
" loc='lower left', bbox_to_anchor= (0.85, 0.1), ncol=2, \n",
" borderaxespad=0, frameon=True, fontsize = 12)"
"metadata": {
"tags": [
"solution"
]
},
"source": [
"plt.figure(figsize = (10,7))\n",
"plt.subplots_adjust(wspace = 0.5)\n",
"plt.subplots_adjust(hspace = 0.5)\n",
"for i, e in enumerate(all_events.T):\n",
" plt.subplot(2,3,i+1)\n",
" plt.hist(e, [0, 0.5, 1, 1.5])\n",
" plt.xticks([0.25, 1.25], ['no event', 'event'], fontsize = 14)\n",
" plt.yticks([500000, 1000000], [r'$5 \\cdot 10^{5}$', r'$1 \\cdot 10^{6}$'], fontsize = 14) \n",
" plt.title('movement ' + str(i+1), fontsize = 14)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The purpose of the feature extraction is to extract time-dependent features from the EEG data. To do so, a sliding window containing **500 datapoints** each is used. **Three consecutive time windows** each predict the event in the following time step.\n",
"Extract time-dependend features from the EEG-data:\n",
"\n",
"- define the start and end points of a sliding window with a length of **500 datapoints** and a **step size of 2**\n",
"- loop through those start and end points\n",
"- per iteration:\n",
" - take **three consecutive time windows** (window_1 = data[start:end,:], window_2 = data[start+500:end+500,:],\n",
" - compute the **average power** per window (power: square of the signal)\n",
" - combine the three arrays containing the average power to one array"
{
"cell_type": "markdown",
"metadata": {},
"source": [
"step_size = 2\n",
"num_feat = 3\n",
"num_win = int((all_data.shape[0] - (win_size * num_feat))/step_size)\n",
"ix_start = np.arange(0, num_win*step_size - win_size*num_feat, step_size)\n",
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Compute the mean power per time window"
"metadata": {},
"outputs": [],
"source": [
"def mean_pow(y):\n",
" return np.mean(y**2, axis = 0)"
"for start, end in zip(ix_start, ix_end):\n",
" pow_1 = mean_pow(all_data[start:end, :])\n",
" pow_2 = mean_pow(all_data[start+500:end+500, :])\n",
" pow_3 = mean_pow(all_data[start+1000:end+1000, :])\n",
" data_filt.append(np.hstack([pow_1, pow_2, pow_3]))\n",
"data_filt = np.array(data_filt)\n",
"events_filt = np.array([all_events[end + 1501, :] for end in ix_end])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exercise section\n",
"\n",
"1. Split the data into a train and test dataset.\n",
"\n",
"2. Define a pipeline which includes:\n",
" - PCA to reduce the data to 10 dimensions\n",
" - Scaling of the data\n",
" - a classifier (LogisticRegression)\n",
"3. Choose an appropriate parametrization of the classifier according to the <strong>imbalance</strong> of the data (see lecture 6).\n",
"4. Transfer the multi-class classification problem into a one-vs-rest classification (start with only one arm movement).\n",
"5. Use cross-validation to test the model performance (cv = 5).\n",
"<br>(hint: use cross_val_predict to evaluate the model performance using the test dataset)\n",
"6. Use the ROC-AUC curve and the confusion matrix for the evaluation of the model.\n",
"7. Visualize the model performance by plotting the true and predicted hand movements.\n",
"8. Once you evaluated the model performance, make predictions based on the test dataset.\n",
"<br>(hint: you have to train your pipeline first)\n",
"<br>\n",
"<br>\n",
"9. Repeat the above named steps for another classifier (Random Forest) and compare the results. \n",
"10. Once your training works, train classifiers for all different arm movements.\n",
"\n",
"<div class=\"alert alert-block alert-warning\">\n",
" <i class=\"fa fa-info-circle\"></i> <strong>ROC (Receiver Operating Characteristics) curve</strong> \n",
" <p>A classifier can produce four different types of results:</p>\n",
" <p>- <strong>true positive</strong> (arm movement was observed and predicted)</p>\n",
" <p>- <strong>true negative</strong> (arm movement was not observed and not predicted)</p>\n",
" <p>- <strong>false positive</strong> (arm movement was not observed but predicted)</p>\n",
" <p>- <strong>false negative</strong> (arm movement was observed but not predicted)</p>\n",
" <p>\n",
" <figure>\n",
" <img src=\"./images/evaluation-measures-for-roc.png\" title=\"made at imgflip.com\" width=50%/>\n",
" </figure>\n",
" </p>\n",
" <p>\n",
" These four possible outcomes also determine the following values:</p>\n",
" <p>- <strong>recall/sensitivity</strong>: true positive rate (should be high) </p>\n",
" <p>- <strong>specificity</strong>: true negative rate (should be low) </p>\n",
" <p>- <strong>precision</strong>: positive predictive value </p> \n",
" <br>\n",
" <p> <strong>f1-score</strong> = $\\frac{precision \\cdot recall}{precision + recall}$</p>\n",
" <br>\n",
" <p>The <strong>ROC curve</strong> plots the sensitivity against (1 - specificity):</p>\n",
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
" <p>\n",
" <figure>\n",
" <img src=\"./images/a-roc-curve-connecting-points.png\" title=\"made at imgflip.com\" width=30%/>\n",
" </figure>\n",
" </p>\n",
" <p>\n",
" <p> As the sensitivity should be high and the specificity should be low the ROC-curve for different classifier performances looks as follows:\n",
" </p>\n",
" <p>\n",
" <center>\n",
" <figure>\n",
" <table><tr>\n",
" <td> <img src=\"./images/a-roc-curve-of-a-random-classifier.png\" style=\"width: 400px;\"/> </td>\n",
" <td> <img src=\"./images/a-roc-curve-of-a-perfect-classifier.png\" style=\"width: 400px;\"/> </td>\n",
" </tr></table>\n",
" </figure>\n",
" </center>\n",
" </p>\n",
" <p>\n",
" The metric <strong>'roc-auc'</strong> describes the area under the ROC-curve. Thus, the higher this values is the better is the performance of the classifier.\n",
" </p>\n",
" <p> All figures are from: https://classeval.wordpress.com/introduction/introduction-to-the-roc-receiver-operating-characteristics-plot/\n",
" </p>\n",
" \n",
" \n",
" \n",
"\n",
"</div>\n",
"\n",
"<div class=\"alert alert-block alert-warning\">\n",
" <i class=\"fa fa-info-circle\"></i> <strong>One-vs-rest classification</strong>\n",
" <p> Multiclass classification can also be tranferred to multiple binary classification problems. One strategy is called One-vs-rest, where one classifier is trained per class. In our case this means that for each arm movement one classifier is trained by considering only the labels of the respective arm movement.\n",
" </p>\n",
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# split of the data\n",
"# from sklearn.model_selection import train_test_split\n",
"# ..."
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# make pipeline\n",
"# from sklearn.pipeline import make_pipeline\n",
"# from sklearn.decomposition import PCA\n",
"# from sklearn.preprocessing import StandardScaler\n",
"# from sklearn.linear_model import LogisticRegression\n",
"# from sklearn.ensemble import RandomForestClassifier\n",
"# p = make_pipeline(...)"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# training of model\n",
"# from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"# from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"# preds = []\n",
"# for i in range(#nr of arm movements):\n",
"# y_pred = cross_val_predict(...)\n",
"# preds.append(y_pred)"
]
},
{
"cell_type": "code",
"metadata": {
"tags": [
"solution"
]
},
"outputs": [],
"source": [
"# split of the data\n",
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, y_train, y_test = train_test_split(data_filt, events_filt,\\\n",
" test_size = 0.33, shuffle = True)"
]
},
{
"cell_type": "markdown",
"source": [
"#### Pipeline with single classifier"
]
},
{
"cell_type": "code",
"metadata": {
"tags": [
"solution"
]
},
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"p_lr = make_pipeline(PCA(10), StandardScaler(), LogisticRegression(class_weight = 'balanced', solver = 'lbfgs'))\n",
"p_rf = make_pipeline(PCA(10), StandardScaler(), RandomForestClassifier(class_weight = 'balanced', n_estimators = 10))"
]
},
{
"cell_type": "code",
"metadata": {
"tags": [
"solution"
]
},
"source": [
"%%time\n",
"from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"for i in range(6):\n",
" y_pred = cross_val_predict(p_lr, X_train, y_train[:,i], cv = 5)\n",
" print('Results for arm movement number ' + str(i+1) + ':')\n",
" print('confusion matrix: ')\n",
" print(confusion_matrix(y_train[:,i], y_pred))\n",
" print('roc-auc score: ' + str(roc_auc_score(y_train[:,i], y_pred)))\n",
" print()"
]
},
{
"cell_type": "code",
"from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"preds_lr = []\n",
" p_lr.fit(X_train, y_train[:,i])\n",
" y_pred = p_lr.predict(X_test)\n",
" print('Results for arm movement number ' + str(i+1) + ':')\n",
" print('confusion matrix: ')\n",
" print(confusion_matrix(y_test[:,i], y_pred))\n",
" print('roc-auc score: ' + str(roc_auc_score(y_test[:,i], y_pred)))\n",
" print()"
]
},
{
"cell_type": "code",
"source": [
"%%time\n",
"from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"for i in range(6):\n",
" y_pred = cross_val_predict(p_rf, X_train, y_train[:,i], cv = 5)\n",
" print('Results for arm movement number ' + str(i+1) + ':')\n",
" print('confusion matrix: ')\n",
" print(confusion_matrix(y_train[:,i], y_pred))\n",
" print('roc-auc score: ' + str(roc_auc_score(y_train[:,i], y_pred)))\n",
"metadata": {
"tags": [
"solution"
]
},
"source": [
"%%time\n",
"preds_rf = []\n",
"for i in range(6):\n",
" p_rf.fit(X_train, y_train[:,i])\n",
" y_pred = p_rf.predict(X_test)\n",
" print('Results for arm movement number ' + str(i+1) + ':')\n",
" print('confusion matrix: ')\n",
" print(confusion_matrix(y_test[:,i], y_pred))\n",
" print('roc-auc score: ' + str(roc_auc_score(y_test[:,i], y_pred)))\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Visualization of model results"
]
},
{
"cell_type": "code",
"plt.subplots_adjust(wspace = 0.5)\n",
"plt.subplots_adjust(hspace = 0.5)\n",
" plt.plot(y_test[800:1050, i],label = 'observation')\n",
" plt.plot(preds_lr[i][800:1050], ':', label = 'prediction')\n",
" plt.xticks([0, 250], ['0', '0.5'], fontsize = 14)\n",
" plt.xlabel('Time [sec]', fontsize = 14)\n",
" plt.yticks([]) \n",
" plt.title('movement ' + str(i+1), fontsize = 14)\n",
"plt.legend(loc = 1);"
"plt.figure(figsize = (15,7))\n",
"plt.subplots_adjust(wspace = 0.5)\n",
"plt.subplots_adjust(hspace = 0.5)\n",
"for i in range(6):\n",
" plt.subplot(2,3,i+1)\n",
" plt.plot(y_test[800:1050, i],label = 'observation')\n",
" plt.plot(preds_rf[i][800:1050], ':', label = 'prediction')\n",
" plt.xticks([0, 250], ['0', '0.5'], fontsize = 14)\n",
" plt.xlabel('Time [sec]', fontsize = 14)\n",
" plt.yticks([]) \n",
" plt.title('movement ' + str(i+1), fontsize = 14)\n",
"plt.legend(loc = 1);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Your own machine learning application"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For those of you who did not bring your own data, you can have a look at the following datasets:\n",
"- San Francisco Crime Classification (https://www.kaggle.com/c/sf-crime)\n",
"- Forest Cover Type (https://www.kaggle.com/c/forest-cover-type-prediction)"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "288px"
},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}