Classifying flaring stars with stella: a convolutional neural network for TESS#


Learning Goals#

In this tutorial, you will see an example of building, compiling, and training a CNN to classify astronomical data in vector form. By the end of this tutorial you will have a working example of a simple Convolutional Neural Network (CNN) in Keras.

Introduction#

CNNs are a class of machine learning (ML) algorithms that can extract information from data. In this notebook, you will walk through the basic steps of applying a CNN to data:

  1. Load the data and visualize a sample of the data.

  2. Divide the data into training, validation, and testing sets.

  3. Build a CNN in Keras.

  4. Compile the CNN.

  5. Train the CNN to perform a classification task.

  6. Evaluate the CNN performance on test data with a confusion matrix.

  7. Build a new, unlabeled dataset and apply the CNN.

CNNs can be applied to a wide range of vector analysis tasks, including classification and regression. Here, we will build, compile, and train CNN to classify whether a star has undergone a flaring event from its observed Transiting Exoplanet Survey Satellite (TESS) 2-minute light curve, and where the flaring events are located within the time series. This work is based on the model described in the stella software package.

NOTE: The stella team has publicly-available code and documentation for demonstrating the architecture and optimal performance of this model, which we encourage you to check out! The goal of this notebook is to step through the model building and training process.

Imports#

This notebook uses the following:

  • numpy to handle array functions

  • tarfile for unpacking files

  • astropy for accessing FITS files

  • matplotlib.pyplot for plotting data

  • stella for generating the training set and processing data

  • keras for building the CNN

  • sklearn for model performance metrics

  • lightkurve.search for extracting light curves

For the stella package, please install the development version (see their documentation for instructions).

For other packages, you can install them using pip or conda.

# arrays
import numpy as np

# unpacking files
import tarfile

# fits
from astropy.io import fits
from astropy.utils.data import download_file

# plotting
import matplotlib.pyplot as plt
from matplotlib import ticker

# stella CNN functions
import stella

# keras
from keras.models import Sequential, Model, load_model
from keras.layers import Input, Flatten, Dense, Activation, Dropout, BatchNormalization
from keras.layers.convolutional import Convolution1D, MaxPooling1D

# sklearn for performance metrics
from sklearn import metrics

# lightkurve
from lightkurve.search import search_lightcurve

# from IPython import get_ipython
# get_ipython().run_line_magic('matplotlib', 'notebook')

# set random seed for reproducibility 
np.random.seed(42)
2022-08-09 14:28:22.709838: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.12/x64/lib
2022-08-09 14:28:22.709878: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

Convolutional Neural Network (CNN) for Vector Classification#

1. Download the training data using stella#

Load the sample of TESS lightcurves (input vectors) and flare classifications (output labels) to be used to train the CNN from the stella package. stella pre-processes the light curves and splits the data into training, test and validation sets.

The training set contains stars observed at 2-minute cadence in TESS Sectors 1 and 2, classified by hand and presented as a flare catalog by Gunther et al. 2020. The light curves are processed into examples of length 200 cadences, where each flaring event, if present, is located at the center of the example. The full sample of lightcurves contains 8694 positive classes (flare), and 35896 negative classes (no flare). For this notebook, we will download a subset of this sample. These data are described in Feinstein et al. 2020.

The CNN will be used to predict the presence of flaring events as a function of observing cadence. The input to the CNN is a light curve (time, flux, and flux error) and the output is a “probability light curve”, or probabilities (value between 0 and 1) that the measurement at each cadence is of a flaring event (1=flare, 0=no flare). In other words, the CNN performs a classification task at each cadence. Before a probability light curve can be produced, stella first pre-processes the input light curve by assembling it into examples of length 200 cadences, so that the model can predict a value for the flare probability at each valid cadence.

%%time
file_url = 'https://archive.stsci.edu/hlsps/hellouniverse/hellouniverse_stella_500.tar.gz'

# open file
file = tarfile.open(download_file(file_url, cache=True))
  
# extracting file
file.extractall('.')
file.close()
CPU times: user 209 ms, sys: 64.1 ms, total: 273 ms
Wall time: 1.57 s
# build train, test, validation dataset, "ds"
data_dir = './hellouniverse_stella_500/'

