Classifying JWST-HST galaxy mergers with CNNs#

Learning Goals#

In this tutorial, you will see an example of building, compiling, and training a CNN on simulated astronomical data. By the end of this tutorial you will have a working example of a simple Convolutional Neural Network (CNN) in Keras.


CNNs are a class of machine learning (ML) algorithms that can extract information from images. In this notebook, you will walk through the basic steps of applying a CNN to data:

  1. Load the data and visualize a sample of the data.

  2. Divide the data into training, validation, and testing sets.

  3. Build a CNN in Keras.

  4. Compile the CNN.

  5. Train the CNN to perform a classification task.

  6. Evaluate the results.

CNNs can be applied to a wide range of image recognition tasks, including classification and regression. In this tutorial, we will build, compile, and train CNN to classify whether a galaxy has undergone a merger, using simulated Hubble Space Telescope images of galaxies. This work is based on the public data and code from DeepMerge (Ciprijanovic et al. 2020).

NOTE: The DeepMerge team has publicly-available code for demonstrating the architecture and optimal performace of the model, which we encourage you to check out! The goal of this notebook is to step through the model building and training process.


This notebook uses the following packages:

  • numpy to handle array functions

  • astropy for downloading and accessing FITS files

  • matplotlib.pyplot for plotting data

  • keras and tensorflow for building the CNN

  • sklearn for some utility functions

If you do not have these packages installed, you can install them using pip or conda.

# arrays
import numpy as np

# fits
from import fits
from import download_file
from astropy.visualization import simple_norm

# plotting
from matplotlib import pyplot as plt

# keras
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Activation, Dropout, BatchNormalization
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.regularizers import l2
from keras.callbacks import EarlyStopping

# sklearn (for machine learning)
from sklearn.model_selection import train_test_split
from sklearn import metrics

# from IPython import get_ipython
# get_ipython().run_line_magic('matplotlib', 'notebook')
2022-08-09 14:13:24.840391: W tensorflow/stream_executor/platform/default/] Could not load dynamic library ''; dlerror: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.12/x64/lib
2022-08-09 14:13:24.840422: I tensorflow/stream_executor/cuda/] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

1. Load the data and visualize a sample of the data#

Load the simulated galaxy observations (3-band images) and merger probabilities (output labels).

In total, there are 15,426 simulated images, each in three filters (F814W from the Advanced Camera for Surveys and F160W from the Wide Field Camera 3 on the Hubble Space Telescope (HST), and F160W and F356W from Near Infrared Camera on the James Webb Space Telescope (JWST)), retrieved and augmented from synthetic observations of the Illustris cosmological simulation. The sample includes 8120 galaxy mergers and 7306 non-mergers. Two versions of the sample are available, with and without realistic observational and experimental noise (“pristine” and “noisy”). The sample construction and augmentation process for the HST images is described in detail in Ciprijanovic et al. 2020, and is identical for the mock JWST images.

These datasets are hosted at the Mikulski Archive for Space Telescopes as an the DEEPMERGE high-level science product (HLSP).

The CNN will be trained to distinguish between merging and non-merging galaxies.

Load the data#

The simulated images are stored in FITS format. We refer you to the Astropy Documentation for further information about this format.

For this example, we will download the “pristine” set of galaxy images, i.e., those without added observational noise. To select the “noisy” sample, change the version below. Alternatively, you can download data files from the DEEPMERGE website.

version = 'pristine'
file_url = ''+version+'.fits'
hdu =, cache=True, show_progress=True))
CPU times: user 3.65 s, sys: 4.55 s, total: 8.2 s
Wall time: 1min 7s

Explore the header of the file for information about its contents

