NIRCam PSF Photometry With Space_Phot#

Author: Ori Fox
Last Updated: January, 2025

This notebook currently fails to execute, use as reference only

Table of contents#

  1. Introduction

  2. Setup
    2.1 Python imports
    2.2 Download data

  3. Bright, Single Object
    3.1 Multiple, Level2 Files

  4. Faint/Upper Limit, Single Object
    4.1 Multiple, Level2 Files

  5. Stellar Field (LMC)
    5.1 Multiple, Level2 Files
    5.2 Single, Level3 Mosaicked File

1.-Introduction #

Packages to Install: drizzlepac\ space_phot https://github.com/jpierel14/space_phot\ photutils (on main git+https://github.com/astropy/photutils)\ jupyter\

Goals:

PSF Photometry can be obtained using:

  • grid of PSF models from WebbPSF

  • single effective PSF (ePSF) NOT YET AVAILABLE

  • grid of effective PSF NOT YET AVAILABLE

The notebook shows:

  • how to obtain the PSF model from WebbPSF (or build an ePSF)

  • how to perform PS

  • photometry on the image

2.-Setup #

2.1-Python imports#

from astropy.io import fits
from astropy.nddata import extract_array
from astropy.coordinates import SkyCoord
from astropy import wcs
from astropy.wcs.utils import skycoord_to_pixel
from astropy import units as u
import numpy as np
import pandas as pd
from astropy.visualization import simple_norm
from urllib.parse import urlparse
import requests
import time
import math
import logging
from jwst.associations import load_asn
import matplotlib.pyplot as plt
%matplotlib inline

from astroquery.mast import Observations
import os
import tarfile

# Background and PSF Functions
from photutils.background import MMMBackground, MADStdBackgroundRMS
from photutils.detection import DAOStarFinder

import space_phot
from importlib.metadata import version
version('space_phot')

2.2-Download data#

def get_asn_filenames(program, observation, objnum, filtername):
    """Query MAST to determine the name of the association file for the given
    program, observation, and object number. This function is more convenient
    than hardcoding a filename because the filenames include datestamps of when
    they were produced. So each time the files are reprocessed in MAST, the
    filenames change.

    Parameters
    ----------
    program : int
        Program ID number. e.g. 1067

    observation : int
        Observation number. e.g. 24

    objnum : int
        Object number. 

    filtername : str
        Name of the filter used in the observation. e.g. "F444W"

    Returns
    -------
    files_to_download : list
        List of filenames matching the input parameters
    """
    prog_str = str(program).zfill(5)
    obs_str = str(observation).zfill(3)
    obj_str = str(objnum).zfill(5)
    
    obs_id_table = Observations.query_criteria(instrument_name=["NIRCAM/IMAGE"],
                                               provenance_name=["CALJWST"],  # Executed observations
                                               obs_id=['jw' + prog_str + '-o' + obs_str + '*'],
                                               filters=[filtername.upper()]
                                               )
    matching_files = []
    for exposure in (obs_id_table):
        products = Observations.get_product_list(exposure)
        filtered_products = Observations.filter_products(products, productSubGroupDescription='ASN')
        matching_files.extend(filtered_products['dataURI'])

    matching_files = [os.path.basename(e) for e in matching_files if f'_image3_{obj_str}_asn' in e]
    return matching_files
def download_files(files_to_download):
    """Download a list of files from MAST.

    Parameters
    ----------
    files_to_download : list
        List of filenames
    """
    for file in files_to_download:
        # Check if the file already exists in the current working directory
        if os.path.exists(file):
            print(f"File {file} already exists. Skipping download.")
            continue
        cal_uri = f'mast:JWST/product/{file}'
        Observations.download_file(cal_uri)