ds = stella.FlareDataSet(fn_dir=data_dir,
                         catalog=data_dir+'Guenther_2020_flare_catalog.txt')
Reading in training set files.
  0%|          | 0/62 [00:00<?, ?it/s]
 50%|█████     | 31/62 [00:00<00:00, 304.06it/s]
100%|██████████| 62/62 [00:00<00:00, 317.18it/s]

502 positive classes (flare)
1342 negative classes (no flare)
37.0% class imbalance

The stella dataset includes training, test and validation lightcurves (input vectors) and flare labels (output labels). stella applies the necessary pre-processing to the lightcurves for input to the CNN model. For more on how these are constructed, see Feinstein et al. 2020. For our purposes (i.e., building the stella CNN from scratch to illustrate its structure and function), we first need to remove all lightcurves from the training, test and validation sets with NaN-valued inputs. To do this, we loop through the data and select only lightcurves without NaNs.

# remove lightcurves with NaNs from training, test and validation data
def remove_nans(input_data):
    '''Determine indices of files without NaNs'''

    idx = []
    for k in range(np.shape(input_data)[0]):
        if len(input_data[k, :, :][np.isnan(input_data[k, :, :])]) == 0:
            idx.append(k)
    return idx


# find indices in train, test and validation sets without NaNs
idx_train = remove_nans(ds.train_data)
idx_test = remove_nans(ds.test_data)
idx_val = remove_nans(ds.val_data)

ds.train_data = ds.train_data[idx_train]
ds.train_labels = ds.train_labels[idx_train]

ds.test_data = ds.test_data[idx_test]
ds.test_labels = ds.test_labels[idx_test]

ds.val_data = ds.val_data[idx_val]
ds.val_labels = ds.val_labels[idx_val]

To visualize the structure of the lightcurves in the training set, we plot a random selection of 16 examples:

# select random image indices:
example_ids = np.random.choice(len(ds.train_labels), 16)

# pull the lightcurves and labels for these selections
example_lightcurves = [ds.train_data[j] for j in example_ids]
example_labels = [ds.train_labels[j] for j in example_ids]


# initialize your figure
fig = plt.figure(figsize=(10, 10))

# loop through the randomly selected images and plot with labels
colors = {1: 'r', 0: 'k'}
titles = {1: 'Flare', 0: 'Non-flare'}
for i in range(len(example_ids)):
    plt.subplot(4, 4, i + 1)
    plt.plot(example_lightcurves[i], color=colors[example_labels[i]])
    plt.title(titles[example_labels[i]])
    plt.xlabel('Cadences')
    
plt.tight_layout()
plt.show()
../../../_images/Classifying_TESS_flares_with_CNNs_13_0.png

3. Build a CNN in Keras#

Here, we will build the CNN model described in Feinstein et al. 2020 and implemented in stella from scratch.

Further details about Conv1D, MaxPooling1D, BatchNormalization, Dropout, and Dense layers can be found in the Keras Layers Documentation. Further details about the sigmoid and softmax activation function can be found in the Keras Activation Function Documentation.

# ------------------------------------------------------------------------------
# generate the model architecture
# Written for Keras 2
# ------------------------------------------------------------------------------

seed = 2
np.random.seed(seed)

filter1 = 16
filter2 = 64
dense = 32
dropout = 0.1

# Define architecture for model
data_shape = np.shape(ds.train_data)
input_shape = (np.shape(ds.train_data)[1], 1)

x_in = Input(shape=input_shape)
c0 = Convolution1D(7, filter1, activation='relu', padding='same', input_shape=input_shape)(x_in)
b0 = MaxPooling1D(pool_size=2)(c0)
d0 = Dropout(dropout)(b0)

c1 = Convolution1D(3, filter2, activation='relu', padding='same')(d0)
b1 = MaxPooling1D(pool_size=2)(c1)
d1 = Dropout(dropout)(b1)


f = Flatten()(d1)
z0 = Dense(dense, activation='relu')(f)
d2 = Dropout(dropout)(z0)
y_out = Dense(1, activation='sigmoid')(d2)

cnn = Model(inputs=x_in, outputs=y_out)
2022-08-09 14:28:30.166152: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.12/x64/lib
2022-08-09 14:28:30.166188: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-08-09 14:28:30.166208: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (fv-az200-623): /proc/driver/nvidia/version does not exist
2022-08-09 14:28:30.166638: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

