Classifying flaring stars with stella: a convolutional neural network for TESS#

Learning Goals#

In this tutorial, you will see an example of building, compiling, and training a CNN to classify astronomical data in vector form. By the end of this tutorial you will have a working example of a simple Convolutional Neural Network (CNN) in Keras.

Introduction#

CNNs are a class of machine learning (ML) algorithms that can extract information from data. In this notebook, you will walk through the basic steps of applying a CNN to data:

  1. Load the data and visualize a sample of the data.

  2. Build a CNN in Keras.

  3. Compile the CNN.

  4. Train the CNN to perform a classification task.

  5. Evaluate the CNN performance on test data with a confusion matrix.

  6. Build a new, unlabeled dataset and apply the CNN.

CNNs can be applied to a wide range of vector analysis tasks, including classification and regression. Here, we will build, compile, and train CNN to classify whether a star has undergone a flaring event from its observed Transiting Exoplanet Survey Satellite (TESS) 2-minute light curve, and where the flaring events are located within the time series. This work is based on the model described in the stella software package.

NOTE: The stella team has publicly-available code and documentation for demonstrating the architecture and optimal performance of this model, which we encourage you to check out! The goal of this notebook is to step through the model building and training process.

About this Notebook#

Author:
Claire Murray, Assistant Astronomer, cmurray1@stsci.edu

Additional Contributors:
Yotam Cohen, STScI Staff Scientist, ycohen@stsci.edu

Info:
This notebook is based on the stella software package for the CNN used in “Flare Statistics for Young Stars from a Convolutional Neural Network Analysis of TESS Data”, Adina D. Feinstein et al. Astronomical Journal, Volume 160, Issue 5, November 2020, and the notebook “CNN_for_cluster_masses” by Michelle Ntampaka, Assistant Astronomer, mntampaka@stsci.edu.

Published: 2022-06-01

Updated: 2023-12-13

This notebook currently fails to execute, use as reference only


Imports#

This notebook uses the following:

  • numpy to handle array functions

  • astropy for accessing FITS files

  • matplotlib.pyplot for plotting data

  • keras for building the CNN

  • sklearn for model performance metrics

  • lightkurve.search for extracting light curves

For other packages, you can install them using pip or conda.

# arrays
import numpy as np

# fits
from astropy.io import fits
from astropy.utils.data import download_file

# plotting
import matplotlib.pyplot as plt

# keras
from keras.models import Model, load_model
from keras.layers import Input, Flatten, Dense, Dropout, Conv1D, MaxPooling1D

# sklearn for performance metrics
from sklearn import metrics

# lightkurve
from lightkurve import search_lightcurve

# from IPython import get_ipython
# get_ipython().run_line_magic('matplotlib', 'notebook')

# set random seed for reproducibility 
np.random.seed(42)

1. Download the training data#

Load the sample of TESS lightcurves (input vectors) and flare classifications (output labels) to be used to train the CNN.

The training set contains stars observed at 2-minute cadence in TESS Sectors 1 and 2, classified by hand and presented as a flare catalog by Gunther et al. 2020. The light curves are processed into examples of length 200 cadences, where each flaring event, if present, is located at the center of the example. The full sample of lightcurves contains 8694 positive classes (flare), and 35896 negative classes (no flare). For this notebook, we will download a subset of this sample. These data are described in Feinstein et al. 2020.

The CNN will be used to predict the presence of flaring events as a function of observing cadence. The input to the CNN is a light curve (time, flux, and flux error) and the output is a “probability light curve”, or probabilities (value between 0 and 1) that the measurement at each cadence is of a flaring event (1=flare, 0=no flare). In other words, the CNN performs a classification task at each cadence.

%%time
file_url = 'https://archive.stsci.edu/hlsps/hellouniverse/hellouniverse_stella_500.fits'
hdu = fits.open(download_file(file_url, cache=True))

The stella dataset includes training, test and validation lightcurves (input vectors) and flare labels (output labels). For more on how these are constructed, see Feinstein et al. 2020. For our purposes (i.e., building the stella CNN from scratch to illustrate its structure and function), we first unpack the multi-extension table to isolate the training dataset and training labels, validation dataset and validation labels, and testing dataset and testing labels.

train_data = hdu[1].data['train_data']
train_labels = hdu[1].data['train_labels']

test_data = hdu[2].data['test_data']
test_labels = hdu[2].data['test_labels']

val_data = hdu[3].data['val_data']
val_labels = hdu[3].data['val_labels']

