Newer
Older
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"outputs": [
{
"data": {
"text/html": [
"<style>\n",
" \n",
" @import url('http://fonts.googleapis.com/css?family=Source+Code+Pro');\n",
" \n",
" @import url('http://fonts.googleapis.com/css?family=Kameron');\n",
" @import url('http://fonts.googleapis.com/css?family=Crimson+Text');\n",
" \n",
" @import url('http://fonts.googleapis.com/css?family=Lato');\n",
" @import url('http://fonts.googleapis.com/css?family=Source+Sans+Pro');\n",
" \n",
" @import url('http://fonts.googleapis.com/css?family=Lora'); \n",
"\n",
" \n",
" body {\n",
" font-family: 'Lora', Consolas, sans-serif;\n",
" \n",
" -webkit-print-color-adjust: exact important !;\n",
" \n",
" \n",
" \n",
" }\n",
" \n",
" .alert-block {\n",
" width: 95%;\n",
" margin: auto;\n",
" }\n",
" \n",
" .rendered_html code\n",
" {\n",
" color: black;\n",
" background: #eaf0ff;\n",
" background: #f5f5f5; \n",
" padding: 1pt;\n",
" font-family: 'Source Code Pro', Consolas, monocco, monospace;\n",
" }\n",
" \n",
" p {\n",
" line-height: 140%;\n",
" }\n",
" \n",
" strong code {\n",
" background: red;\n",
" }\n",
" \n",
" .rendered_html strong code\n",
" {\n",
" background: #f5f5f5;\n",
" }\n",
" \n",
" .CodeMirror pre {\n",
" font-family: 'Source Code Pro', monocco, Consolas, monocco, monospace;\n",
" }\n",
" \n",
" .cm-s-ipython span.cm-keyword {\n",
" font-weight: normal;\n",
" }\n",
" \n",
" strong {\n",
" background: #f5f5f5;\n",
" margin-top: 4pt;\n",
" margin-bottom: 4pt;\n",
" padding: 2pt;\n",
" border: 0.5px solid #a0a0a0;\n",
" font-weight: bold;\n",
" color: darkred;\n",
" }\n",
" \n",
" \n",
" div #notebook {\n",
" # font-size: 10pt; \n",
" line-height: 145%;\n",
" }\n",
" \n",
" li {\n",
" line-height: 145%;\n",
" }\n",
"\n",
" div.output_area pre {\n",
" background: #fff9d8 !important;\n",
" padding: 5pt;\n",
" \n",
" -webkit-print-color-adjust: exact; \n",
" \n",
" }\n",
" \n",
" \n",
" \n",
" h1, h2, h3, h4 {\n",
" font-family: Kameron, arial;\n",
"\n",
" }\n",
" \n",
" div#maintoolbar {display: none !important;}\n",
"\n",
" div#site { \n",
" border-top: 20px solid #1F407A; \n",
" border-right: 20px solid #1F407A; \n",
" margin-bottom: 0;\n",
" padding-bottom: 0;\n",
" }\n",
" div#toc-wrapper { \n",
" border-left: 20px solid #1F407A; \n",
" border-top: 20px solid #1F407A; \n",
"\n",
" }\n",
"\n",
" body {\n",
" margin-botton:10px;\n",
" }\n",
"\n",
"</style>\n",
" <script>\n",
"IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
" return false;\n",
"}\n",
" </script>\n",
"\n",
"\n",
"<footer id=\"attribution\" style=\"float:left; color:#1F407A; background:#fff; font-family: helvetica;\">\n",
" Copyright (C) 2019-2021 Scientific IT Services of ETH Zurich,\n",
" <p>\n",
" Contributing Authors:\n",
" Dr. Tarun Chadha,\n",
" Dr. Franziska Oschmann,\n",
" Dr. Mikolaj Rybinski,\n",
" Dr. Uwe Schmitt.\n",
" </p<\n",
"</footer>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# IGNORE THIS CELL WHICH CUSTOMIZES LAYOUT AND STYLING OF THE NOTEBOOK !\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
"import warnings\n",
"warnings.filterwarnings('ignore', category=FutureWarning)\n",
"warnings.filterwarnings = lambda *a, **kw: None\n",
"from IPython.core.display import HTML; HTML(open(\"custom.html\", \"r\").read())"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Chapter 9: Use case - prediction of arm movements\n",
"\n",
"<center>\n",
"<figure>\n",
"<table><tr>\n",
"<td> <img src=\"./images/eeg_cap.png\" style=\"width: 400px;\"/> </td>\n",
"<td> <img src=\"./images/arm_movement.png\" style=\"width: 400px;\"/> </td>\n",
"</tr></table>\n",
"<figcaption>Setup of an EEG-experiment.</figcaption>\n",
"</figure>\n",
"</center>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<center>\n",
"<figure>\n",
" <img src=\"./images/eeg_electrode_numbering.jpg\" width=35%/> \n",
" <figcaption>Arrangement of electrodes on head.</figcaption>\n",
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This data contains EEG recordings of one subject performing **grasp-and-lift (GAL)** trials. \n",
"There is **1 subject** in total, **8 series** of trials for this subject, and approximately **30 trials** within each series. The number of trials varies for each series.\n",
"\n",
"For each **GAL**, you are tasked to detect 6 events:\n",
"\n",
"- HandStart\n",
"- FirstDigitTouch\n",
"- BothStartLoadPhase\n",
"- LiftOff\n",
"- Replace\n",
"- BothReleased\n",
"\n",
"These events always occur in the same order. In this dataset, there are two files for the subject + series combination:\n",
"the ***_data.csv** files contain the raw 32 channels EEG data (sampling rate 500Hz)\n",
"the ***_events.csv** files contains the ground truth frame-wise labels for all events\n",
"\n",
"\n",
"Detailed information about the data can be found here:\n",
"Luciw MD, Jarocka E, Edin BB (2014) Multi-channel EEG recordings during 3,936 grasp and lift trials with varying weight and friction. Scientific Data 1:140047. www.nature.com/articles/sdata201447\n",
"\n",
"*Description from https://www.kaggle.com/c/grasp-and-lift-eeg-detection/data*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<center>\n",
"<figure>\n",
" <img src=\"./images/eeg_signal_preprocessing.png\" title=\"made at imgflip.com\" width=75%/> \n",
" <figcaption>Preprocessing steps for EEG-signals.</figcaption>\n",
"</figure>\n",
"</center>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The data can be found in: `/data/eeg_use_case` and contains:\n",
"\n",
"- 8 series of recorded EEG data\n",
"\n",
"- 8 series of events of arm movements\n",
"\n",
"Load the EEG data and the events:\n",
"- combine all EEG series in one array (size: (total number of time series, number of channels))\n",
"- combine all events in one array (size: (total number of time series, number of different arm movement))\n",
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div class=\"alert alert-block alert-warning\">\n",
" <i class=\"fa fa-info-circle\"></i> <strong>Filter strings with the lambda-operator</strong> \n",
" The lambda-operator allows to build hidden functions, which are basically functions without a name. These hidden functions have any number of parameters, execute an expression and return the value of this expression. The lambda operator can be applied in the following way to filter the filenames:\n",
" \n",
" all_data_files = list(filter(lambda x: '_data' in x, os.listdir(path)))\n",
"</div>"
]
},
"metadata": {},
"outputs": [],
"source": [
"def load_data(file_names, path):\n",
" # read the csv file and drop the id column\n",
" dfs = []\n",
" for f in file_names:\n",
" df = pd.read_csv(path + f).drop('id', axis = 1)\n",
"metadata": {},
"outputs": [],
"source": [
"# define path and list of all data and event files\n",
"import os\n",
"import pandas as pd\n",
"\n",
"path = 'data/eeg_use_case/' \n",
"\n",
"all_data_files = list(filter(lambda x: '_data' in x, os.listdir(path)))\n",
"all_event_files = list(filter(lambda x: '_events' in x, os.listdir(path)))\n",
"\n",
"all_data_sort = np.sort(all_data_files)\n",
"all_event_sort = np.sort(all_event_files)"
"metadata": {},
"outputs": [],
"source": [
"# load all data and event files\n",
"all_data = np.concatenate(load_data(all_data_sort, path))\n",
"all_events = np.concatenate(load_data(all_event_sort, path))"
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualize the EEG-data and events and pay attention to:\n",
"- the EEG traces (plt.plot())\n",
"- the number of detected arm movements (plt.hist())\n",
"execution_count": 6,
"metadata": {
"tags": [
"solution"
]
},
"image/png": "\n",
"<Figure size 504x720 with 5 Axes>"
"image/png": {
"height": 598,
"width": 466
},
"needs_background": "light"
},
"output_type": "display_data"
}
],
"columns = pd.read_csv(path + all_data_sort[0]).columns[1:]\n",
"start = np.where(all_events == 1)[0][0]\n",
"\n",
"plt.figure(figsize = (7,10))\n",
"plt.subplots_adjust(hspace = 0.3)\n",
"for i, ch in enumerate(ix):\n",
" ax = plt.subplot(5,1,i+1)\n",
" ax.plot(all_data[(start-500):(start+3500), ch], linewidth = 1.5, color = cols[i], label = labels[i])\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" ax.set_yticks([])\n",
" ax.set_xticks([])\n",
" ax.legend(loc='upper left', bbox_to_anchor= (0, 1.1), fontsize = 14)\n",
" ax.set_ylim(-500,3000)\n",
" \n",
"ax = plt.subplot(5,1,5)\n",
"ax.spines['right'].set_visible(False)\n",
"ax.spines['top'].set_visible(False)\n",
"ax.spines['left'].set_visible(False)\n",
"ax.set_yticks([])\n",
"ax.set_xticks([])\n",
"ax.plot(all_events[(start-500):(start+3500)], linewidth = 2)\n",
"ax.set_xticks(np.arange(0,4100,1000))\n",
"ax.set_xticklabels(['0', '2', '4', '6', '8'], fontsize = 14)\n",
"ax.set_xlabel('Time [sec]', fontsize = 14)\n",
"ax.set_ylim(0.1,1)\n",
"lgd = ax.legend(['1', '2', '3', '4', '5', '6'],\n",
" loc='lower left', bbox_to_anchor= (0.85, 0.1), ncol=2, \n",
" borderaxespad=0, frameon=True, fontsize = 12)"
"execution_count": 7,
"metadata": {
"tags": [
"solution"
]
},
"image/png": "\n",
"text/plain": [
"<Figure size 720x504 with 6 Axes>"
]
},
"metadata": {
"image/png": {
"height": 432,
"width": 622
},
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize = (10,7))\n",
"plt.subplots_adjust(wspace = 0.5)\n",
"plt.subplots_adjust(hspace = 0.5)\n",
"for i, e in enumerate(all_events.T):\n",
" plt.subplot(2,3,i+1)\n",
" plt.hist(e, [0, 0.5, 1, 1.5])\n",
" plt.xticks([0.25, 1.25], ['no event', 'event'], fontsize = 14)\n",
" plt.yticks([500000, 1000000], [r'$5 \\cdot 10^{5}$', r'$1 \\cdot 10^{6}$'], fontsize = 14) \n",
" plt.title('movement ' + str(i+1), fontsize = 14)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The purpose of the feature extraction is to extract time-dependent features from the EEG data. To do so, a sliding window containing **500 datapoints** each is used. **Three consecutive time windows** each predict the event in the following time step.\n",
"Extract time-dependend features from the EEG-data:\n",
"\n",
"- define the start and end points of a sliding window with a length of **500 datapoints** and a **step size of 2**\n",
"- loop through those start and end points\n",
"- per iteration:\n",
" - take **three consecutive time windows** (window_1 = data[start:end,:], window_2 = data[start+500:end+500,:],\n",
" - compute the **average power** per window (power: square of the signal)\n",
" - combine the three arrays containing the average power to one array"
{
"cell_type": "markdown",
"metadata": {},
"source": [
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.46 ms, sys: 2.75 ms, total: 5.22 ms\n",
"Wall time: 4.24 ms\n"
"step_size = 2\n",
"num_feat = 3\n",
"num_win = int((all_data.shape[0] - (win_size * num_feat))/step_size)\n",
"ix_start = np.arange(0, num_win*step_size - win_size*num_feat, step_size)\n",
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Compute the mean power per time window"
"metadata": {},
"outputs": [],
"source": [
"def mean_pow(y):\n",
" return np.mean(y**2, axis = 0)"
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min 10s, sys: 2.4 s, total: 1min 13s\n",
"Wall time: 1min 13s\n"
"for start, end in zip(ix_start, ix_end):\n",
" pow_1 = mean_pow(all_data[start:end, :])\n",
" pow_2 = mean_pow(all_data[start+500:end+500, :])\n",
" pow_3 = mean_pow(all_data[start+1000:end+1000, :])\n",
" data_filt.append(np.hstack([pow_1, pow_2, pow_3]))\n",
"data_filt = np.array(data_filt)\n",
"events_filt = np.array([all_events[end + 1501, :] for end in ix_end])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exercise section\n",
"\n",
"1. Split the data into a train and test dataset.\n",
"\n",
"2. Define a pipeline which includes:\n",
" - PCA to reduce the data to 10 dimensions\n",
" - Scaling of the data\n",
" - a classifier (LogisticRegression)\n",
"3. Choose an appropriate parametrization of the classifier according to the <strong>imbalance</strong> of the data (see lecture 6).\n",
"4. Transfer the multi-class classification problem into a one-vs-rest classification (start with only one arm movement).\n",
"5. Use cross-validation to test the model performance (cv = 5).\n",
"<br>(hint: use cross_val_predict to evaluate the model performance using the test dataset)\n",
"6. Use the ROC-AUC curve and the confusion matrix for the evaluation of the model.\n",
"7. Visualize the model performance by plotting the true and predicted hand movements.\n",
"8. Once you evaluated the model performance, make predictions based on the test dataset.\n",
"<br>(hint: you have to train your pipeline first)\n",
"<br>\n",
"<br>\n",
"9. Repeat the above named steps for another classifier (Random Forest) and compare the results. \n",
"10. Once your training works, train classifiers for all different arm movements.\n",
"\n",
"<div class=\"alert alert-block alert-warning\">\n",
" <i class=\"fa fa-info-circle\"></i> <strong>ROC (Receiver Operating Characteristics) curve</strong> \n",
" <p>A classifier can produce four different types of results:</p>\n",
" <p>- <strong>true positive</strong> (arm movement was observed and predicted)</p>\n",
" <p>- <strong>true negative</strong> (arm movement was not observed and not predicted)</p>\n",
" <p>- <strong>false positive</strong> (arm movement was not observed but predicted)</p>\n",
" <p>- <strong>false negative</strong> (arm movement was observed but not predicted)</p>\n",
" <p>\n",
" <figure>\n",
" <img src=\"./images/evaluation-measures-for-roc.png\" title=\"made at imgflip.com\" width=50%/>\n",
" </figure>\n",
" </p>\n",
" <p>\n",
" These four possible outcomes also determine the following values:</p>\n",
" <p>- <strong>recall/sensitivity</strong>: true positive rate (should be high) </p>\n",
" <p>- <strong>specificity</strong>: true negative rate (should be low) </p>\n",
" <p>- <strong>precision</strong>: positive predictive value </p> \n",
" <br>\n",
" <p> <strong>f1-score</strong> = $\\frac{precision \\cdot recall}{precision + recall}$</p>\n",
" <br>\n",
" <p>The <strong>ROC curve</strong> plots the sensitivity against (1 - specificity):</p>\n",
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
" <p>\n",
" <figure>\n",
" <img src=\"./images/a-roc-curve-connecting-points.png\" title=\"made at imgflip.com\" width=30%/>\n",
" </figure>\n",
" </p>\n",
" <p>\n",
" <p> As the sensitivity should be high and the specificity should be low the ROC-curve for different classifier performances looks as follows:\n",
" </p>\n",
" <p>\n",
" <center>\n",
" <figure>\n",
" <table><tr>\n",
" <td> <img src=\"./images/a-roc-curve-of-a-random-classifier.png\" style=\"width: 400px;\"/> </td>\n",
" <td> <img src=\"./images/a-roc-curve-of-a-perfect-classifier.png\" style=\"width: 400px;\"/> </td>\n",
" </tr></table>\n",
" </figure>\n",
" </center>\n",
" </p>\n",
" <p>\n",
" The metric <strong>'roc-auc'</strong> describes the area under the ROC-curve. Thus, the higher this values is the better is the performance of the classifier.\n",
" </p>\n",
" <p> All figures are from: https://classeval.wordpress.com/introduction/introduction-to-the-roc-receiver-operating-characteristics-plot/\n",
" </p>\n",
" \n",
" \n",
" \n",
"\n",
"</div>\n",
"\n",
"<div class=\"alert alert-block alert-warning\">\n",
" <i class=\"fa fa-info-circle\"></i> <strong>One-vs-rest classification</strong>\n",
" <p> Multiclass classification can also be tranferred to multiple binary classification problems. One strategy is called One-vs-rest, where one classifier is trained per class. In our case this means that for each arm movement one classifier is trained by considering only the labels of the respective arm movement.\n",
" </p>\n",
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# split of the data\n",
"# from sklearn.model_selection import train_test_split\n",
"# ..."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# make pipeline\n",
"# from sklearn.pipeline import make_pipeline\n",
"# from sklearn.decomposition import PCA\n",
"# from sklearn.preprocessing import StandardScaler\n",
"# from sklearn.linear_model import LogisticRegression\n",
"# from sklearn.ensemble import RandomForestClassifier\n",
"# p = make_pipeline(...)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# training of model\n",
"# from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"# from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"# preds = []\n",
"# for i in range(#nr of arm movements):\n",
"# y_pred = cross_val_predict(...)\n",
"# preds.append(y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"tags": [
"solution"
]
},
"outputs": [],
"source": [
"# split of the data\n",
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, y_train, y_test = train_test_split(data_filt, events_filt,\\\n",
" test_size = 0.33, shuffle = True)"
]
},
{
"cell_type": "markdown",
"source": [
"#### Pipeline with single classifier"
]
},
{
"cell_type": "code",
"metadata": {
"tags": [
"solution"
]
},
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"p_lr = make_pipeline(PCA(10), StandardScaler(), LogisticRegression(class_weight = 'balanced', solver = 'lbfgs'))\n",
"p_rf = make_pipeline(PCA(10), StandardScaler(), RandomForestClassifier(class_weight = 'balanced', n_estimators = 10))"
]
},
{
"cell_type": "code",
"metadata": {
"tags": [
"solution"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Results for arm movement number 1:\n",
"confusion matrix: \n",
"[[218386 244225]\n",
" [ 3973 8912]]\n",
"roc-auc score: 0.5818648069870105\n",
"\n",
"Results for arm movement number 2:\n",
"confusion matrix: \n",
"[[174936 287630]\n",
" [ 4637 8293]]\n",
"roc-auc score: 0.5097813376479718\n",
"\n",
"Results for arm movement number 3:\n",
"confusion matrix: \n",
"[[159938 302671]\n",
" [ 4337 8550]]\n",
"roc-auc score: 0.504594855856682\n",
"\n",
"Results for arm movement number 4:\n",
"confusion matrix: \n",
"[[155437 307076]\n",
" [ 3248 9735]]\n",
"roc-auc score: 0.5429486250708738\n",
"\n",
"Results for arm movement number 5:\n",
"confusion matrix: \n",
"[[235343 227177]\n",
" [ 1491 11485]]\n",
"roc-auc score: 0.696961643702174\n",
"\n",
"Results for arm movement number 6:\n",
"confusion matrix: \n",
"[[241339 221083]\n",
" [ 1310 11764]]\n",
"roc-auc score: 0.7108516020754958\n",
"CPU times: user 3min 49s, sys: 29.7 s, total: 4min 19s\n",
"Wall time: 1min 25s\n"
]
}
],
"source": [
"%%time\n",
"from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"for i in range(6):\n",
" y_pred = cross_val_predict(p_lr, X_train, y_train[:,i], cv = 5)\n",
" print('Results for arm movement number ' + str(i+1) + ':')\n",
" print('confusion matrix: ')\n",
" print(confusion_matrix(y_train[:,i], y_pred))\n",
" print('roc-auc score: ' + str(roc_auc_score(y_train[:,i], y_pred)))\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 18,
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Results for arm movement number 1:\n",
"confusion matrix: \n",
"[[107238 120422]\n",
" [ 1998 4542]]\n",
"roc-auc score: 0.58276997647385\n",
"\n",
"Results for arm movement number 2:\n",
"confusion matrix: \n",
"[[ 85820 141885]\n",
" [ 2311 4184]]\n",
"roc-auc score: 0.5105394949122397\n",
"\n",
"Results for arm movement number 3:\n",
"confusion matrix: \n",
"[[ 78560 149102]\n",
" [ 2212 4326]]\n",
"roc-auc score: 0.503371597290901\n",
"\n",
"Results for arm movement number 4:\n",
"confusion matrix: \n",
"[[ 76499 151224]\n",
" [ 1600 4877]]\n",
"roc-auc score: 0.5444510551690052\n",
"\n",
"Results for arm movement number 5:\n",
"confusion matrix: \n",
"[[115536 112140]\n",
" [ 736 5788]]\n",
"roc-auc score: 0.697321871090943\n",
"\n",
"Results for arm movement number 6:\n",
"confusion matrix: \n",
"[[118587 109187]\n",
" [ 631 5795]]\n",
"roc-auc score: 0.7112198275415272\n",
"\n",
"CPU times: user 58.3 s, sys: 7.01 s, total: 1min 5s\n",
"Wall time: 21.6 s\n"
]
}
],
"from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"preds_lr = []\n",
" p_lr.fit(X_train, y_train[:,i])\n",
" y_pred = p_lr.predict(X_test)\n",
" print('Results for arm movement number ' + str(i+1) + ':')\n",
" print('confusion matrix: ')\n",
" print(confusion_matrix(y_test[:,i], y_pred))\n",
" print('roc-auc score: ' + str(roc_auc_score(y_test[:,i], y_pred)))\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 19,
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Results for arm movement number 1:\n",
"confusion matrix: \n",
"[[462533 78]\n",
" [ 782 12103]]\n",
"roc-auc score: 0.9695703330836672\n",
"\n",
"Results for arm movement number 2:\n",
"confusion matrix: \n",
"[[462473 93]\n",
" [ 850 12080]]\n",
"roc-auc score: 0.9670301775944556\n",
"\n",
"Results for arm movement number 3:\n",
"confusion matrix: \n",
"[[462515 94]\n",
" [ 792 12095]]\n",
"roc-auc score: 0.9691697610560872\n",
"\n",
"Results for arm movement number 4:\n",
"confusion matrix: \n",
"[[462434 79]\n",
" [ 797 12186]]\n",
"roc-auc score: 0.9692206125539191\n",
"\n",
"Results for arm movement number 5:\n",
"confusion matrix: \n",
"[[462397 123]\n",
" [ 560 12416]]\n",
"roc-auc score: 0.9782887343799201\n",
"\n",
"Results for arm movement number 6:\n",
"confusion matrix: \n",
"[[462303 119]\n",
" [ 533 12541]]\n",
"roc-auc score: 0.979487361470148\n",
"\n",
"CPU times: user 7min 2s, sys: 23.5 s, total: 7min 25s\n",
"Wall time: 5min 29s\n"
]
}
],
"source": [
"%%time\n",
"from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"from sklearn.metrics import confusion_matrix, roc_auc_score\n",
"for i in range(6):\n",
" y_pred = cross_val_predict(p_rf, X_train, y_train[:,i], cv = 5)\n",
" print('Results for arm movement number ' + str(i+1) + ':')\n",
" print('confusion matrix: ')\n",
" print(confusion_matrix(y_train[:,i], y_pred))\n",
" print('roc-auc score: ' + str(roc_auc_score(y_train[:,i], y_pred)))\n",
"metadata": {
"tags": [
"solution"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
"Results for arm movement number 1:\n",
"confusion matrix: \n",
"[[227622 38]\n",
" [ 318 6222]]\n",
"roc-auc score: 0.9756046156065661\n",
"\n",
"Results for arm movement number 2:\n",
"confusion matrix: \n",
"[[227661 44]\n",
" [ 327 6168]]\n",
"roc-auc score: 0.9747301736024179\n",
"\n",
"Results for arm movement number 3:\n",
"confusion matrix: \n",
"[[227629 33]\n",
" [ 293 6245]]\n",
"roc-auc score: 0.9775200600803725\n",
"\n",
"Results for arm movement number 4:\n",
"confusion matrix: \n",
"[[227683 40]\n",
" [ 276 6201]]\n",
"roc-auc score: 0.9786060137414901\n",
"\n",
"Results for arm movement number 5:\n",
"confusion matrix: \n",
"[[227615 61]\n",
" [ 238 6286]]\n",
"roc-auc score: 0.9816256943550609\n",
"\n",
"Results for arm movement number 6:\n",
"confusion matrix: \n",
"[[227734 40]\n",
" [ 219 6207]]\n",
"roc-auc score: 0.9828720442725605\n",
"\n",
"CPU times: user 1min 52s, sys: 5.29 s, total: 1min 57s\n",
"Wall time: 1min 29s\n"