# Download NIRCam Data PID 1537 (Calibration Program) and NIRCam Data PID 1476 (LMC)
files_to_download = ['jw01537-o024_t001_nircam_clear-f444w-sub160_i2d.fits',
                     'jw01537024001_0310a_00001_nrcblong_cal.fits',
                     'jw01537024001_0310a_00002_nrcblong_cal.fits',
                     'jw01537024001_0310a_00003_nrcblong_cal.fits',
                     'jw01537024001_0310a_00004_nrcblong_cal.fits',
                     'jw01537024001_0310k_00001_nrcblong_cal.fits',
                     'jw01537024001_0310k_00002_nrcblong_cal.fits',
                     'jw01537024001_0310k_00003_nrcblong_cal.fits',
                     'jw01537024001_0310k_00004_nrcblong_cal.fits',
                     'jw01476-o001_t001_nircam_clear-f150w_i2d.fits',
                     'jw01476001007_02101_00001_nrca1_cal.fits',
                     'jw01476001007_02101_00002_nrca1_cal.fits',
                     'jw01476001007_02101_00003_nrca1_cal.fits',
                     'jw01476001008_02101_00001_nrca1_cal.fits',
                     'jw01476001008_02101_00002_nrca1_cal.fits',
                     'jw01476001008_02101_00003_nrca1_cal.fits',
                     'jw01476001008_02101_00004_nrca1_cal.fits',
                     'jw01476001008_02101_00005_nrca1_cal.fits',
                     'jw01476001008_02101_00006_nrca1_cal.fits'
                     ]

# Get the names of the related association files and add those
# to the list of files to download
asn_files_to_download = [get_asn_filenames(1537, 24, 1, 'F444W')[0],
                         get_asn_filenames(1476, 1, 23, 'F150W')[0]
                         ]
files_to_download += asn_files_to_download

# Call the function to download files
download_files(files_to_download)

3.-Bright, Single Object#

3.1-Multiple, Level2 Files#

# Level 3 Files: NIRCam Data PID 1537 (Calibration Program):
lvl3 = 'jw01537-o024_t001_nircam_clear-f444w-sub160_i2d.fits'
lvl3
hdl = fits.open(lvl3)
hdr = hdl[0].header
asnfile = hdr['ASNTABLE']
lvl2_prelim = []
asn_data = load_asn(open(asnfile))
for member in asn_data['products'][0]['members']:
    #print(member['expname'])
    lvl2_prelim.append(member['expname'])
    
lvl2_prelim
# Sort out LVL2 Data That Includes The Actual Source (there are 4 detectors)
source_location = SkyCoord('5:05:30.6593', '+52:49:49.862', unit=(u.hourangle, u.deg))
lvl2 = []
for ref_image in lvl2_prelim:
    print(ref_image)
    ref_fits = fits.open(ref_image)
    ref_data = fits.open(ref_image)['SCI', 1].data
    ref_y, ref_x = skycoord_to_pixel(source_location, wcs.WCS(ref_fits['SCI', 1], ref_fits))
    print(ref_y, ref_x)
    try:
        extract_array(ref_data, (11, 11), (ref_x, ref_y)) # block raising an exception
    except Exception as e:
        logging.error(f"An error occurred: {e}")
        pass  # Doing nothing on exception, but logging it
    else:
        lvl2.append(ref_image)
        print(ref_image + ' added to final list')
        
lvl2
# Change all DQ flagged pixels to NANs
for file in lvl2:
    hdul = fits.open(file, mode='update')
    data = fits.open(file)['SCI', 1].data
    dq = fits.open(file)['DQ', 1].data
    data[dq == 1] = np.nan
    hdul['SCI', 1].data = data
    hdul.flush()
# Examine the First Image
ref_image = lvl2[0]
print(ref_image)
ref_fits = fits.open(ref_image)
ref_data = fits.open(ref_image)['SCI', 1].data
norm1 = simple_norm(ref_data, stretch='linear', min_cut=-1, max_cut=10)

plt.imshow(ref_data, origin='lower', norm=norm1, cmap='gray')
plt.gca().tick_params(labelcolor='none', axis='both', color='none')
plt.show()
lvl2[0]
# Zoom in to see the source
ref_y, ref_x = skycoord_to_pixel(source_location, wcs.WCS(ref_fits['SCI', 1], ref_fits))
ref_cutout = extract_array(ref_data, (11, 11), (ref_x, ref_y))
norm1 = simple_norm(ref_cutout, stretch='linear', min_cut=-10, max_cut=1000)
plt.imshow(ref_cutout, origin='lower',
           norm=norm1, cmap='gray')
