"""Functions for the RSCD correction for MIRI science data."""
import logging
import numpy as np
from stdatamodels.jwst.datamodels import dqflags
log = logging.getLogger(__name__)
__all__ = [
"do_correction",
"correction_skip_groups",
"get_rscd_parameters",
"flag_rscd",
"apply_rscd_flags",
]
[docs]
def do_correction(output_model, rscd_model):
"""
Set the initial groups of an integration of MIRI data to 'DO_NOT_USE'.
The number of initial groups to set to 'DO_NOT_USE' is read in from the RSCD reference
file. The number of groups to skip is integration dependent. The first integration has
a value defined in the reference file and the second and higher integrations have a
separate value in the reference file.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.RampModel`
Input ramp datamodel
rscd_model : `~stdatamodels.jwst.datamodels.RSCDModel`
RSCD reference datamodel
Returns
-------
output_model : `~stdatamodels.jwst.datamodels.RampModel`
Ramp datamodel with RSCD affected groups flagged as DO_NOT_USE
"""
# Retrieve the reference parameters for this exposure type
param = get_rscd_parameters(output_model, rscd_model)
if not bool(param): # empty dictionary
log.warning(
"READPATT, SUBARRAY combination not found in ref file: RSCD correction will be skipped"
)
output_model.meta.cal_step.rscd = "SKIPPED"
return output_model
group_skip_int1 = param["skip_int1"] # integration 1
group_skip_int2p = param["skip_int2p"] # integration 2, plus higher integrations
if group_skip_int1 < 0:
log.warning("RSCD reference file is of a deprecated model.")
log.warning("There are no values for first integration")
log.warning("Setting number of groups to skip in first integration to 1")
group_skip_int1 = 1
log.info(f"# groups from RSCD reference file for int 1 to flag: {group_skip_int1}")
log.info(f"# groups from RSCD reference file for int 2 and higher to flag: {group_skip_int2p}")
output_model = correction_skip_groups(output_model, group_skip_int1, group_skip_int2p)
return output_model
[docs]
def correction_skip_groups(output, group_skip_int1, group_skip_int2p):
"""
Set the initial groups in integration to DO_NOT_USE to skip groups affected by RSCD effect.
Parameters
----------
output : `~stdatamodels.jwst.datamodels.RampModel`
Science data to be flagged
group_skip_int1 : int
Number of groups to skip at the beginning of the ramp for integration 1
group_skip_int2p : int
Number of groups to skip at the beginning of the ramp for integration 2 and higher
Returns
-------
output: `~stdatamodels.jwst.datamodels.RampModel`
Ramp datamodel with RSCD affected groups flagged as DO_NOT_USE
"""
# General exposure parameters
sci_ngroups = output.meta.exposure.ngroups
sci_nints = output.meta.exposure.nints
# values defined for segmented data
sci_int_start = output.meta.exposure.integration_start
if sci_int_start is None: # the data is not segmented
sci_int_start = 1
log.debug(f"RSCD correction using: nints={sci_nints}, ngroups={sci_ngroups}")
log.debug(f"The first integration in the data is integration: {sci_int_start}")
# For general RSCD flagging, we have to start with at least 3 groups. The last frame
# has been rejected in the last frame correction, leaving us with 2 groups. We have to
# have at least 2 valid groups to perform a fit. Therefore the minimum number of groups
# to do an rscd flagging is 3 groups. MIRI has a set minimum of 5 groups in APT (so only
# in rare special cases will have 3 groups or less).
if sci_ngroups < 3:
log.warning("Too few groups to apply RSCD correction")
log.warning("RSCD step will be skipped")
output.meta.cal_step.rscd = "SKIPPED"
return output
# Basic global checks:
# ___________________
# Will we have at least 3 groups. The last frame step has rejected 1 group so we have 2 to
# find a slope.
# check for sci_ngroups <= 5
if sci_ngroups <= 5:
group_skip_int1 = 1
group_skip_int2p = 1
log.info(
f"Number of groups to skip for integration 1 (for data with <= 5 groups): "
f"{group_skip_int1}"
)
log.info(
f"Number of groups to skip for integration 2+ (for data with <= 5 groups): "
f"{group_skip_int2p}"
)
# General Checks for RSCD dynamic flagging.
# checks for integration 1:
if sci_ngroups < (group_skip_int1 + 3):
max_groups_skip = max(0, sci_ngroups - 3)
if max_groups_skip != group_skip_int1:
log.info(f"Changing the # of groups to skip in int 1 to {max_groups_skip}")
group_skip_int1 = max_groups_skip
# checks for integration 2
if sci_nints > 1 and sci_ngroups < (group_skip_int2p + 3):
max_groups_skip = max(0, sci_ngroups - 3)
if max_groups_skip != group_skip_int2p:
group_skip_int2p = max_groups_skip
log.info(f"Changing the # of groups to skip in int 2 and higher to {max_groups_skip}")
# Note For segmented data the first integration in the file may not be the first
# integration in the exposure. The value in meta.exposure.integration_start
# holds the value of the first integration in the file.
# Flag RSCD groups in integration 1
# __________________________________
if sci_int_start == 1: # Using sci_int_start to cover segmented data case.
rscd_skip_array, num_rscd_lowered, num_only_one_group = flag_rscd(
output, sci_int_start - 1, sci_int_start - 1, group_skip_int1
)
output = apply_rscd_flags(output, sci_int_start - 1, sci_int_start - 1, rscd_skip_array)
log.info(
"Number of usable bright pixels with rscd flag groups "
f"not set to DO_NOT_USE: {num_rscd_lowered}"
)
output.meta.rscd.keep_bright_firstgroup_int1 = num_only_one_group
output.meta.rscd.keep_groups_saturation_int1 = num_rscd_lowered
output.meta.rscd.ngroups_skip_int1 = group_skip_int1
# Flag RSCD groups in integration 2 and higher
# ______________________________________________
int_start = 2
int_end = output.data.shape[0]
# use the data shape instead of sci_ints in case we have segmented data
# in segmented data the sci_ints can be much larger than data.shape[0]
# depending on which segment number we are on.
if sci_int_start != 1: # we have segmented data and we are not on the first integration
int_start = 1
if sci_nints > 1:
rscd_skip_array, num_rscd_lowered, num_only_one_group = flag_rscd(
output, int_start - 1, int_end - 1, group_skip_int2p
)
output = apply_rscd_flags(output, int_start - 1, int_end - 1, rscd_skip_array)
log.info(
"Number of usable bright pixels with rscd flag groups "
f"not set to DO_NOT_USE: {num_rscd_lowered}"
)
output.meta.rscd.keep_bright_firstgroup_int2p = num_only_one_group
output.meta.rscd.keep_groups_saturation_int2p = num_rscd_lowered
output.meta.rscd.ngroups_skip_int2p = group_skip_int2p
output.meta.cal_step.rscd = "COMPLETE"
return output
[docs]
def flag_rscd(output_model, int_start, int_end, rscd_skip):
"""
Find the initial groups to set to DO_NOT_USE based on RSCD rules.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.RampModel`
Science data to be flagged.
int_start : int
Starting integration.
int_end : int
Ending integration.
rscd_skip : int
Number of groups to skip at the beginning of the ramp for integration range.
Returns
-------
skip_array : ndarray
Array containing the number of groups to skip based on pixel location and integration.
num_rscd_lowered : int
The number of pixels where the number of RSCD groups to flag as DO_NOT_USE was
changed because of saturation.
num_only_one_group : int
The number of pixels where there is only 1 valid group after checking for saturation.
"""
n_ints = int_end - int_start + 1
x_dim = output_model.groupdq.shape[3]
y_dim = output_model.groupdq.shape[2]
skip_array = np.full((n_ints, y_dim, x_dim), rscd_skip)
# --- If we encounter saturation, we might need to back off the rscd correction.
# Ideally we want at least two valid groups, but we need to allow there to only
# be 1 valid group. The user can set the ramp_fit parameter suppress_one_group = False
# to derive a value for this point.
min_group = rscd_skip + 2
# Note: min_groups starts count at 1
# 1. Identify pixels saturated at the current threshold
is_sat_problem = (
(
output_model.groupdq[int_start : int_end + 1, min_group - 1, :, :]
& dqflags.group["SATURATED"]
)
> 0
).astype(bool)
# New check specifically for Group 1. If it is also saturated then we can not
# recover this pixel.
is_group_1_sat = (
(output_model.groupdq[int_start : int_end + 1, 0, :, :] & dqflags.group["SATURATED"]) > 0
).astype(bool)
# 3. Remove Group 1 saturation from the original problem mask
# This keeps saturation flags ONLY if they are NOT saturated in Group 1
is_sat_problem &= ~is_group_1_sat
num_rscd_lowered = 0
num_only_one_group_pixels = 0
num_sat = np.sum(is_sat_problem)
log.info(
f" There are {num_sat} saturated pixels that require the number of "
"rscd groups flagged to be lowered"
)
# Find the first non-saturating group
if num_sat > 0:
# do dynamic rscd flagging - based on saturation group of every pixel
while num_sat > 0 and min_group > 1:
# subtract 1 from skip_array
skip_array[is_sat_problem] = np.maximum(skip_array[is_sat_problem] - 1, 0)
min_group = min_group - 1
# re-evaluate the saturation at the lower group level
is_sat_problem = (
(
output_model.groupdq[int_start : int_end + 1, min_group - 1, :, :]
& dqflags.group["SATURATED"]
)
> 0
).astype(bool)
# Re-apply the group 1 guard
# (Otherwise, if we drop to Group 1, we might process pixels
# we already deemed "unrecoverable")
is_sat_problem &= ~is_group_1_sat
num_sat = is_sat_problem.sum()
# 1. Identify where the skip_array is less than the original rscd_skip
# This means the logic was forced to "back off" to accommodate saturation.
was_backed_off = skip_array < rscd_skip
# 2. Collapse the 3D mask (Integrations, Y, X) to 2D (Y, X)
# If a pixel was backed off in ANY integration, we flag it.
is_backed_off_2d = np.any(was_backed_off, axis=0)
num_rscd_lowered = is_backed_off_2d.sum()
# 3. Apply the FLUX_ESTIMATED flag
if np.any(is_backed_off_2d):
output_model.pixeldq[is_backed_off_2d] |= dqflags.pixel["FLUX_ESTIMATED"]
log.info(
f"Flagged {np.sum(is_backed_off_2d)} pixels as FLUX_ESTIMATED due to RSCD back-off."
)
# 4. Final Safety: Reset negative values (with this logic, 0 is the floor)
skip_array = np.maximum(skip_array, 0)
# now record if we have to back off all the way to group 1
is_only_one_group = skip_array == 0
num_only_one_group_pixels = np.any(is_only_one_group, axis=0).sum()
return skip_array, num_rscd_lowered, num_only_one_group_pixels
[docs]
def apply_rscd_flags(output_model, int_start, int_end, skip_array):
"""
Apply flags for RSCD correction setting DO_NOT_USE to the dq values.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.RampModel`
Science data to be flagged
int_start : int
Starting integration
int_end : int
Ending integration
skip_array : ndarray
Number of groups to skip at the beginning of the ramp for integration range.
Returns
-------
output_model : `~stdatamodels.jwst.datamodels.RampModel`
Ramp datamodel with RSCD affected groups flagged as DO_NOT_USE
"""
# Redefine starting at 0
skip_array = skip_array - 1
# 1. Extract the relevant region of the groupdq array
# Shape: (N_ints, Groups, Y, X)
dq = output_model.groupdq[int_start : int_end + 1, :, :, :]
# 2. Create a grid of group indices
# Shape: (Groups,) -> e.g., [0, 1, 2, 3...]
num_groups = dq.shape[1]
group_indices = np.arange(num_groups)
# 3. Broadcast for comparison
# We want: (1, Groups, 1, 1) < (N_ints, 1, Y, X)
# This results in a 4D boolean mask
mask = group_indices[None, :, None, None] <= skip_array[:, None, :, :]
# 4. Apply the DO_NOT_USE flag using the mask
# This updates only the pixels/groups where the index is below the skip threshold
dq[mask] |= dqflags.group["DO_NOT_USE"]
# Put the modified dq back
output_model.groupdq[int_start : int_end + 1, :, :, :] = dq
return output_model
[docs]
def get_rscd_parameters(input_model, rscd_model):
"""
Read in the parameters from the reference file and store the parameters in a dictionary.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.RampModel`
Science data to be flagged
rscd_model : `~stdatamodels.jwst.datamodels.RSCDModel`
RSCD reference file data
Returns
-------
param : dict
Dictionary of parameters
"""
# Reference file parameters held in dictionary: param
param = {}
# read in the type of data from the input model (FAST,SLOW,FULL,SUBARRAY)
readpatt = input_model.meta.exposure.readpatt
subarray = input_model.meta.subarray.name
# Check for old values of the MIRI LRS slitless subarray name
# in the science data and change to the new
if subarray.upper() == "SUBPRISM":
subarray = "SLITLESSPRISM"
# read table 1: containing the number of groups to skip
for tabdata in rscd_model.rscd_group_skip_table:
subarray_table = tabdata["subarray"]
readpatt_table = tabdata["readpatt"]
group_skip_table_int2p = tabdata["group_skip"] # integration 2 and higher (+)
group_skip_table_int1 = tabdata["group_skip1"]
if subarray_table == subarray and readpatt_table == readpatt:
param["skip_int1"] = group_skip_table_int1
param["skip_int2p"] = group_skip_table_int2p # integration 2 and higher
break
return param