To visualize the structure of the lightcurves in the training set, in the cell below we plot a random selection of examples. The resulting figure displays a grid of 16 light curves (flux as a function of time), one per panel. If a flare is detected in the light curve, it is colored red and the title of the panel is “Flare”. If a flare is not detected in the light curve, it is colored black and the title of the panel is “Non-flare”.

# select random image indices:
example_ids = np.random.choice(len(train_labels), 16)

# pull the lightcurves and labels for these selections
example_lightcurves = [train_data[j] for j in example_ids]
example_labels = [train_labels[j] for j in example_ids]


# initialize your figure
fig = plt.figure(figsize=(10, 10))

# loop through the randomly selected images and plot with labels
colors = {1: 'r', 0: 'k'}
titles = {1: 'Flare', 0: 'Non-flare'}
for i in range(len(example_ids)):
    plt.subplot(4, 4, i + 1)
    plt.plot(example_lightcurves[i], color=colors[example_labels[i]])
    plt.title(titles[example_labels[i]])
    plt.xlabel('Cadences')
    
plt.tight_layout()
plt.show()

2. Build a CNN in Keras#

Here, we will build the CNN model described in Feinstein et al. 2020 and implemented in stella from scratch.

Further details about Conv1D, MaxPooling1D, BatchNormalization, Dropout, and Dense layers can be found in the Keras Layers Documentation. Further details about the sigmoid and softmax activation function can be found in the Keras Activation Function Documentation.

# ------------------------------------------------------------------------------
# generate the model architecture
# Written for Keras 2
# ------------------------------------------------------------------------------

seed = 2
np.random.seed(seed)

filter1 = 16
filter2 = 64
dense = 32
dropout = 0.1

# Define architecture for model
data_shape = np.shape(train_data)
input_shape = (np.shape(train_data)[1], 1)

x_in = Input(shape=input_shape)
c0 = Conv1D(7, filter1, activation='relu', padding='same', input_shape=input_shape)(x_in)
b0 = MaxPooling1D(pool_size=2)(c0)
d0 = Dropout(dropout)(b0)

c1 = Conv1D(3, filter2, activation='relu', padding='same')(d0)
b1 = MaxPooling1D(pool_size=2)(c1)
d1 = Dropout(dropout)(b1)


f = Flatten()(d1)
z0 = Dense(dense, activation='relu')(f)
d2 = Dropout(dropout)(z0)
y_out = Dense(1, activation='sigmoid')(d2)

cnn = Model(inputs=x_in, outputs=y_out)

3. Compile the CNN#

Next, we compile the model. As in Feinstein et al. 2020, we select the Adam optimizer and the binary cross entropy loss function (as this is a binary classification problem).

You can learn more about optimizers and more about loss functions for regression tasks in the Keras documentation.

# Compile Model
optimizer = 'adam'
fit_metrics = ['accuracy'] 
loss = 'binary_crossentropy'
cnn.compile(loss=loss, optimizer=optimizer, metrics=fit_metrics)
cnn.summary()

4. Train the CNN to perform a classification task#

We will start with training for 20 epochs, but this almost certainly won’t be long enough to get great results. Once you’ve run your model and evaluated the fit, you can come back here and run the next cell again for 100 epochs or longer.

You can learn more about fit here.

nb_epoch = 20
batch_size = 64
shuffle = True

# Train
history = cnn.fit(train_data, train_labels,
                  batch_size=batch_size, 
                  epochs=nb_epoch, 
                  validation_data=(val_data, val_labels), 
                  shuffle=shuffle,
                  verbose=True)
# save the model to file
cnn_file = 'flare_model.h5'
cnn.save(cnn_file)

5. Test the CNN performance#

Apply the CNN to predict flares on the “test” set, not used for training or validating the CNN, and evaluate the performance using a confusion matrix. See the documentation from sklearn on confusion matrices for more information. The code for generating and plotting the confusion matrix below was adapted from the application by Ciprijanovic et al. 2020 for DeepMerge.

def plot_confusion_matrix(cnn, input_data, input_labels):
    
    # Compute flare predictions for the test dataset
    predictions = cnn.predict(input_data)

    # Convert to binary classification 
    predictions = (predictions > 0.5).astype('int32') 
    
    # Compute the confusion matrix by comparing the test labels (ds.test_labels) with the test predictions
    cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])
    cm = cm.astype('float')

    # Normalize the confusion matrix results. 
    cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]
    
    # Plotting
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.matshow(cm_norm, cmap='binary_r')

    plt.title('Confusion matrix', y=1.08)
    
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Flare', 'No Flare'])
    
    ax.set_yticks([0, 1])
    ax.set_yticklabels(['Flare', 'No Flare'])

    plt.xlabel('Predicted')
    plt.ylabel('True')

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            ax.text(j, i, format(cm_norm[i, j], fmt), 
                    ha="center", va="center", color="white" if cm_norm[i, j] < thresh else "black")
    plt.show()