4. Compile the CNN#

Next, we compile the model. As in Feinstein et al. 2020, we select the Adam optimizer and the binary cross entropy loss function (as this is a binary classification problem).

You can learn more about optimizers and more about loss functions for regression tasks in the Keras documentation.

# Compile Model
optimizer = 'adam'
fit_metrics = ['accuracy'] 
loss = 'binary_crossentropy'
cnn.compile(loss=loss, optimizer=optimizer, metrics=fit_metrics)
cnn.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 200, 1)]          0         
                                                                 
 conv1d (Conv1D)             (None, 200, 7)            119       
                                                                 
 max_pooling1d (MaxPooling1D  (None, 100, 7)           0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 100, 7)            0         
                                                                 
 conv1d_1 (Conv1D)           (None, 100, 3)            1347      
                                                                 
 max_pooling1d_1 (MaxPooling  (None, 50, 3)            0         
 1D)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 50, 3)             0         
                                                                 
 flatten (Flatten)           (None, 150)               0         
                                                                 
 dense (Dense)               (None, 32)                4832      
                                                                 
 dropout_2 (Dropout)         (None, 32)                0         
                                                                 
 dense_1 (Dense)             (None, 1)                 33        
                                                                 
=================================================================
Total params: 6,331
Trainable params: 6,331
Non-trainable params: 0
_________________________________________________________________

5. Train the CNN to perform a classification task#

We will start with training for 20 epochs, but this almost certainly won’t be long enough to get great results. Once you’ve run your model and evaluated the fit, you can come back here and run the next cell again for 100 epochs or longer.

You can learn more about fit here.

nb_epoch = 20
batch_size = 64
shuffle = True

# Train
history = cnn.fit(ds.train_data, ds.train_labels,
                  batch_size=batch_size, 
                  epochs=nb_epoch, 
                  validation_data=(ds.val_data, ds.val_labels), 
                  shuffle=shuffle,
                  verbose=True)
Epoch 1/20
 1/23 [>.............................] - ETA: 18s - loss: 0.5972 - accuracy: 0.7656

 4/23 [====>.........................] - ETA: 0s - loss: 0.6172 - accuracy: 0.7031 

 8/23 [=========>....................] - ETA: 0s - loss: 0.6206 - accuracy: 0.7109

12/23 [==============>...............] - ETA: 0s - loss: 0.6104 - accuracy: 0.7214

15/23 [==================>...........] - ETA: 0s - loss: 0.6011 - accuracy: 0.7281

19/23 [=======================>......] - ETA: 0s - loss: 0.6029 - accuracy: 0.7229

23/23 [==============================] - ETA: 0s - loss: 0.5994 - accuracy: 0.7261

23/23 [==============================] - 1s 25ms/step - loss: 0.5994 - accuracy: 0.7261 - val_loss: 0.6056 - val_accuracy: 0.7088
Epoch 2/20
 1/23 [>.............................] - ETA: 0s - loss: 0.6002 - accuracy: 0.7188

 4/23 [====>.........................] - ETA: 0s - loss: 0.5905 - accuracy: 0.7188

 8/23 [=========>....................] - ETA: 0s - loss: 0.5861 - accuracy: 0.7305

12/23 [==============>...............] - ETA: 0s - loss: 0.5951 - accuracy: 0.7227

16/23 [===================>..........] - ETA: 0s - loss: 0.5897 - accuracy: 0.7295

20/23 [=========================>....] - ETA: 0s - loss: 0.5896 - accuracy: 0.7289

23/23 [==============================] - 0s 18ms/step - loss: 0.5876 - accuracy: 0.7309 - val_loss: 0.6030 - val_accuracy: 0.7088
Epoch 3/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5964 - accuracy: 0.7188

 4/23 [====>.........................] - ETA: 0s - loss: 0.5824 - accuracy: 0.7266

 7/23 [========>.....................] - ETA: 0s - loss: 0.5961 - accuracy: 0.7188

10/23 [============>.................] - ETA: 0s - loss: 0.5828 - accuracy: 0.7328

13/23 [===============>..............] - ETA: 0s - loss: 0.5743 - accuracy: 0.7428