SIMPLE  =                    T / conforms to FITS standard                      
BITPIX  =                  -64 / array data type                                
NAXIS   =                    4 / number of array dimensions                     
NAXIS1  =                   75                                                  
NAXIS2  =                   75                                                  
NAXIS3  =                    3                                                  
NAXIS4  =                15426                                                  
EXTEND  =                    T                                                  
NAME1   = 'ImageX  '                                                            
NAME2   = 'ImageY  '                                                            
NAME3   = 'filter  '           / F814W,F356W,F160W                              
NAME4   = 'object  '                                                            
EXTNAME = 'Images  '                                                            
BUNIT   = 'microjanskies/arcsec^2' / image units                                
PIXSIZE =               0.1875 / arcsec                                         
DOI     = '10.17909/t9-vqk6-pc80'                                               
HLSPID  = 'DEEPMERGE'                                                           
HLSPLEAD= 'Aleksandra Ciprijanovic'                                             
HLSPNAME= 'Mock Image Training Sets for DeepMerge'                              
SIMULATD=                    T                                                  
HLSPTARG= 'Illustris Simulation'                                                
HLSPVER = 'v1      '                                                            
INSTRUME= 'ACS,WFC3,NIRCam'                                                     
LICENSE = 'CC BY 4.0'                                                           
LICENURL= ''                        
OBSERVAT= 'HST,JWST'                                                            
TELESCOP= 'HST,JWST'                                                            
PROPOSID= 'HST AR#13887'                                                        
REFERENC= '2021MNRAS.506..677C'                                                 
XTENSION= 'BINTABLE'           / binary table extension                         
BITPIX  =                    8 / array data type                                
NAXIS   =                    2 / number of array dimensions                     
NAXIS1  =                    8 / length of dimension 1                          
NAXIS2  =                15426 / length of dimension 2                          
PCOUNT  =                    0 / number of group parameters                     
GCOUNT  =                    1 / number of groups                               
TFIELDS =                    1 / number of table fields                         
TTYPE1  = 'MergerLabel'                                                         
TFORM1  = 'D       '                                                            
EXTNAME = 'MergerLabel'                                                         

The file includes a primary header card with overall information, an image card with the simulated images, and a bintable with the merger labels for the images (1=merger, 0=non-merger).

Plot example images#

For a random selection of images, plot the images and their corresponding labels:

(15426, 3, 75, 75)
# set the random seed to get the same random set of images each time, or comment it out to get different ones!
# np.random.seed(206265)

# select 16 random image indices:
example_ids = np.random.choice(hdu[1].data.shape[0], 16)
# pull the F160W image (index=1) from the simulated dataset for these selections
examples = [hdu[0].data[j, 1, :, :] for j in example_ids]

# initialize your figure
fig = plt.figure(figsize=(8, 8)) 

# loop through the randomly selected images and plot with labels
for i, image in enumerate(examples):
    ax = fig.add_subplot(4, 4, i+1)
    norm = simple_norm(image, 'log', max_percent=99.75)

    ax.imshow(image, aspect='equal', cmap='binary_r', norm=norm)

2. Divide data into training, validation, and testing sets#

To divide the data set into training, validation, and testing data we will use Scikit-Learn’s train_test_split function.

We will denote the input images as X and their corresponding labels (i.e. the integer indicating whether or not they are a merger) as y, following the convention used by sklearn.

X = hdu[0].data
y = hdu[1].data

Following the authors, we will split the data into 70:10:20 ratio of train:validate:test

# as above, set the random seed to randomly split the images in a repeatable way. Feel free to try different values!
random_state = 42

X = np.asarray(X).astype('float32')
y = np.asarray(y).astype('float32')

# First split off 30% of the data for validation+testing
X_train, X_split, y_train, y_split = train_test_split(X, y, test_size=0.3, random_state=random_state, shuffle=True)

# Then divide this subset into training and testing sets
X_valid, X_test, y_valid, y_test = train_test_split(X_split, y_split, test_size=0.666, random_state=random_state, shuffle=True)

Next, reshape the image array as follows: (number_of_images, image_width, image_length, 3). This is referred to as a “channels last” approach, where the final axis denotes the number of “colors” or “channels”. The three-filter images have three channels, similar to RGB images like jpg and png image formats. CNN’s will work with an arbitrary number of channels.

imsize = np.shape(X_train)[2]

X_train = X_train.reshape(-1, imsize, imsize, 3)
X_valid = X_valid.reshape(-1, imsize, imsize, 3)
X_test = X_test.reshape(-1, imsize, imsize, 3)

3. Build a CNN in Keras#

Here, we will build the model described in Section 3 of Ciprijanovic et al. 2020.

Further details about Conv2D, MaxPooling2D, BatchNormalization, Dropout, and Dense layers can be found in the Keras Layers Documentation. Further details about the sigmoid and softmax activation function can be found in the Keras Activation Function Documentation.

# ------------------------------------------------------------------------------
# generate the model architecture
# Written for Keras 2
# ------------------------------------------------------------------------------

# Define architecture for model
data_shape = np.shape(X)
input_shape = (imsize, imsize, 3)

x_in = Input(shape=input_shape)
c0 = Convolution2D(8, (5, 5), activation='relu', strides=(1, 1), padding='same')(x_in)
b0 = BatchNormalization()(c0)
d0 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b0)
e0 = Dropout(0.5)(d0)

c1 = Convolution2D(16, (3, 3), activation='relu', strides=(1, 1), padding='same')(e0)
b1 = BatchNormalization()(c1)
d1 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b1)
e1 = Dropout(0.5)(d1)

c2 = Convolution2D(32, (3, 3), activation='relu', strides=(1, 1), padding='same')(e1)
b2 = BatchNormalization()(c2)
d2 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b2)
e2 = Dropout(0.5)(d2)