In the cell below, we execute the code to plot the confusion matrix. The result is a 2x2 grid of squares, with the predicted classes (“Flare” and “No Flare”) on the x-axis and the true classes (“Flare” and “No Flare”) on the y-axis. Numbers at the center of each panel quantify the fraction of results in each true class that are found in each predicted class (e.g., the fraction of true flares which are predicted to be flares), so that the numbers in each horizontal row sum to 1.0. The color of each panel also corresponds to these fractions and range from white to black, where white is equal to 1.0 and black is equal to 0.

In this example, the values in the top row are 1.0 and 0.0, meaning that 100% of true flares are predicted to be flares, and 0% of true flares are predicted to be no flare. The values in the bottom row are 0.7 and 0.3, meaning that 70% of the true not flares are predicted to be flares, and 30% of the true not flares are predicted to be not flares.

plot_confusion_matrix(cnn, test_data, test_labels)

FAQ#

  • The results don’t look great… why? From the confusion matrix in Section 5, when faced with the test dataset (i.e., data not used for training or validation), the model predicts a large fraction of false positive flare events, and consequently not enough true negative flare events). The published models from Feinstein et al. 2020 perform much better, and the confusion matrix should look more like the results shown below. We note that in this notebook we are using a subset of the available training data, and we are training the model for only a subset of the optimal number of epochs for space and time considerations, but you are welcome to augment these restrictions, and as always check out the stella repository for more information!

In the cell below, we execute the code to plot the confusion matrix. The result is a 2x2 grid of squares, with the predicted classes (“Flare” and “No Flare”) on the x-axis and the true classes (“Flare” and “No Flare”) on the y-axis. Numbers at the center of each panel quantify the fraction of results in each true class that are found in each predicted class (e.g., the fraction of true flares which are predicted to be flares), so that the numbers in each horizontal row sum to 1.0. The color of each panel also corresponds to these fractions and range from white to black, where white is equal to 1.0 and black is equal to 0.