17/23 [=====================>........] - ETA: 0s - loss: 0.5866 - accuracy: 0.7335

21/23 [==========================>...] - ETA: 0s - loss: 0.5975 - accuracy: 0.7232

23/23 [==============================] - 0s 20ms/step - loss: 0.5924 - accuracy: 0.7309 - val_loss: 0.6108 - val_accuracy: 0.7088
Epoch 4/20
 1/23 [>.............................] - ETA: 0s - loss: 0.6753 - accuracy: 0.6250

 5/23 [=====>........................] - ETA: 0s - loss: 0.6300 - accuracy: 0.6750

 8/23 [=========>....................] - ETA: 0s - loss: 0.6094 - accuracy: 0.7051

11/23 [=============>................] - ETA: 0s - loss: 0.5917 - accuracy: 0.7216

15/23 [==================>...........] - ETA: 0s - loss: 0.5890 - accuracy: 0.7240

19/23 [=======================>......] - ETA: 0s - loss: 0.5882 - accuracy: 0.7245

23/23 [==============================] - ETA: 0s - loss: 0.5813 - accuracy: 0.7309

23/23 [==============================] - 0s 18ms/step - loss: 0.5813 - accuracy: 0.7309 - val_loss: 0.6014 - val_accuracy: 0.7088
Epoch 5/20
 1/23 [>.............................] - ETA: 0s - loss: 0.4712 - accuracy: 0.8281

 5/23 [=====>........................] - ETA: 0s - loss: 0.6044 - accuracy: 0.7125

 9/23 [==========>...................] - ETA: 0s - loss: 0.5928 - accuracy: 0.7240

13/23 [===============>..............] - ETA: 0s - loss: 0.5794 - accuracy: 0.7380

17/23 [=====================>........] - ETA: 0s - loss: 0.5829 - accuracy: 0.7353

21/23 [==========================>...] - ETA: 0s - loss: 0.5896 - accuracy: 0.7284

23/23 [==============================] - 0s 17ms/step - loss: 0.5870 - accuracy: 0.7309 - val_loss: 0.5996 - val_accuracy: 0.7088
Epoch 6/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5342 - accuracy: 0.7812

 5/23 [=====>........................] - ETA: 0s - loss: 0.6032 - accuracy: 0.7031

 9/23 [==========>...................] - ETA: 0s - loss: 0.5877 - accuracy: 0.7257

13/23 [===============>..............] - ETA: 0s - loss: 0.5879 - accuracy: 0.7272

17/23 [=====================>........] - ETA: 0s - loss: 0.5854 - accuracy: 0.7289

21/23 [==========================>...] - ETA: 0s - loss: 0.5859 - accuracy: 0.7292

23/23 [==============================] - 0s 16ms/step - loss: 0.5826 - accuracy: 0.7316 - val_loss: 0.5986 - val_accuracy: 0.7088
Epoch 7/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5571 - accuracy: 0.7656

 5/23 [=====>........................] - ETA: 0s - loss: 0.5747 - accuracy: 0.7406

 9/23 [==========>...................] - ETA: 0s - loss: 0.6025 - accuracy: 0.7101

13/23 [===============>..............] - ETA: 0s - loss: 0.5951 - accuracy: 0.7200

17/23 [=====================>........] - ETA: 0s - loss: 0.5885 - accuracy: 0.7261

21/23 [==========================>...] - ETA: 0s - loss: 0.5811 - accuracy: 0.7329

23/23 [==============================] - 0s 16ms/step - loss: 0.5813 - accuracy: 0.7322 - val_loss: 0.6030 - val_accuracy: 0.7088
Epoch 8/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5787 - accuracy: 0.7500

 5/23 [=====>........................] - ETA: 0s - loss: 0.5897 - accuracy: 0.7219

 9/23 [==========>...................] - ETA: 0s - loss: 0.5858 - accuracy: 0.7274

13/23 [===============>..............] - ETA: 0s - loss: 0.5835 - accuracy: 0.7284

17/23 [=====================>........] - ETA: 0s - loss: 0.5811 - accuracy: 0.7307

21/23 [==========================>...] - ETA: 0s - loss: 0.5743 - accuracy: 0.7366

