diff --git a/.gitignore b/.gitignore index 7e4ac34..e3ba2b3 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ Software/CNN_feature_extraction_project/**/*.pkl Software/CNN_feature_extraction_project/**/*.edf Software/CNN_feature_extraction_project/**/*.npz Software/CNN_feature_extraction_project/**/*.pt +Software/CNN_feature_extraction_project/**/*.pth Software/CNN_feature_extraction_project/**/*.h5 diff --git a/Software/CNN_feature_extraction_project/models/mustaf/iteration3_EEGnet_07192024/notebook.ipynb b/Software/CNN_feature_extraction_project/models/mustaf/iteration3_EEGnet_07192024/notebook.ipynb new file mode 100644 index 0000000..7f8fe99 --- /dev/null +++ b/Software/CNN_feature_extraction_project/models/mustaf/iteration3_EEGnet_07192024/notebook.ipynb @@ -0,0 +1,953 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fc1fc60b-8acf-4170-a95e-4088681bce37", + "metadata": {}, + "source": [ + "# CNN Classification\n", + "\n", + "We'll use EEGNet to perform binary classification on 32 channel ECog data.\n", + "\n", + "https://arxiv.org/pdf/1611.08024\n" + ] + }, + { + "cell_type": "markdown", + "id": "08546d9b-1e43-4b2b-b3a2-e2f01afd51dc", + "metadata": {}, + "source": [ + "# 1. Load dataset and define training & test sets" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "d9e52569-9cbe-4395-bb5d-d3652d266c81", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total Dataset Size: 7591\n", + "Shape of the signals: (7591, 32, 250)\n", + "Shape of the labels: torch.Size([7591])\n" + ] + } + ], + "source": [ + "import os\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from sklearn.model_selection import train_test_split\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torchvision import transforms\n", + "from torch.autograd import Variable\n", + "from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler\n", + "import numpy as np\n", + "import optuna\n", + "\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "data_dir = '../../../datasets/processed/shuffleboard/'\n", + "\n", + "# inputs\n", + "signal_data_path = data_dir + 'normalized_dataset_500hz_500ms_consecutive_buckets.npz'\n", + "signal_data = np.load(signal_data_path, allow_pickle=True)\n", + "\n", + "signal_array = signal_data['data']\n", + "signals_tensor = torch.tensor(signal_array, dtype=torch.float32)\n", + "signals_tensor = signals_tensor.unsqueeze(1) # Add channel dimension\n", + "\n", + "\n", + "# Print dataset info \n", + "print(\"Total Dataset Size:\", len(dataset))\n", + "print(\"Shape of the signals:\", signal_array.shape)\n", + "print(\"Shape of the labels:\", labels_tensor.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "18a2e377-95f5-460a-ba1c-4d24f8cdfb15", + "metadata": {}, + "source": [ + "# Define Model\n", + "\n", + "Torcheeg's implementation of EEGnet:\n", + "\n", + "https://torcheeg.readthedocs.io/en/v1.1.0/generated/torcheeg.models.EEGNet.html#torcheeg.models.EEGNet" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "fcaec51b-308e-4b7f-945e-cc41fb9cc55e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class Conv2dWithConstraint(nn.Conv2d):\n", + " def __init__(self, *args, max_norm: int = 1, **kwargs):\n", + " self.max_norm = max_norm\n", + " super(Conv2dWithConstraint, self).__init__(*args, **kwargs)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)\n", + " return super(Conv2dWithConstraint, self).forward(x)\n", + "\n", + "\n", + "class EEGNet(nn.Module):\n", + " r'''\n", + " A compact convolutional neural network (EEGNet). For more details, please refer to the following information.\n", + "\n", + " - Paper: Lawhern V J, Solon A J, Waytowich N R, et al. EEGNet: a compact convolutional neural network for EEG-based brain-computer interfaces[J]. Journal of neural engineering, 2018, 15(5): 056013.\n", + " - URL: https://arxiv.org/abs/1611.08024\n", + " - Related Project: https://github.com/braindecode/braindecode/tree/master/braindecode\n", + " '''\n", + " def __init__(self,\n", + " chunk_size: int = 151,\n", + " num_electrodes: int = 60,\n", + " F1: int = 8,\n", + " F2: int = 16,\n", + " D: int = 2,\n", + " num_classes: int = 2,\n", + " kernel_1: int = 64,\n", + " kernel_2: int = 16,\n", + " dropout: float = 0.25):\n", + " super(EEGNet, self).__init__()\n", + " self.F1 = F1\n", + " self.F2 = F2\n", + " self.D = D\n", + " self.chunk_size = chunk_size\n", + " self.num_classes = num_classes\n", + " self.num_electrodes = num_electrodes\n", + " self.kernel_1 = kernel_1\n", + " self.kernel_2 = kernel_2\n", + " self.dropout = dropout\n", + "\n", + " self.block1 = nn.Sequential(\n", + " nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),\n", + " nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),\n", + " Conv2dWithConstraint(self.F1,\n", + " self.F1 * self.D, (self.num_electrodes, 1),\n", + " max_norm=1,\n", + " stride=1,\n", + " padding=(0, 0),\n", + " groups=self.F1,\n", + " bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),\n", + " nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))\n", + "\n", + " self.block2 = nn.Sequential(\n", + " nn.Conv2d(self.F1 * self.D,\n", + " self.F1 * self.D, (1, self.kernel_2),\n", + " stride=1,\n", + " padding=(0, self.kernel_2 // 2),\n", + " bias=False,\n", + " groups=self.F1 * self.D),\n", + " nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),\n", + " nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),\n", + " nn.Dropout(p=dropout))\n", + "\n", + " self.lin = nn.Linear(self.feature_dim(), num_classes, bias=False)\n", + "\n", + " def feature_dim(self):\n", + " with torch.no_grad():\n", + " mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)\n", + "\n", + " mock_eeg = self.block1(mock_eeg)\n", + " mock_eeg = self.block2(mock_eeg)\n", + "\n", + " return self.F2 * mock_eeg.shape[3]\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " r'''\n", + " Args:\n", + " x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 60, 151]`. Here, :obj:`n` corresponds to the batch size, :obj:`60` corresponds to :obj:`num_electrodes`, and :obj:`151` corresponds to :obj:`chunk_size`.\n", + "\n", + " Returns:\n", + " torch.Tensor[number of sample, number of classes]: the predicted probability that the samples belong to the classes.\n", + " '''\n", + " x = self.block1(x)\n", + " x = self.block2(x)\n", + " x = x.flatten(start_dim=1)\n", + " x = self.lin(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "d357966d-2b54-4415-bbc4-432beeef5a44", + "metadata": {}, + "source": [ + "# Train Models" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "7af0ec45-266d-4dd1-917f-9594a500aa2a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Training envelopes Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.39%, Percentage of 0's: 50.61%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.75 1.00 0.86 559\n", + " 1 0.99 0.68 0.81 581\n", + "\n", + " accuracy 0.84 1140\n", + " macro avg 0.87 0.84 0.83 1140\n", + "weighted avg 0.87 0.84 0.83 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "557 | 2\n", + "---+---\n", + "186 | 395\n", + "FN | TP\n", + "\n", + "Best model saved to best_envelopes_model.pth\n", + "----\n", + "\n", + "--- Training rms Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.39%, Percentage of 0's: 50.61%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.93 0.96 0.94 557\n", + " 1 0.96 0.93 0.94 583\n", + "\n", + " accuracy 0.94 1140\n", + " macro avg 0.94 0.94 0.94 1140\n", + "weighted avg 0.94 0.94 0.94 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "532 | 25\n", + "---+---\n", + "40 | 543\n", + "FN | TP\n", + "\n", + "Best model saved to best_rms_model.pth\n", + "----\n", + "\n", + "--- Training variance Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.65%, Percentage of 0's: 50.35%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.53 0.99 0.69 599\n", + " 1 0.82 0.03 0.05 541\n", + "\n", + " accuracy 0.54 1140\n", + " macro avg 0.68 0.51 0.37 1140\n", + "weighted avg 0.67 0.54 0.39 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "596 | 3\n", + "---+---\n", + "527 | 14\n", + "FN | TP\n", + "\n", + "Best model saved to best_variance_model.pth\n", + "----\n", + "\n", + "--- Training std_dev Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.89%, Percentage of 0's: 50.11%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.64 0.19 0.30 570\n", + " 1 0.53 0.89 0.66 570\n", + "\n", + " accuracy 0.54 1140\n", + " macro avg 0.58 0.54 0.48 1140\n", + "weighted avg 0.58 0.54 0.48 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "110 | 460\n", + "---+---\n", + "61 | 509\n", + "FN | TP\n", + "\n", + "Best model saved to best_std_dev_model.pth\n", + "----\n", + "\n", + "--- Training spectral_edge_density Model ---\n", + "Distribution of Labels: Percentage of 1's: 50.03%, Percentage of 0's: 49.97%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.55 0.90 0.68 581\n", + " 1 0.68 0.23 0.34 559\n", + "\n", + " accuracy 0.57 1140\n", + " macro avg 0.61 0.56 0.51 1140\n", + "weighted avg 0.61 0.57 0.51 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "522 | 59\n", + "---+---\n", + "432 | 127\n", + "FN | TP\n", + "\n", + "Best model saved to best_spectral_edge_density_model.pth\n", + "----\n", + "\n", + "--- Training derivatives Model ---\n", + "Distribution of Labels: Percentage of 1's: 50.18%, Percentage of 0's: 49.82%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.65 0.97 0.78 564\n", + " 1 0.94 0.49 0.64 576\n", + "\n", + " accuracy 0.73 1140\n", + " macro avg 0.79 0.73 0.71 1140\n", + "weighted avg 0.79 0.73 0.71 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "545 | 19\n", + "---+---\n", + "294 | 282\n", + "FN | TP\n", + "\n", + "Best model saved to best_derivatives_model.pth\n", + "----\n", + "\n", + "--- Training centroids Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.77%, Percentage of 0's: 50.23%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.50 0.51 0.51 583\n", + " 1 0.48 0.47 0.47 557\n", + "\n", + " accuracy 0.49 1140\n", + " macro avg 0.49 0.49 0.49 1140\n", + "weighted avg 0.49 0.49 0.49 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "297 | 286\n", + "---+---\n", + "295 | 262\n", + "FN | TP\n", + "\n", + "Best model saved to best_centroids_model.pth\n", + "----\n", + "\n", + "--- Training phases Model ---\n", + "Distribution of Labels: Percentage of 1's: 50.34%, Percentage of 0's: 49.66%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.51 0.97 0.67 577\n", + " 1 0.64 0.06 0.10 563\n", + "\n", + " accuracy 0.52 1140\n", + " macro avg 0.58 0.51 0.39 1140\n", + "weighted avg 0.58 0.52 0.39 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "559 | 18\n", + "---+---\n", + "531 | 32\n", + "FN | TP\n", + "\n", + "Best model saved to best_phases_model.pth\n", + "----\n", + "\n", + "--- Training beta_band_power Model ---\n", + "Distribution of Labels: Percentage of 1's: 50.15%, Percentage of 0's: 49.85%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.56 0.23 0.33 568\n", + " 1 0.52 0.82 0.64 572\n", + "\n", + " accuracy 0.53 1140\n", + " macro avg 0.54 0.53 0.48 1140\n", + "weighted avg 0.54 0.53 0.48 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "132 | 436\n", + "---+---\n", + "103 | 469\n", + "FN | TP\n", + "\n", + "Best model saved to best_beta_band_power_model.pth\n", + "----\n", + "\n", + "--- Training average_signal_shapes Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.30%, Percentage of 0's: 50.70%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.97 0.90 0.93 571\n", + " 1 0.91 0.97 0.94 569\n", + "\n", + " accuracy 0.93 1140\n", + " macro avg 0.94 0.93 0.93 1140\n", + "weighted avg 0.94 0.93 0.93 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "514 | 57\n", + "---+---\n", + "18 | 551\n", + "FN | TP\n", + "\n", + "Best model saved to best_average_signal_shapes_model.pth\n", + "----\n", + "\n", + "--- Training analytic_signals Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.30%, Percentage of 0's: 50.70%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.99 0.73 0.84 579\n", + " 1 0.78 0.99 0.87 561\n", + "\n", + " accuracy 0.86 1140\n", + " macro avg 0.88 0.86 0.86 1140\n", + "weighted avg 0.88 0.86 0.85 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "422 | 157\n", + "---+---\n", + " 6 | 555\n", + "FN | TP\n", + "\n", + "Best model saved to best_analytic_signals_model.pth\n", + "----\n", + "\n", + "--- Training fft_results Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.64%, Percentage of 0's: 50.36%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.98 0.45 0.61 604\n", + " 1 0.61 0.99 0.76 536\n", + "\n", + " accuracy 0.70 1140\n", + " macro avg 0.80 0.72 0.69 1140\n", + "weighted avg 0.81 0.70 0.68 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "270 | 334\n", + "---+---\n", + " 6 | 530\n", + "FN | TP\n", + "\n", + "Best model saved to best_fft_results_model.pth\n", + "----\n", + "\n", + "--- Training magnitudes Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.52%, Percentage of 0's: 50.48%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.74 0.77 0.75 578\n", + " 1 0.75 0.72 0.73 562\n", + "\n", + " accuracy 0.74 1140\n", + " macro avg 0.75 0.74 0.74 1140\n", + "weighted avg 0.75 0.74 0.74 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "446 | 132\n", + "---+---\n", + "159 | 403\n", + "FN | TP\n", + "\n", + "Best model saved to best_magnitudes_model.pth\n", + "----\n", + "\n", + "--- Training average_distance Model ---\n", + "Distribution of Labels: Percentage of 1's: 50.26%, Percentage of 0's: 49.74%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.48 0.02 0.03 579\n", + " 1 0.49 0.98 0.65 561\n", + "\n", + " accuracy 0.49 1140\n", + " macro avg 0.48 0.50 0.34 1140\n", + "weighted avg 0.48 0.49 0.34 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "10 | 569\n", + "---+---\n", + "11 | 550\n", + "FN | TP\n", + "\n", + "Best model saved to best_average_distance_model.pth\n", + "----\n", + "\n", + "--- Training average_peak_height Model ---\n", + "Distribution of Labels: Percentage of 1's: 2.15%, Percentage of 0's: 97.85%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.98 1.00 0.99 1115\n", + " 1 0.00 0.00 0.00 25\n", + "\n", + " accuracy 0.98 1140\n", + " macro avg 0.49 0.50 0.49 1140\n", + "weighted avg 0.96 0.98 0.97 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "1113 | 2\n", + "---+---\n", + "25 | 0\n", + "FN | TP\n", + "\n", + "Best model saved to best_average_peak_height_model.pth\n", + "----\n", + "\n", + "--- Training peak_counts Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.45%, Percentage of 0's: 50.55%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.65 0.12 0.20 580\n", + " 1 0.51 0.93 0.66 560\n", + "\n", + " accuracy 0.52 1140\n", + " macro avg 0.58 0.53 0.43 1140\n", + "weighted avg 0.58 0.52 0.43 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "70 | 510\n", + "---+---\n", + "37 | 523\n", + "FN | TP\n", + "\n", + "Best model saved to best_peak_counts_model.pth\n", + "----\n", + "\n", + "--- Training spectral_entropy Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.02%, Percentage of 0's: 50.98%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.63 0.27 0.38 563\n", + " 1 0.54 0.85 0.66 577\n", + "\n", + " accuracy 0.56 1140\n", + " macro avg 0.59 0.56 0.52 1140\n", + "weighted avg 0.59 0.56 0.52 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "151 | 412\n", + "---+---\n", + "89 | 488\n", + "FN | TP\n", + "\n", + "Best model saved to best_spectral_entropy_model.pth\n", + "----\n", + "\n", + "--- Training evolution_rate Model ---\n", + "Distribution of Labels: Percentage of 1's: 49.23%, Percentage of 0's: 50.77%\n", + "\n", + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.51 0.99 0.67 575\n", + " 1 0.59 0.02 0.03 565\n", + "\n", + " accuracy 0.51 1140\n", + " macro avg 0.55 0.50 0.35 1140\n", + "weighted avg 0.55 0.51 0.35 1140\n", + "\n", + "\n", + "Confusion Matrix:\n", + "----------------\n", + "TN | FP\n", + "---+---\n", + "568 | 7\n", + "---+---\n", + "555 | 10\n", + "FN | TP\n", + "\n", + "Best model saved to best_evolution_rate_model.pth\n", + "----\n", + "\n" + ] + } + ], + "source": [ + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader, random_split\n", + "from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, classification_report\n", + "import optuna\n", + "import os\n", + "import warnings\n", + "\n", + "# Don't want to see all of those warnings ... \n", + "optuna.logging.set_verbosity(optuna.logging.CRITICAL)\n", + "\n", + "\n", + "# features where the output from from extraction was 32 values (one per channel) before we took the average value over channels\n", + "features = [\n", + " \"envelopes\",\n", + " \"rms\",\n", + " \"variance\",\n", + " \"std_dev\",\n", + " \"spectral_edge_density\",\n", + " \"derivatives\",\n", + " \"centroids\",\n", + " \"phases\",\n", + " \"beta_band_power\",\n", + " \"average_signal_shapes\",\n", + " \"analytic_signals\",\n", + " \"fft_results\",\n", + " \"magnitudes\",\n", + " \"average_distance\",\n", + " \"average_peak_height\",\n", + " \"peak_counts\",\n", + " \"spectral_entropy\",\n", + " \"evolution_rate\"\n", + "]\n", + "\n", + "for feature_name in features:\n", + "\n", + " # output labels \n", + " labels_tensor = torch.load(os.path.join(data_dir, '{}_labels_500ms_500hz_tensor.pt'.format(feature_name)))\n", + " labels_tensor = labels_tensor.long()\n", + "\n", + "\n", + " dataset = TensorDataset(signals_tensor, labels_tensor)\n", + "\n", + " print(\"--- Training {} Model ---\".format(feature_name))\n", + " print(f\"Distribution of Labels: Percentage of 1's: {labels_tensor.float().mean()*100:.2f}%, Percentage of 0's: {100 - labels_tensor.float().mean()*100:.2f}%\")\n", + "\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + " total_size = len(dataset)\n", + " train_size = int(0.7 * total_size)\n", + " val_size = int(0.15 * total_size)\n", + " test_size = total_size - train_size - val_size\n", + "\n", + " train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])\n", + "\n", + " def create_dataloaders(batch_size):\n", + " return (\n", + " DataLoader(train_dataset, batch_size=batch_size, shuffle=True),\n", + " DataLoader(val_dataset, batch_size=batch_size),\n", + " DataLoader(test_dataset, batch_size=batch_size)\n", + " )\n", + "\n", + " def create_model(params):\n", + " return EEGNet(\n", + " chunk_size=250,\n", + " num_electrodes=32,\n", + " num_classes=2,\n", + " dropout=params['dropout']\n", + " ).to(device)\n", + "\n", + " def objective(trial):\n", + " params = {\n", + " 'dropout': trial.suggest_discrete_uniform('dropout', 0.1, 0.7, 0.1),\n", + " 'num_epochs': trial.suggest_int('num_epochs', 20, 100, step=10)\n", + " }\n", + "\n", + " model = create_model(params)\n", + " criterion = nn.CrossEntropyLoss()\n", + " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "\n", + " train_loader, val_loader, _ = create_dataloaders(32)\n", + "\n", + " best_val_accuracy = 0\n", + " for epoch in range(params['num_epochs']):\n", + " model.train()\n", + " for inputs, labels in train_loader:\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + " optimizer.zero_grad()\n", + " outputs = model(inputs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " model.eval()\n", + " correct = total = 0\n", + " with torch.no_grad():\n", + " for inputs, labels in val_loader:\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + " outputs = model(inputs)\n", + " _, predicted = outputs.max(1)\n", + " total += labels.size(0)\n", + " correct += predicted.eq(labels).sum().item()\n", + "\n", + " val_accuracy = 100. * correct / total\n", + " if val_accuracy > best_val_accuracy:\n", + " best_val_accuracy = val_accuracy\n", + " torch.save(model.state_dict(), '{}_best_model_in_trial.pth'.format(feature_name))\n", + "\n", + " trial.report(val_accuracy, epoch)\n", + " if trial.should_prune():\n", + " raise optuna.exceptions.TrialPruned()\n", + "\n", + " return best_val_accuracy\n", + "\n", + " study = optuna.create_study(direction='maximize')\n", + " study.optimize(objective, n_trials=50)\n", + "\n", + " best_params = study.best_params\n", + "\n", + "\n", + " # Evaluate best feature model on test set \n", + " best_model = create_model(best_params)\n", + " best_model.load_state_dict( torch.load('{}_best_model_in_trial.pth'.format(feature_name)))\n", + " best_model.eval()\n", + "\n", + " \n", + " all_labels = []\n", + " all_predictions = []\n", + "\n", + " with torch.no_grad():\n", + " for inputs, labels in DataLoader(test_dataset, batch_size=32):\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + " outputs = best_model(inputs)\n", + " _, predicted = outputs.max(1)\n", + " all_labels.extend(labels.cpu().numpy())\n", + " all_predictions.extend(predicted.cpu().numpy())\n", + "\n", + " cm = confusion_matrix(all_labels, all_predictions)\n", + "\n", + " \n", + " report = classification_report(all_labels, all_predictions, target_names=['0', '1'], digits=2)\n", + " print(\"\\nClassification Report:\")\n", + " print(report)\n", + "\n", + " \n", + " print(\"\\nConfusion Matrix:\")\n", + " print(\"----------------\")\n", + " print(\"TN | FP\")\n", + " print(\"---+---\")\n", + " print(f\"{cm[0][0]:2d} | {cm[0][1]:2d}\")\n", + " print(\"---+---\")\n", + " print(f\"{cm[1][0]:2d} | {cm[1][1]:2d}\")\n", + " print(\"FN | TP\")\n", + " \n", + " model_save_path = f'best_{feature_name}_model.pth'\n", + " torch.save(best_model.state_dict(), model_save_path)\n", + " print(f'\\nBest model saved to {model_save_path}')\n", + " print(\"----\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "65ecbac2-449d-42c7-be09-e936cf760f0a", + "metadata": {}, + "source": [ + "# Summary of Results\n" + ] + }, + { + "cell_type": "markdown", + "id": "04b11fba-0d4f-40fb-bc52-d1a7bab1e905", + "metadata": {}, + "source": [ + "| Feature Model | Accuracy |\n", + "|---------------|----------|\n", + "| rms | 0.94 |\n", + "| average_signal_shapes | 0.93 |\n", + "| analytic_signals | 0.86 |\n", + "| envelopes | 0.84 |\n", + "| magnitudes | 0.74 |\n", + "| derivatives | 0.73 |\n", + "| fft_results | 0.70 |\n", + "| spectral_edge_density | 0.57 |\n", + "| spectral_entropy | 0.56 |\n", + "| variance | 0.54 |\n", + "| std_dev | 0.54 |\n", + "| beta_band_power | 0.53 |\n", + "| peak_counts | 0.52 |\n", + "| phases | 0.52 |\n", + "| evolution_rate | 0.51 |\n", + "| average_distance | 0.49 |\n", + "| centroids | 0.49 |" + ] + }, + { + "cell_type": "markdown", + "id": "2c4e5308-6a8d-4767-b161-08df7b7beefc", + "metadata": { + "tags": [] + }, + "source": [ + "### Label Distributions\n", + "\n", + "For each feature model, the distribution of labels was balanced.\n", + "\n", + "### Excellent\n", + "\n", + "The EEGNet architecture performs very well for predicting increases/decreases associated with the rms and average_signal_shapes features, with 94% and 96% accuracy respectively after basic hyper-parameter tuning. \n", + "\n", + "### Good\n", + "\n", + "We saw good performance with analytic_signals (86%) and envelopes (84%). \n", + "\n", + "With the analytic_signals model, there were a significant amount false positive degrading its performance, and for the envelopes model, the amount of false negatives degraded its performance.\n", + "\n", + "\n", + "### Moderate\n", + "\n", + "magnitudes\t0.74\n", + "\n", + "derivatives\t0.73\n", + "\n", + "fft_results\t0.70\n", + "\n", + "### Poor\n", + "\n", + "The following features performed close or at the performance of random guessing:\n", + "\n", + "spectral_edge_density\t0.57\n", + "\n", + "spectral_entropy\t0.56\n", + "\n", + "variance\t0.54\n", + "\n", + "std_dev\t0.54\n", + "\n", + "beta_band_power\t0.53\n", + "\n", + "peak_counts\t0.52\n", + "\n", + "phases\t0.52\n", + "\n", + "evolution_rate\t0.51\n", + "\n", + "average_distance\t0.49\n", + "\n", + "centroids\t0.49\n", + "\n", + "\n", + "### Next Steps \n", + "1. One option is to continue hyper-parameter tuning because I've done very basic tuning.\n", + "2. Feature engineer the input signals. At the moment, each input signal is a 500ms block of concatenated 250ms blocks; maybe we can try different ways of creating the inputs other than concatenation, which may yield better performance. \n", + "3. Explore other models\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}