In this example (which uses the trained model from Fein, the values in the top row are 0.98 and 0.02, meaning that % of true flares are predicted to be flares, and 0% of true flares are predicted to be no flare. The values in the bottom row are 0.7 and 0.3, meaning that 70% of the true not flares are predicted to be flares, and 30% of the true not flares are predicted to be not flares.

file_url = 'https://archive.stsci.edu/hlsps/stella/hlsp_stella_tess_ensemblemodel_s042_tess_v0.1.0_cnn.h5'
pretrained_model = load_model(download_file(file_url, cache=True, show_progress=True))

plot_confusion_matrix(pretrained_model, test_data, test_labels)
  • Can I improve the model by increasing the number of epochs its training for? We only trained for 20 epochs, which is many fewer than the published model. Go back to Section 4 (“Train the CNN to perform a classification task”) and increase the number of epochs to 100 (or more!) and train again. Does your model perform better? Your results may look better/worse/different from the published results due to the stochastic nature of training.

  • Can I try a different model? I think the results could be improved. Yes! You can try adding layers, swapping out the max pooling, changing the activation functions, swapping out the loss function, or trying a different optimizer or learning rate. Experiment and see what model changes give the best results. You should be aware: when you start training again, you pick up where your model left off. If you want to “reset” your model to epoch 0 and random weights, you should run the cells to make and compile the model again.

  • I want to test my model on my training data! No. You will convince yourself that your results are much better than they actually are. Always keep your training, validation, and testing sets completely separate!

Extensions/Exercises#

  • Is the model “overfitting”? Using the results of the model’s history (saved as a result of the model training process), investigate the behavior of the training and validation losses and accuracies as a function of training epoch. Make a plot or two! How do the training and validation losses compare? How do the training and validation accuracies compare? If the loss for the validation set is higher than for the training set (and conversely, the accuracy for the validation set is lower than for the training set), the model may be suffering from overfitting.

  • Try applying this model to a new dataset You can pre-process your own 2-minute cadence TESS light curves and predict flares. An example workflow is shown below:

6. Predict flares in a new dataset#

In this step, we will download light curves directly from TESS, pre-process them for input to the CNN, and predict flares. The sample is a set of bright M dwarfs not featured in the training/validation/tests datasets.

First select a sample of sources by their ids numbers:

ticids = ['120461526', '278779899', '139754153', '273418879', '52121469', '188580272', '394015919', '402104884']
# for all the selected targets, pull the available lightcurves using the lightkurve package
lcs = []
for name in ticids:
    lc = search_lightcurve(target='TIC'+name, mission='TESS', sector=[1, 2], author='SPOC')
    lc = lc.download_all()
    lcs.append(lc)

Here we define functions (based on stella) which pre-process the light curves (“identify gaps”) and apply a pre-trained model to predict flares (“stella_predict”)

def identify_gaps(t, cad_pad):
    """
    Identifies which cadences can be predicted on given
    locations of gaps in the data. Will always stay 
    cadences/2 away from the gaps.

    Returns lists of good indices to predict on.
    """

    # SETS ALL CADENCES AVAILABLE
    all_inds = np.arange(0, len(t), 1, dtype=int)

    # REMOVES BEGINNING AND ENDS
    bad_inds = np.arange(0, cad_pad, 1, dtype=int)
    bad_inds = np.append(bad_inds, np.arange(len(t)-cad_pad,
                                             len(t), 1, dtype=int))

    diff = np.diff(t)
    med, std = np.nanmedian(diff), np.nanstd(diff)

    bad = np.where(np.abs(diff) >= med + 1.5*std)[0]
    for b in bad:
        bad_inds = np.append(bad_inds, np.arange(b-cad_pad,
                                                 b+cad_pad,
                                                 1, dtype=int))
    bad_inds = np.sort(bad_inds)
    return np.delete(all_inds, bad_inds)
def stella_predict(modelname, time, flux, err):
    """
    Loads pre-trained model and predicts 
    flares in input light curves.
    
    Returns array of flare predictions
    """
    # load model               
    model = load_model(modelname)

    # GETS REQUIRED INPUT SHAPE FROM MODEL
    cadences = model.input.shape[1]
    cad_pad = cadences/2

    # normalize flux
    lc = flux / np.nanmedian(flux)  

    # identify good channels
    good_inds = identify_gaps(time, cad_pad)

    # reshape data 
    reshaped_data = np.zeros((len(lc), cadences))

    for i in good_inds:
        loc = [int(i-cad_pad), int(i+cad_pad)]
        f = lc[loc[0]:loc[1]]                  
        reshaped_data[i] = f

    reshaped_data = reshaped_data.reshape(reshaped_data.shape[0], 
                                          reshaped_data.shape[1], 1)
    
    # predict flares and reshape output
    preds = model.predict(reshaped_data)
    preds = np.reshape(preds, (len(preds),))

    return time, lc, err, preds

Below we loop through the selected targets and generate flare predictions using the functions defined above. The figure produced by this cell is a 2x4 grid of panels, each displaying a light curve (flux as a function of time). The light curve data points are colored according to the probability of a flare being present from the model.

fig = plt.figure(0, [8, 10])

for i, lc in enumerate(lcs):
    # pull out on the first light curve in each set, if more than one exist
    if len(lc) > 0: 
        lc = lc[0]
        
    # predict the flare probability light cuvey for the input data using `stella` 
    # (which applies the necessary pre-processing to the data for input to the CNN)
    time, flux, err, preds = stella_predict(cnn_file, time=lc.time.value, flux=lc.flux, err=lc.flux_err)
    # print(np.shape(lc.time.value), np.shape(preds))
                   
    ax = fig.add_subplot(4, 2, i+1)
    im = ax.scatter(time, flux, c=preds, s=1.)
    
    plt.colorbar(im, ax=ax, label='Probability of Flare')
    ax.set_xlabel('Time [BJD-2457000]')
    ax.set_ylabel('Normalized Flux')
    ax.set_title('TIC {}'.format(lc.targetid))
plt.tight_layout()
plt.show()

FAQ#

  • Why does the data near gaps have a high probability of being flares? This isn’t necessarily the case. The issue here is that stella needs 100 data points on either side of a given cadence to create an “example” (those 200 cadence samples we trained on). When there’s a gap in the data, the first 100 points can’t be centered in each example properly. As such, stella cannot accurately predict flares in these data and skips it.

  • But what if there are flares near the data gaps? There may very well be flares towards the data gaps! Unfortunately, stella cannot find those for you at present, and you’ll need to identify those yourself.

Citations#

If you use this CNN, stella, astropy, or keras for published research, please cite the authors. Follow these links for more information:

Top of Page Space Telescope Logo