plt.title('PID1537,Obs024')
plt.gca().tick_params(labelcolor='none', axis='both', color='none')
plt.show()

ref_cutout
# Set environmental variables
os.environ["WEBBPSF_PATH"] = "./webbpsf-data/webbpsf-data"
os.environ["PYSYN_CDBS"] = "./grp/redcat/trds/"

# required webbpsf data
boxlink = 'https://stsci.box.com/shared/static/qxpiaxsjwo15ml6m4pkhtk36c9jgj70k.gz'                                                          
boxfile = './webbpsf-data/webbpsf-data-LATEST.tar.gz'
synphot_url = 'http://ssb.stsci.edu/trds/tarfiles/synphot5.tar.gz'
synphot_file = './synphot5.tar.gz'

webbpsf_folder = './webbpsf-data'
synphot_folder = './grp'


def download_file(url, dest_path, timeout=60):
    parsed_url = urlparse(url)
    if parsed_url.scheme not in ["http", "https"]:
        raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme}")

    response = requests.get(url, stream=True, timeout=timeout)
    response.raise_for_status()
    with open(dest_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)


# Gather webbpsf files
psfExist = os.path.exists(webbpsf_folder)
if not psfExist:
    os.makedirs(webbpsf_folder)
    download_file(boxlink, boxfile)
    gzf = tarfile.open(boxfile)
    gzf.extractall(webbpsf_folder, filter='data')

# Gather synphot files
synExist = os.path.exists(synphot_folder)
if not synExist:
    os.makedirs(synphot_folder)
    download_file(synphot_url, synphot_file)
    gzf = tarfile.open(synphot_file)
    gzf.extractall('./', filter='data')

# Get PSF from WebbPSF
jwst_obs = space_phot.observation2(lvl2)
psfs = space_phot.get_jwst_psf(jwst_obs, source_location)
plt.imshow(psfs[0].data)
plt.show()
# Do PSF Photometry using space_phot (details of fitting are in documentation)
# https://st-phot.readthedocs.io/en/latest/examples/plot_a_psf.html#jwst-images
jwst_obs.psf_photometry(psfs, source_location, bounds={'flux': [-10, 10000],
                        'centroid': [-2, 2],
                        'bkg': [0, 50]},
                        fit_width=5,
                        fit_bkg=True,
                        fit_flux='single')
jwst_obs.plot_psf_fit()
plt.show()

jwst_obs.plot_psf_posterior(minweight=.0005)
plt.show()

print(jwst_obs.psf_result.phot_cal_table)
# Calculate Average Magnitude from Table
mag_arr = jwst_obs.psf_result.phot_cal_table['mag']
magerr_arr = jwst_obs.psf_result.phot_cal_table['magerr']

mag_lvl2psf = np.mean(mag_arr)
magerr_lvl2psf = math.sqrt(sum(p**2 for p in magerr_arr))
print(round(mag_lvl2psf, 4), round(magerr_lvl2psf, 4))

4.-Faint/Upper Limit, Single Object#

4.1-Multiple, Level2 Files#

# Level 3 Files
lvl3 = 'jw01537-o024_t001_nircam_clear-f444w-sub160_i2d.fits'
lvl3
from jwst.associations import load_asn
hdl = fits.open(lvl3)
hdr = hdl[0].header
asnfile = hdr['ASNTABLE']
lvl2_prelim = []
asn_data = load_asn(open(asnfile))
for member in asn_data['products'][0]['members']:
    lvl2_prelim.append(member['expname'])
    