23/23 [==============================] - 0s 15ms/step - loss: 0.5765 - accuracy: 0.7336 - val_loss: 0.5986 - val_accuracy: 0.7088
Epoch 9/20
 1/23 [>.............................] - ETA: 0s - loss: 0.6119 - accuracy: 0.7188

 5/23 [=====>........................] - ETA: 0s - loss: 0.5908 - accuracy: 0.7250

 9/23 [==========>...................] - ETA: 0s - loss: 0.5908 - accuracy: 0.7170

13/23 [===============>..............] - ETA: 0s - loss: 0.5757 - accuracy: 0.7380

17/23 [=====================>........] - ETA: 0s - loss: 0.5695 - accuracy: 0.7408

21/23 [==========================>...] - ETA: 0s - loss: 0.5740 - accuracy: 0.7359

23/23 [==============================] - 0s 16ms/step - loss: 0.5759 - accuracy: 0.7350 - val_loss: 0.5940 - val_accuracy: 0.7143
Epoch 10/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5760 - accuracy: 0.7344

 5/23 [=====>........................] - ETA: 0s - loss: 0.5491 - accuracy: 0.7625

 9/23 [==========>...................] - ETA: 0s - loss: 0.5565 - accuracy: 0.7517

13/23 [===============>..............] - ETA: 0s - loss: 0.5509 - accuracy: 0.7572

17/23 [=====================>........] - ETA: 0s - loss: 0.5650 - accuracy: 0.7454

21/23 [==========================>...] - ETA: 0s - loss: 0.5713 - accuracy: 0.7388

23/23 [==============================] - 0s 15ms/step - loss: 0.5707 - accuracy: 0.7411 - val_loss: 0.5878 - val_accuracy: 0.7143
Epoch 11/20
 1/23 [>.............................] - ETA: 0s - loss: 0.6063 - accuracy: 0.7031

 5/23 [=====>........................] - ETA: 0s - loss: 0.5766 - accuracy: 0.7219

 9/23 [==========>...................] - ETA: 0s - loss: 0.5727 - accuracy: 0.7309

13/23 [===============>..............] - ETA: 0s - loss: 0.5850 - accuracy: 0.7212

17/23 [=====================>........] - ETA: 0s - loss: 0.5726 - accuracy: 0.7371

21/23 [==========================>...] - ETA: 0s - loss: 0.5626 - accuracy: 0.7463

23/23 [==============================] - 0s 15ms/step - loss: 0.5654 - accuracy: 0.7445 - val_loss: 0.6009 - val_accuracy: 0.7143
Epoch 12/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5016 - accuracy: 0.7969

 5/23 [=====>........................] - ETA: 0s - loss: 0.5394 - accuracy: 0.7688

 9/23 [==========>...................] - ETA: 0s - loss: 0.5460 - accuracy: 0.7639

13/23 [===============>..............] - ETA: 0s - loss: 0.5555 - accuracy: 0.7548

17/23 [=====================>........] - ETA: 0s - loss: 0.5602 - accuracy: 0.7509

21/23 [==========================>...] - ETA: 0s - loss: 0.5609 - accuracy: 0.7493

23/23 [==============================] - 0s 15ms/step - loss: 0.5598 - accuracy: 0.7514 - val_loss: 0.5799 - val_accuracy: 0.7143
Epoch 13/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5025 - accuracy: 0.8281

 5/23 [=====>........................] - ETA: 0s - loss: 0.5417 - accuracy: 0.7656

 9/23 [==========>...................] - ETA: 0s - loss: 0.5555 - accuracy: 0.7552

13/23 [===============>..............] - ETA: 0s - loss: 0.5619 - accuracy: 0.7464

17/23 [=====================>........] - ETA: 0s - loss: 0.5619 - accuracy: 0.7482

21/23 [==========================>...] - ETA: 0s - loss: 0.5567 - accuracy: 0.7515

23/23 [==============================] - 0s 15ms/step - loss: 0.5546 - accuracy: 0.7520 - val_loss: 0.5841 - val_accuracy: 0.7143
Epoch 14/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5396 - accuracy: 0.7656

 5/23 [=====>........................] - ETA: 0s - loss: 0.5452 - accuracy: 0.7625

 9/23 [==========>...................] - ETA: 0s - loss: 0.5367 - accuracy: 0.7622

13/23 [===============>..............] - ETA: 0s - loss: 0.5335 - accuracy: 0.7656