f = Flatten()(e2)
z0 = Dense(64, activation='softmax', kernel_regularizer=l2(0.0001))(f)
z1 = Dense(32, activation='softmax', kernel_regularizer=l2(0.0001))(z0)
y_out = Dense(1, activation='sigmoid')(z1)

cnn = Model(inputs=x_in, outputs=y_out)
2022-08-09 14:14:38.386832: W tensorflow/stream_executor/platform/default/] Could not load dynamic library ''; dlerror: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.12/x64/lib
2022-08-09 14:14:38.386890: W tensorflow/stream_executor/cuda/] failed call to cuInit: UNKNOWN ERROR (303)
2022-08-09 14:14:38.386918: I tensorflow/stream_executor/cuda/] kernel driver does not appear to be running on this host (fv-az200-623): /proc/driver/nvidia/version does not exist
2022-08-09 14:14:38.390512: I tensorflow/core/platform/] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

4. Compile the CNN#

Next, we compile the model. As in Ciprijanovic et al. 2020, we select the Adam opmimizer and the binary cross entropy loss function (as this is a binary classification problem).

You can learn more about optimizers and more about loss functions for regression tasks in the Keras documentation.

# Compile Model
optimizer = 'adam'
fit_metrics = ['accuracy']
loss = 'binary_crossentropy'
cnn.compile(loss=loss, optimizer=optimizer, metrics=fit_metrics)
Model: "model"
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 75, 75, 3)]       0         
 conv2d (Conv2D)             (None, 75, 75, 8)         608       
 batch_normalization (BatchN  (None, 75, 75, 8)        32        
 max_pooling2d (MaxPooling2D  (None, 37, 37, 8)        0         
 dropout (Dropout)           (None, 37, 37, 8)         0         
 conv2d_1 (Conv2D)           (None, 37, 37, 16)        1168      
 batch_normalization_1 (Batc  (None, 37, 37, 16)       64        
 max_pooling2d_1 (MaxPooling  (None, 18, 18, 16)       0         
 dropout_1 (Dropout)         (None, 18, 18, 16)        0         
 conv2d_2 (Conv2D)           (None, 18, 18, 32)        4640      
 batch_normalization_2 (Batc  (None, 18, 18, 32)       128       
 max_pooling2d_2 (MaxPooling  (None, 9, 9, 32)         0         
 dropout_2 (Dropout)         (None, 9, 9, 32)          0         
 flatten (Flatten)           (None, 2592)              0         
 dense (Dense)               (None, 64)                165952    
 dense_1 (Dense)             (None, 32)                2080      
 dense_2 (Dense)             (None, 1)                 33        
Total params: 174,705
Trainable params: 174,593
Non-trainable params: 112

5. Train the CNN to perform a classification task#

We will start with training for 20 epochs, but this almost certainly won’t be long enough to get great results. We set the “batch size” of the network (i.e., the number of samples to be propagated through the network, see the keras documentation here) to 128. Once you’ve run your model and evaluated the fit, you can come back here and run the next cell again for 100 epochs or longer. This step will likely take many minutes. The training step is typically the computational bottleneck for using CNNs. However, once a CNN is trained, it can effectively be “packaged up” for future use on the original or other machines. In other words, it doesn’t have to be retrained every time one wants to use it!

You can learn more about here.

nb_epoch = 20
batch_size = 128
shuffle = True

# Train
history =, y_train, 
                  validation_data=(X_valid, y_valid),
2022-08-09 14:14:38.796846: W tensorflow/core/framework/] Allocation of 728865000 exceeds 10% of free system memory.
2022-08-09 14:14:40.426289: W tensorflow/core/framework/] Allocation of 23040000 exceeds 10% of free system memory.
2022-08-09 14:14:40.575207: W tensorflow/core/framework/] Allocation of 23040000 exceeds 10% of free system memory.
2022-08-09 14:14:40.730816: W tensorflow/core/framework/] Allocation of 23887872 exceeds 10% of free system memory.
2022-08-09 14:14:40.731212: W tensorflow/core/framework/] Allocation of 23887872 exceeds 10% of free system memory.

6. Visualize CNN performance#

To visualize the performance of the CNN, we plot the evolution of the accuracy and loss as a function of training epochs, for the training set and for the validation set.

# plotting from history

loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

epochs = list(range(len(loss)))

figsize = (6, 4)
fig, axis1 = plt.subplots(figsize=figsize)
plot1_lacc = axis1.plot(epochs, acc, 'navy', label='accuracy')
plot1_val_lacc = axis1.plot(epochs, val_acc, 'deepskyblue', label="validation accuracy")

plot1_loss = axis1.plot(epochs, loss, 'red', label='loss')
plot1_val_loss = axis1.plot(epochs, val_loss, 'lightsalmon', label="validation loss")

plots = plot1_loss + plot1_val_loss
labs = [plot.get_label() for plot in plots]
plt.title("Loss/Accuracy History (Pristine Images)")
axis1.legend(loc='lower right')

Observe how the loss for the validation set is higher than for the training set (and conversely, the accuracy for the validation set is lower than for the training set), suggesting that this model is suffering from overfitting. Revisit the original paper and notice the strategies they employ to improve the validation accuracy. Observe their Figure 2 for an example of what the results of a properly-trained network look like!

7. Predict mergers!#

Apply the CNN to predict mergers in the “test” set, not used for training or validating the CNN.

test_predictions = cnn.predict(X_test)

Below, we use a confusion matrix to evaluate the model performance on the test data. See the documentation from sklearn on confusion matrices for more information.

def plot_confusion_matrix(cnn, input_data, input_labels):
    # Compute merger predictions for the test dataset
    predictions = cnn.predict(input_data)

    # Convert to binary classification 
    predictions = (predictions > 0.5).astype('int32') 
    # Compute the confusion matrix by comparing the test labels (ds.test_labels) with the test predictions
    cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])
    cm = cm.astype('float')

    # Normalize the confusion matrix results. 
    cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]
    # Plotting
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.matshow(cm_norm, cmap='binary_r')

    plt.title('Confusion matrix', y=1.08)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Merger', 'No Merger'])
    ax.set_yticks([0, 1])
    ax.set_yticklabels(['Merger', 'No Merger'])


    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            ax.text(j, i, format(cm_norm[i, j], fmt), 
                    ha="center", va="center", 
                    color="white" if cm_norm[i, j] < thresh else "black")