lvl2_prelim
# Sort out LVL2 Data That Includes The Actual Source (there are 4 detectors)
source_location = SkyCoord('5:05:30.6186', '+52:49:49.130', unit=(u.hourangle, u.deg))
lvl2 = []
for ref_image in lvl2_prelim:
    print(ref_image)
    ref_fits = fits.open(ref_image)
    ref_data = fits.open(ref_image)['SCI', 1].data
    ref_y, ref_x = skycoord_to_pixel(source_location, wcs.WCS(ref_fits['SCI', 1], ref_fits))
    print(ref_y, ref_x)
    try:
        extract_array(ref_data, (11, 11), (ref_x, ref_y)) # block raising an exception
    except Exception as e:
        logging.error(f"An error occurred: {e}")
        pass  # Doing nothing on exception, but logging it
    else:
        lvl2.append(ref_image)
        print(ref_image + ' added to final list')
        
lvl2
# Change all DQ flagged pixels to NANs
for file in lvl2:
    hdul = fits.open(file, mode='update')
    data = fits.open(file)['SCI', 1].data
    dq = fits.open(file)['DQ', 1].data
    data[dq == 1] = np.nan
    hdul['SCI', 1].data = data
    hdul.flush()
# Examine the First Image
ref_image = lvl2[0]
print(ref_image)
ref_fits = fits.open(ref_image)
ref_data = fits.open(ref_image)['SCI', 1].data
norm1 = simple_norm(ref_data, stretch='linear', min_cut=-1, max_cut=25)

plt.imshow(ref_data, origin='lower', norm=norm1, cmap='gray')
plt.gca().tick_params(labelcolor='none', axis='both', color='none')
plt.show()
# Pick a blank part of the sky to calculate the upper limit
ref_y, ref_x = skycoord_to_pixel(source_location, wcs.WCS(ref_fits['SCI', 1], ref_fits))
ref_cutout = extract_array(ref_data, (11, 11), (ref_x, ref_y))
norm1 = simple_norm(ref_cutout, stretch='linear', min_cut=-1, max_cut=25)
plt.imshow(ref_cutout, origin='lower',
           norm=norm1, cmap='gray')
plt.title('PID1028,Obs006')
plt.gca().tick_params(labelcolor='none', axis='both', color='none')
plt.show()
# Get PSF from WebbPSF
jwst_obs = space_phot.observation2(lvl2)
psfs = space_phot.get_jwst_psf(jwst_obs, source_location)
plt.imshow(psfs[0].data)
plt.show()
# Do PSF Photometry using space_phot (details of fitting are in documentation)
# https://st-phot.readthedocs.io/en/latest/examples/plot_a_psf.html#jwst-images
jwst_obs.psf_photometry(
    psfs,
    source_location,
    bounds={
        'flux': [-10, 1000],
        'bkg': [0, 50]
    },
    fit_width=5,
    fit_bkg=True,
    fit_centroid='fixed',
    fit_flux='single'
)

jwst_obs.plot_psf_fit()
plt.show()

jwst_obs.plot_psf_posterior(minweight=.0005)
plt.show()

print(jwst_obs.psf_result.phot_cal_table)
# Print Upper Limits
magupper_lvl2psf = jwst_obs.upper_limit(nsigma=5)
magupper_lvl2psf

5.-Stellar Field (LMC)#

5.1-Multiple, Level2 Files#

Now do the same thing for a larger group of stars and test for speed#

# Level 3 Files: NIRCam Data PID 1476 (LMC)
lvl3 = 'jw01476-o001_t001_nircam_clear-f150w_i2d.fits'
lvl3
hdl = fits.open(lvl3)
hdr = hdl[0].header
asnfile = hdr['ASNTABLE']
lvl2 = []
asn_data = load_asn(open(asnfile))
for member in asn_data['products'][0]['members']:
    lvl2.append(member['expname'])
    
lvl2 = [s for s in lvl2 if "nrca1" in s]
lvl2
# Find Stars in Level 3 File
# Get rough estimate of background (There are Better Ways to Do Background Subtraction)
bkgrms = MADStdBackgroundRMS()
mmm_bkg = MMMBackground()

im = fits.open(lvl3) 
w = wcs.WCS(im['SCI', 1])