17/23 [=====================>........] - ETA: 0s - loss: 0.5361 - accuracy: 0.7610

21/23 [==========================>...] - ETA: 0s - loss: 0.5406 - accuracy: 0.7552

23/23 [==============================] - 0s 15ms/step - loss: 0.5433 - accuracy: 0.7541 - val_loss: 0.5752 - val_accuracy: 0.7308
Epoch 15/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5951 - accuracy: 0.7812

 5/23 [=====>........................] - ETA: 0s - loss: 0.5226 - accuracy: 0.7844

 9/23 [==========>...................] - ETA: 0s - loss: 0.5503 - accuracy: 0.7622

13/23 [===============>..............] - ETA: 0s - loss: 0.5412 - accuracy: 0.7620

17/23 [=====================>........] - ETA: 0s - loss: 0.5450 - accuracy: 0.7592

21/23 [==========================>...] - ETA: 0s - loss: 0.5512 - accuracy: 0.7507

23/23 [==============================] - 0s 15ms/step - loss: 0.5445 - accuracy: 0.7589 - val_loss: 0.5580 - val_accuracy: 0.7308
Epoch 16/20
 1/23 [>.............................] - ETA: 0s - loss: 0.5638 - accuracy: 0.7500

 5/23 [=====>........................] - ETA: 0s - loss: 0.5436 - accuracy: 0.7563

 9/23 [==========>...................] - ETA: 0s - loss: 0.5445 - accuracy: 0.7552

13/23 [===============>..............] - ETA: 0s - loss: 0.5383 - accuracy: 0.7584

17/23 [=====================>........] - ETA: 0s - loss: 0.5425 - accuracy: 0.7564

21/23 [==========================>...] - ETA: 0s - loss: 0.5358 - accuracy: 0.7656

23/23 [==============================] - 0s 16ms/step - loss: 0.5336 - accuracy: 0.7664 - val_loss: 0.5717 - val_accuracy: 0.7308
Epoch 17/20
 1/23 [>.............................] - ETA: 0s - loss: 0.4762 - accuracy: 0.7969

 5/23 [=====>........................] - ETA: 0s - loss: 0.5577 - accuracy: 0.7312

 9/23 [==========>...................] - ETA: 0s - loss: 0.5512 - accuracy: 0.7396

13/23 [===============>..............] - ETA: 0s - loss: 0.5474 - accuracy: 0.7476

17/23 [=====================>........] - ETA: 0s - loss: 0.5436 - accuracy: 0.7528

21/23 [==========================>...] - ETA: 0s - loss: 0.5346 - accuracy: 0.7612

23/23 [==============================] - 0s 15ms/step - loss: 0.5329 - accuracy: 0.7623 - val_loss: 0.5587 - val_accuracy: 0.7308
Epoch 18/20
 1/23 [>.............................] - ETA: 0s - loss: 0.4371 - accuracy: 0.8281

 5/23 [=====>........................] - ETA: 0s - loss: 0.4930 - accuracy: 0.7688

 9/23 [==========>...................] - ETA: 0s - loss: 0.5213 - accuracy: 0.7569

13/23 [===============>..............] - ETA: 0s - loss: 0.5173 - accuracy: 0.7668

17/23 [=====================>........] - ETA: 0s - loss: 0.5210 - accuracy: 0.7647

21/23 [==========================>...] - ETA: 0s - loss: 0.5146 - accuracy: 0.7701

23/23 [==============================] - 0s 15ms/step - loss: 0.5205 - accuracy: 0.7657 - val_loss: 0.5526 - val_accuracy: 0.7308
Epoch 19/20
 1/23 [>.............................] - ETA: 0s - loss: 0.4428 - accuracy: 0.8438

 5/23 [=====>........................] - ETA: 0s - loss: 0.4649 - accuracy: 0.8219

 9/23 [==========>...................] - ETA: 0s - loss: 0.5120 - accuracy: 0.7812

13/23 [===============>..............] - ETA: 0s - loss: 0.5143 - accuracy: 0.7752

17/23 [=====================>........] - ETA: 0s - loss: 0.5163 - accuracy: 0.7730

21/23 [==========================>...] - ETA: 0s - loss: 0.5109 - accuracy: 0.7790