plot_confusion_matrix(cnn, X_test, y_test)


  • How do I interpret theses results? The confusion matrix shows the model predicts a large fraction of false positive (roughly 25%) and false negative (roughly 36%) merger events. The published models from Ciprijanovic et al. 2020 perform much better. We note that in this notebook we are training for only a subset of the optimal number of epochs for space and time considerations, but you are welcome to agument these restricitons, and as always check out the DeepMerge code for more information!

  • Can I improve the model by changing it? We only trained for 20 epochs, which is many fewer than the published model. Go back to Section 4 (“Train the CNN to perform a classification task”) and increase the number of epochs to 100 (or more!) and train again. Does your model perform better? Your results may look better/worse/different from the published results due to the stochastic nature of training.

  • Can I try a different model? I think the results could be improved. Yes! You can try adding layers, swapping out the max pooling, changing the activation functions, swapping out the loss function, or trying a different optimizer or learning rate. Experiment and see what model changes give the best results. You should be aware: when you start training again, you pick up where your model left off. If you want to “reset” your model to epoch 0 and random weights, you should run the cells to make and compile the model again.

  • I want to test my model on my training data! No. You will convince yourself that your results are much better than they actually are. Always keep your training, validation, and testing sets completely separate!


  • Effect of noise? Try re-training the network with “noisy” data (i.e., modify the version in Section 1 to “noisy” and download the associated data product). Do the results change? If so, how and why? What are the pros and cons of using noisy vs. pristine data to train a ML model?

  • Effect of wavelength? The DEEPMERGE HLSP includes mock galaxy images in 2 filters only (only HST data). If you train the network with this data (hint: this will require downloading it from the website, or modifying the download cells to point to the correct URL; and also modifying the shapes of the training, validation and test data, as well as the network inputs), how do the results change?

  • Early stopping? The DeepMerge team employed “early stopping” to minimize overfitting. Try implementing it in the network here! The Keras library for early stopping functions will be useful. For example, you can recompile the model, train for many more epochs, and include a callback, in cnn.train e.g.,

    callback = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=50)

Don’t forget, the DeepMerge team provides code for building their production-level model and verifying its results, please check them out for more extensions and ideas!

About this Notebook#

Claire Murray, Assistant Astronomer,

Additional Contributors:
Yotam Cohen, STScI Staff Scientist,

This notebook is based on the code repository for the paper ”DeepMerge: Classifying High-redshift Merging Galaxies with Deep Neural Networks”, A. Ćiprijanović, G.F. Snyder, B. Nord, J.E.G. Peek, Astronomy & Computing, Volume 32, July 2020, and the notebook “CNN_for_cluster_masses” by Michelle Ntampaka, Assistant Astronomer,

Updated On: 2022-5-25


If you use this data set, astropy, or keras for published research, please cite the authors. Follow these links for more information:

Top of Page Space Telescope Logo