std = bkgrms(im[1].data)
bkg = mmm_bkg(im[1].data)
data_bkgsub = im[1].data.copy()
data_bkgsub -= bkg        
sigma_psf = 1.636 # pixls for F770W
threshold = 5.

daofind = DAOStarFinder(threshold=threshold * std, fwhm=sigma_psf, exclude_border=True)
found_stars = daofind(data_bkgsub)
found_stars.pprint_all(max_lines=10)
# Filter out only stars you want
plt.figure(figsize=(12, 8))
plt.clf()

ax1 = plt.subplot(2, 1, 1)

ax1.set_xlabel('mag')
ax1.set_ylabel('sharpness')

xlim0 = np.min(found_stars['mag']) - 0.25
xlim1 = np.max(found_stars['mag']) + 0.25
ylim0 = np.min(found_stars['sharpness']) - 0.15
ylim1 = np.max(found_stars['sharpness']) + 0.15

ax1.set_xlim(xlim0, xlim1)
ax1.set_ylim(ylim0, ylim1)

ax1.scatter(found_stars['mag'], found_stars['sharpness'], s=10, color='k')

sh_inf = 0.40
sh_sup = 0.82
lmag_lim = -1.0
umag_lim = -6.0

ax1.plot([xlim0, xlim1], [sh_sup, sh_sup], color='r', lw=3, ls='--')
ax1.plot([xlim0, xlim1], [sh_inf, sh_inf], color='r', lw=3, ls='--')
ax1.plot([lmag_lim, lmag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')
ax1.plot([umag_lim, umag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')

ax2 = plt.subplot(2, 1, 2)

ax2.set_xlabel('mag')
ax2.set_ylabel('roundness')

ylim0 = np.min(found_stars['roundness2']) - 0.25
ylim1 = np.max(found_stars['roundness2']) - 0.25

ax2.set_xlim(xlim0, xlim1)
ax2.set_ylim(ylim0, ylim1)

round_inf = -0.40
round_sup = 0.40

ax2.scatter(found_stars['mag'], found_stars['roundness2'], s=10, color='k')

ax2.plot([xlim0, xlim1], [round_sup, round_sup], color='r', lw=3, ls='--')
ax2.plot([xlim0, xlim1], [round_inf, round_inf], color='r', lw=3, ls='--')
ax2.plot([lmag_lim, lmag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')
ax2.plot([umag_lim, umag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')

plt.tight_layout()
mask = ((found_stars['mag'] < lmag_lim) & (found_stars['mag'] > umag_lim) & (found_stars['roundness2'] > round_inf)
        & (found_stars['roundness2'] < round_sup) & (found_stars['sharpness'] > sh_inf) 
        & (found_stars['sharpness'] < sh_sup) & (found_stars['xcentroid'] > 1940) & (found_stars['xcentroid'] < 2000)
        & (found_stars['ycentroid'] > 1890) & (found_stars['ycentroid'] < 1960))

found_stars_sel = found_stars[mask]

print('Number of stars found originally:', len(found_stars))
print('Number of stars in final selection:', len(found_stars_sel))
found_stars_sel
# Convert pixel to wcs coords
skycoords = w.pixel_to_world(found_stars_sel['xcentroid'], found_stars_sel['ycentroid'])
len(skycoords)
lvl2
file = lvl2[0]
dq = fits.open(file)['DQ', 1].data
dq[233, 340]
# Change all DQ flagged pixels to NANs
for file in lvl2:
    hdul = fits.open(file, mode='update')
    data = fits.open(file)['SCI', 1].data
    dq = fits.open(file)['DQ', 1].data
    data[dq == 262657] = np.nan
    data[dq == 262661] = np.nan
    hdul['SCI', 1].data = data
    hdul.flush()
# Create a grid for fast lookup using WebbPSF. The larger the grid, the better the photometric precision.
# Developer note. Would be great to have a fast/approximate look up table.
jwst_obs = space_phot.observation2(lvl2)
grid = space_phot.util.get_jwst_psf_grid(jwst_obs, num_psfs=4)
# Now Loop Through All Stars and Build Photometry Table
counter = 0.
badindex = []

jwst_obs = space_phot.observation2(lvl2)
for source_location in skycoords:
    tic = time.perf_counter()
    print('Starting', counter+1., ' of', len(skycoords), ':', source_location)
    psfs = space_phot.util.get_jwst_psf_from_grid(jwst_obs, source_location, grid)
    jwst_obs.psf_photometry(
        psfs,
        source_location,
        bounds={
            'flux': [-100, 1000],
            'centroid': [-2., 2.],
            'bkg': [0, 50]
        },
        fit_width=3,
        fit_bkg=False,
        fit_flux='single',
        maxiter=5000
    )
    
    jwst_obs.plot_psf_fit()
    plt.show()
    
    ra = jwst_obs.psf_result.phot_cal_table['ra'][0]
    dec = jwst_obs.psf_result.phot_cal_table['dec'][0]
    mag_arr = jwst_obs.psf_result.phot_cal_table['mag']
    magerr_arr = jwst_obs.psf_result.phot_cal_table['magerr']
    mag_lvl2psf = np.mean(mag_arr)
    magerr_lvl2psf = math.sqrt(sum(p**2 for p in magerr_arr))

    if counter == 0:
        df = pd.DataFrame(np.array([[ra, dec, mag_lvl2psf, magerr_lvl2psf]]), columns=['ra', 'dec', 'mag', 'magerr'])
    else:
        df = pd.concat([df, pd.DataFrame(np.array([[ra, dec, mag_lvl2psf, magerr_lvl2psf]]))], ignore_index=True)
    counter = counter + 1.
    
    toc = time.perf_counter()
    print("Elapsed Time for Photometry:", toc - tic)

5.2-Single, Level3 Mosaicked File#

lvl3
# Now do the same photometry on the Level 3 Data
ref_image = lvl3
ref_fits = fits.open(ref_image)
ref_data = fits.open(ref_image)['SCI', 1].data
norm1 = simple_norm(ref_data, stretch='linear', min_cut=-1, max_cut=10)

plt.imshow(ref_data, origin='lower',
           norm=norm1, cmap='gray')
plt.gca().tick_params(labelcolor='none', axis='both', color='none')
plt.show()
# Get PSF from WebbPSF and drizzle it to the source location
# Develop Note: Need Grid Capability for Level3 Data
jwst3_obs = space_phot.observation3(lvl3)
# Now Loop Through All Stars and Build Photometry Table
counter = 0.
badindex = []

for source_location in skycoords:
    tic = time.perf_counter()
    print('Starting', counter+1., ' of', len(skycoords), ':', source_location)
    psf3 = space_phot.get_jwst3_psf(jwst_obs, jwst3_obs, source_location, num_psfs=4)
    jwst3_obs.psf_photometry(
        psf3,
        source_location,
        bounds={
            'flux': [-1000, 10000],
            'centroid': [-2, 2],
            'bkg': [0, 50]
        },
        fit_width=5,
        fit_bkg=True,
        fit_flux=True
    )

    jwst3_obs.plot_psf_fit()
    plt.show()

    ra = jwst3_obs.psf_result.phot_cal_table['ra'][0]
    dec = jwst3_obs.psf_result.phot_cal_table['dec'][0]
    mag_lvl3psf = jwst3_obs.psf_result.phot_cal_table['mag'][0]
    magerr_lvl3psf = jwst3_obs.psf_result.phot_cal_table['magerr'][0]

    if counter == 0:
        df = pd.DataFrame(np.array([[ra, dec, mag_lvl3psf, magerr_lvl3psf]]), columns=['ra', 'dec', 'mag', 'magerr'])
    else:
        df = pd.concat([df, pd.DataFrame(np.array([[ra, dec, mag_lvl3psf, magerr_lvl3psf]]))], ignore_index=True)
    counter = counter + 1.
    toc = time.perf_counter()
    print("Elapsed Time for Photometry:", toc - tic)

Space Telescope Logo