23/23 [==============================] - 0s 15ms/step - loss: 0.5161 - accuracy: 0.7739 - val_loss: 0.5342 - val_accuracy: 0.7363
Epoch 20/20
 1/23 [>.............................] - ETA: 0s - loss: 0.4955 - accuracy: 0.7969

 5/23 [=====>........................] - ETA: 0s - loss: 0.4767 - accuracy: 0.8125

 9/23 [==========>...................] - ETA: 0s - loss: 0.4925 - accuracy: 0.7934

13/23 [===============>..............] - ETA: 0s - loss: 0.4922 - accuracy: 0.7873

17/23 [=====================>........] - ETA: 0s - loss: 0.5051 - accuracy: 0.7757

21/23 [==========================>...] - ETA: 0s - loss: 0.5017 - accuracy: 0.7835

23/23 [==============================] - 0s 15ms/step - loss: 0.5056 - accuracy: 0.7801 - val_loss: 0.5363 - val_accuracy: 0.7363
# save the model to file
cnn_file = 'flare_model.h5'
cnn.save(cnn_file)

6. Test the CNN performance#

Apply the CNN to predict flares on the “test” set, not used for training or validating the CNN, and evaluate the performance using a confusion matrix. See the documentation from sklearn on confusion matrices for more information. The code for generating and plotting the confusion matrix below was adapted from the application by Ciprijanovic et al. 2020 for DeepMerge.

def plot_confusion_matrix(cnn, input_data, input_labels):
    
    # Compute flare predictions for the test dataset
    predictions = cnn.predict(input_data)

    # Convert to binary classification 
    predictions = (predictions > 0.5).astype('int32') 
    
    # Compute the confusion matrix by comparing the test labels (ds.test_labels) with the test predictions
    cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])
    cm = cm.astype('float')

    # Normalize the confusion matrix results. 
    cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]
    
    # Plotting
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.matshow(cm_norm, cmap='binary_r')

    plt.title('Confusion matrix', y=1.08)
    
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Flare', 'No Flare'])
    
    ax.set_yticks([0, 1])
    ax.set_yticklabels(['Flare', 'No Flare'])

    plt.xlabel('Predicted')
    plt.ylabel('True')

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            ax.text(j, i, format(cm_norm[i, j], fmt), 
                    ha="center", va="center",color="white" if cm_norm[i, j] < thresh else "black")
    plt.show()
plot_confusion_matrix(cnn, ds.test_data, ds.test_labels)
../../../_images/Classifying_TESS_flares_with_CNNs_23_0.png

FAQ#

  • The results don’t look great… why? From the confusion matrix in Section 6, when faced with the test dataset (i.e., data not used for training or validation), the model predicts a large fraction of false positive flare events, and consequently not enough true negative flare events). The published models from Feinstein et al. 2020 perform much better, and the confusion matrix should look more like the results shown below. We note that in this notebook we are using a subset of the available training data, and we are training the model for only a subset of the optimal number of epochs for space and time considerations, but you are welcome to augment these restrictions, and as always check out the stella repository for more information!

file_url = 'https://archive.stsci.edu/hlsps/stella/hlsp_stella_tess_ensemblemodel_s042_tess_v0.1.0_cnn.h5'
pretrained_model = load_model(download_file(file_url, cache=True, show_progress=True))

plot_confusion_matrix(pretrained_model, ds.test_data, ds.test_labels)
../../../_images/Classifying_TESS_flares_with_CNNs_25_0.png
  • Can I improve the model by increasing the number of epochs its training for? We only trained for 20 epochs, which is many fewer than the published model. Go back to Section 4 (“Train the CNN to perform a classification task”) and increase the number of epochs to 100 (or more!) and train again. Does your model perform better? Your results may look better/worse/different from the published results due to the stochastic nature of training.

  • Can I try a different model? I think the results could be improved. Yes! You can try adding layers, swapping out the max pooling, changing the activation functions, swapping out the loss function, or trying a different optimizer or learning rate. Experiment and see what model changes give the best results. You should be aware: when you start training again, you pick up where your model left off. If you want to “reset” your model to epoch 0 and random weights, you should run the cells to make and compile the model again.

  • I want to test my model on my training data! No. You will convince yourself that your results are much better than they actually are. Always keep your training, validation, and testing sets completely separate!

Extensions/Exercises#

  • Is the model “overfitting”? Using the results of the model’s history (saved as a result of the model training process), investigate the behavior of the training and validation losses and accuracies as a function of training epoch. Make a plot or two! How do the training and validation losses compare? How do the training and validation accuracies compare? If the loss for the validation set is higher than for the training set (and conversely, the accuracy for the validation set is lower than for the training set), the model may be suffering from overfitting.

  • Try applying this model to a new dataset Using the built-in functionality provided by the stella package, you can pre-process your own 2-minute cadence TESS light curves and predict flares. An example workflow is shown below:

7. Predict flares in a new dataset#

In this step, we will download light curves directly from TESS, pre-process them with stella for input to the CNN, and predict flares. The sample is a set of bright M dwarfs not featured in the training/validation/tests datasets.

ticids = ['120461526', '278779899', '139754153', '273418879', '52121469', '188580272', '394015919', '402104884']
# for all the selected targets, pull the available lightcurves using the lightkurve package
lcs = []
for name in ticids:
    lc = search_lightcurve(target='TIC'+name, mission='TESS', sector=[1, 2], author='SPOC')
    lc = lc.download_all()
    lcs.append(lc)
# load the CNN using `stella`
cnn_stella = stella.ConvNN(output_dir=data_dir, ds=ds)
fig = plt.figure(0, [8, 10])

for i, lc in enumerate(lcs):
    # pull out on the first light curve in each set, if more than one exist
    if len(lc)>0: lc = lc[0]
        
    # predict the flare probability light cuvey for the input data using `stella` 
    # (which applies the necessary pre-processing to the data for input to the CNN)
    cnn_stella.predict(cnn_file, times=lc.time.value, fluxes=lc.flux, errs=lc.flux_err)
    
    ax = fig.add_subplot(4,2,i+1)
    im = ax.scatter(cnn_stella.predict_time[0], cnn_stella.predict_flux[0], c=cnn_stella.predictions[0], s=1. )
    
    plt.colorbar(im, ax=ax, label='Probability of Flare')
    ax.set_xlabel('Time [BJD-2457000]')
    ax.set_ylabel('Normalized Flux')
    ax.set_title('TIC {}'.format(lc.targetid));
plt.tight_layout()
plt.show()
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:02<00:00,  2.86s/it]
100%|██████████| 1/1 [00:02<00:00,  2.86s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:03<00:00,  3.03s/it]
100%|██████████| 1/1 [00:03<00:00,  3.03s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:03<00:00,  3.16s/it]
100%|██████████| 1/1 [00:03<00:00,  3.16s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:03<00:00,  3.08s/it]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:02<00:00,  2.99s/it]
100%|██████████| 1/1 [00:02<00:00,  3.00s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:03<00:00,  3.24s/it]
100%|██████████| 1/1 [00:03<00:00,  3.24s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]

../../../_images/Classifying_TESS_flares_with_CNNs_32_32.png

FAQ#

  • Why does the data near gaps have a high probability of being flares? This isn’t necessarily the case. The issue here is that stella needs 100 data points on either side of a given cadence to create an “example” (those 200 cadence samples we trained on). When there’s a gap in the data, the first 100 points can’t be centered in each example properly. As such, stella cannot accurately predict flares in these data and skips it.

  • But what if there are flares near the data gaps? There may very well be flares towards the data gaps! Unfortunately, stella cannot find those for you at present, and you’ll need to identify those yourself.

About this Notebook#

Author:
Claire Murray, Assistant Astronomer, cmurray1@stsci.edu

Additional Contributors:
Yotam Cohen, STScI Staff Scientist, ycohen@stsci.edu

Info:
This notebook is based on the stella software package for the CNN used in “Flare Statistics for Young Stars from a Convolutional Neural Network Analysis of TESS Data”, Adina D. Feinstein et al. Astronomical Journal, Volume 160, Issue 5, November 2020, and the notebook “CNN_for_cluster_masses” by Michelle Ntampaka, Assistant Astronomer, mntampaka@stsci.edu.

Updated On: 2022-5-25

Citations#

If you use this CNN, stella, astropy, or keras for published research, please cite the authors. Follow these links for more information:

Top of Page Space Telescope Logo