# from numba import njit, prange
import numpy as np
from numpy import newaxis as na
from scipy.fft import ifft2
from . import wfi_coordinate_transformations as wfi
from .filter_detector_properties import FilterDetector, n_mercadtel
from .index_cdte import n_cdte
from .mtf_diffusion import intensity_to_image
from .opticspsf import GeometricOptics
from .polarisation_decomposition import polarisation_mode_decomposition
from .quadrature_integration import QuadratureIntegrator
from .wfi_data import pix
from .zernike import noll_to_zernike, zernike
[docs]
c = 2.99792458e8 # speed of light in m/s
[docs]
epsilon_0 = 8.8541878188e-12 # permittivity of free space in F/m
'''
def parallel_MTF_image(args):
"""Wrapper for MTF_image"""
xd, yd, imageX, imageY, Intensity_integrated, npix_boundary = args
return MTF_image(xd, yd, imageX, imageY, Intensity_integrated, npix_boundary)
'''
[docs]
class PSFObject:
"""
Monochromatic PSF object class.
Parameters
----------
scanum : int
The SCA number.
scax, scay : float
The pixel positions on the SCA (in mm, FPA coordinates relative to the SCA center).
wavelength : float, optional
The vacuum wavelength in microns.
postage_stamp_size : int, optional
The length of the side of the square postage stamp in native pixels.
ovsamp : int, optional
The oversampling factor for the PSF (number of samples per native pixel).
use_filter : str, optional
The filter configuration to use (1-character code).
npix_boundary : int, optional
?
use_postage_stamp_size : int, optional
Force pupil postage stamp size instead of internal calculation. In native pixels.
ray_trace : bool, optional
Whether to use ray tracing. (Only turn off for testing.)
extra_aberrations: float array, optional
Parameters corresponding to zernike polynomials for introducing aberrations that
add to the optical path length and produce different aberrations. Supports up to
5 parameters (Z2, Z3, Z4, Z5, and Z6 in that order). The effects of each polynomial
are as follows:
Z2: horizontal centering
Z3: vertical centering
Z4: focus
Z5: astigmatism
Z6: also astigmatism
detector_thickness : float, optional
Thickness of the detector in microns. This is used to compute the electric field
and intensity within the detector.
zlen : int, optional
Number of points along the z-axis (depth) within the detector to compute the
electric field and intensity. This is used to compute the electric field
and intensity within the detector.
interference_filter : psfsim.filter_detector_properties.FilterDetector, optional
The interference filter object to use for computing the transmitted electric field
and intensity within the detector. Defaults to an ideal 3-layer interference
filter (to get the right wiggle shape) + thin layer of CdTe/HgCdTe (to give the
additional loss in the blue).
Note that if you really want to model the response of each SCA, you will need some
additional empirical corrections on top of this, both because there are thickness
variations and also because there is a lot of physics associated with whether you get
electrons absorbed right near the illuminated surface that we aren't modeling --- we're
just using a 2-layer "dead zone" with ideal materials to mock up what is probably a
region with varying band gap and probability of collecting the hole that gets released.
cycle : int, optional
Which cycle to use for the Zernike modes.
mjd : float, optional
The MJD to use for the optical model.
Attributes
----------
wavelength : float
The wavelength
interference_filter : psfsim.filter_detector_properties.FilterDetector
The interference filter object.
ulen : int
The length of the FFTs.
optics : psfsim.opticspsf.GeometricOptics
The Geometric Optics object.
Methods
-------
__init__
Constructor.
get_optical_psf
Gets the optical PSF (no detector effects).
"""
[docs]
def __init__(
self,
scanum,
scax,
scay,
wavelength=0.48,
postage_stamp_size=31,
ovsamp=10,
use_filter="H",
npix_boundary=1,
a_lanczos=3,
use_postage_stamp_size=None,
ray_trace=True,
extra_aberrations=None,
detector_thickness=2,
zlen=20,
interference_filter=None,
cycle=9,
mjd=None,
):
[docs]
self.wavelength = wavelength
[docs]
self.npix_boundary = npix_boundary
if interference_filter is None:
# this is the default filter
interference_filter = FilterDetector(
[1.35, 1.82, 2.45, n_cdte(wavelength), n_mercadtel(wavelength)],
[0.163, 0.137, 0.084, 0.010, 0.008],
1,
)
[docs]
self.interference_filter = interference_filter
[docs]
self.postage_stamp_size = postage_stamp_size
[docs]
self.detector_thickness = detector_thickness
[docs]
self.z_array = np.linspace(0, detector_thickness, zlen)
# The following sets the ulen of the GeometricOptics object based on
# use_postage_stamp_size when an explicit native-pixel size is provided.
[docs]
self.ulen = 2048 # default value
if use_postage_stamp_size is not None:
if isinstance(use_postage_stamp_size, bool) or not isinstance(
use_postage_stamp_size, (int | np.integer)
):
raise TypeError(
"use_postage_stamp_size must be a positive integer number of native pixels or None."
)
if use_postage_stamp_size <= 0:
raise ValueError("use_postage_stamp_size must be a positive integer number of native pixels.")
self.ulen = use_postage_stamp_size * self.ovsamp
[docs]
self.optics = GeometricOptics(
scanum,
scax,
scay,
wavelength=wavelength,
use_filter=use_filter,
ulen=self.ulen,
ray_trace=ray_trace,
pixelsampling=10.0 / self.ovsamp,
a_lanczos=a_lanczos,
cycle=cycle,
mjd=mjd,
)
self.ux, self.uy = (
self.optics.u_array(),
self.optics.v_array(),
) # np.meshgrid(self.Optics.uX, self.Optics.uY, indexing='ij')
[docs]
self.u = np.sqrt(self.ux**2 + self.uy**2)
[docs]
self.mask = self.u <= 1
# sX = (self.wavelength / (self.optics.umax - self.optics.umin)) * (
# -(self.optics.ulen / 2.0) + np.array(range(self.optics.ulen))
# ) # postage stamp coordinates along the FPA axes in microns
# sY = (self.wavelength / (self.optics.umax - self.optics.umin)) * (
# -(self.optics.ulen / 2.0) + np.array(range(self.optics.ulen))
# ) # postage stamp coordinates along the FPA axes in microns
# self.sX, self.sY = np.meshgrid(sX, sY, indexing="ij")
# self.dsX = self.optics.wavelength / (
# self.optics.umax - self.optics.umin
# ) # postage stamp pixel size in microns
# self.dsY = self.optics.wavelength / (
# self.optics.umax - self.optics.umin
# ) # postage stamp pixel size in microns
[docs]
self.dx = self.optics.wavelength / np.abs(self.ulen * self.optics.du)
if extra_aberrations is not None:
if len(extra_aberrations) > 5:
raise ValueError("extra_aberrations supports at most 5 coefficients (Z2–Z6).")
noll_coeffs = np.arange(2, 7)
coeff_count = noll_coeffs.size
nArr, mArr = noll_to_zernike(noll_coeffs)
# I think this loop could be avoided but not sure if it's really worth it.
for n, m, mag in zip(nArr, mArr, extra_aberrations[:coeff_count], strict=False):
self.optics.path_difference += (
(
mag
* zernike(
n, m, 2 * self.optics.focalLength * self.optics.urhoPolar, self.optics.uthetaPolar
)
)
if mag is not None
else 0
)
prefactor = (
self.optics.pupil_mask
/ self.optics.determinant
* np.exp(2 * np.pi / self.wavelength * 1j * self.optics.path_difference)
)
x_minus = (-1) ** np.array(range(self.ulen)) # used to translate ftt to image center
ph = np.outer(x_minus, x_minus) # phase required to translate fft to center
[docs]
self.prefactor = prefactor * ph
# self.MTF_array = diffusion_green(self.sX, self.sY)
# def get_ulen(self, ps=20):
# # Returns the required ulen for a postage stamp of size ps in pixels.
#
# pixsize = wfi_data.pix # pixel size in microns
# smin = -ps * pixsize
# smax = (
# ps * pixsize
# ) # Note that uX and uY have to be fourier duals to twice the size of the postage stamp to
# # avoid aliasing from periodic boundary conditions
#
# ulen = 2 * (smax - smin) / self.wavelength
# return ulen
[docs]
def get_optical_psf(self, normalise=True):
"""
Gets the optical PSF (no detector effects).
Returns values of the
optical PSF on the SCA surface in the postage stamp surrounding the point (SCAx, SCAy) in the
SCA. This function is added for testing purposes and to assess the impact of the interference
filter on the PSF and charge diffusion through the HgCdTe layer. Note that the optical PSF
includes the effects of diffraction and pupil mask and is normalized to total discrete flux of 1 when
``normalise`` is True. The
optical psf is saved to ``self.Optical_PSF``.
Parameters
----------
normalise : bool, optional
Currently has no effect.
"""
# prefactor = \
# self.optics.pupilMask*self.optics.determinant*np.exp(2*np.pi/self.wavelength*1j\
# *self.optics.pathDifference)
# x_minus = (-1)**np.array(range(self.optics.ulen))#used to translate ftt to image center
# ph = np.outer(x_minus, x_minus) #phase required to translate fft to center
# prefactor *= ph
# start_time = time.time()
# current_time = time.time()
# old version by Charuhas below
# E_local = np.zeros(self.ux.shape + (3,), dtype=np.complex128)
# E_local[self.mask, 0] = A_TE * np.ones_like(self.ux[self.mask])
# E_local[self.mask, 1] = -np.sqrt(1 - self.u[self.mask] ** 2) * A_TM
# E_local[self.mask, 2] = self.u[self.mask] * A_TM
# local_to_FPA = local_to_fpa_rotation(self.ux, self.uy, sgn=1)
# E_FPA_x = np.zeros_like(self.ux, dtype=np.complex128)
# E_FPA_y = np.zeros_like(self.ux, dtype=np.complex128)
# E_FPA_z = np.zeros_like(self.ux, dtype=np.complex128)
# E_FPA_x[self.mask] = np.sum(local_to_FPA[self.mask, 0, :] * E_local[self.mask, :], axis=-1)
# E_FPA_y[self.mask] = np.sum(local_to_FPA[self.mask, 1, :] * E_local[self.mask, :], axis=-1)
# E_FPA_z[self.mask] = np.sum(local_to_FPA[self.mask, 2, :] * E_local[self.mask, :], axis=-1)
# end_time = time.time()
# print("Time taken to get E field in FPA coordinates = ", end_time - current_time, "\n")
# current_time = time.time()
# E_FPA_x *= self.prefactor
# E_FPA_y *= self.prefactor
# E_FPA_z *= self.prefactor
# Ex = ifft2(E_FPA_x)
# Ey = ifft2(E_FPA_y)
# Ez = ifft2(E_FPA_z)
# Ex = np.fft.ifft2(E_FPA_x, axes=(
# print("Time taken to do ifft = ", time.time() - current_time, "\n")
# current_time = time.time()
# self.Optical_PSF = abs(Ex) ** 2 + abs(Ey) ** 2 + abs(Ez) ** 2
# print("Time taken to compute Optical PSF by squaring the E field = ",
# time.time()-current_time, "\n")
# self.Optical_PSF /= np.sum(self.Optical_PSF*self.dsX*self.dsY) # Normalise to total flux of 1
# self.Optical_PSF *= np.sum(self.dsX*self.dsY)
# New changes by Nihar here, please check before removing this comment
# Goal is to get polarization consistent with raytrace
self.E_FPA_h_polarized = self.optics.rb.E[:, :, 0, 1:4] # horizontal polarization
self.E_FPA_v_polarized = self.optics.rb.E[:, :, 1, 1:4] # vertical polarization
E_FPA_h_polarized = self.prefactor[:, :, np.newaxis] * self.E_FPA_h_polarized
E_FPA_v_polarized = self.prefactor[:, :, np.newaxis] * self.E_FPA_v_polarized
r = np.zeros((self.ux.shape[0], self.ux.shape[1], 3))
r[:, :, 0] = self.ux
r[:, :, 1] = self.uy
r[:, :, 2] = np.sqrt(np.clip(1 - self.u**2, 0.0, None))
# r = r.reshape(self.ux.shape[0], self.ux.shape[1], 3) # reshape to be compatible with E
cB_FPA_h_polarized = np.cross(r, E_FPA_h_polarized)
cB_FPA_v_polarized = np.cross(r, E_FPA_v_polarized)
# Need to add normalization
E_h_polarized = ifft2(E_FPA_h_polarized, axes=(0, 1)) # use first two axes for fft
E_v_polarized = ifft2(E_FPA_v_polarized, axes=(0, 1))
cB_h_polarized = ifft2(cB_FPA_h_polarized, axes=(0, 1))
cB_v_polarized = ifft2(cB_FPA_v_polarized, axes=(0, 1))
# Unsure about the abs here, but leaving it in for now...
self.h_polarized_psf = np.real(
0.5
* epsilon_0
* c
* (
np.conjugate(E_h_polarized[:, :, 0]) * cB_h_polarized[:, :, 1]
- np.conjugate(E_h_polarized[:, :, 1]) * cB_h_polarized[:, :, 0]
)
)
self.v_polarized_psf = np.real(
0.5
* epsilon_0
* c
* (
np.conjugate(E_v_polarized[:, :, 0]) * cB_v_polarized[:, :, 1]
- np.conjugate(E_v_polarized[:, :, 1]) * cB_v_polarized[:, :, 0]
)
)
self.Optical_PSF = 0.5 * (self.h_polarized_psf + self.v_polarized_psf)
if normalise:
total_flux = np.sum(self.Optical_PSF)
if total_flux != 0:
self.Optical_PSF /= total_flux
return
[docs]
def get_Intensity_in_detector(self, nworkers=8):
"""
Gets the total intensity of the h-polarised and v-polarised E-fields in the
detector, integrated over the depth of the detector, after the optical PSF has been computed
Parameters
----------
nworkers : int, optional
The number of workers to use for parallel processing when computing the intensity
in the detector. This is used in the `get_Intensity_from_E` function which is called
within this function.
"""
# Check if Optical_PSF has been computed, if not, compute Optical_PSF
if not hasattr(self, "Optical_PSF"):
self.get_optical_psf()
# Get the TE and TM mode amplitudes for the h-polarized and v-polarized E-fields
TE_TM_h_polarized = polarisation_mode_decomposition(self.ux, self.uy, self.E_FPA_h_polarized, sgn=1)
TE_TM_v_polarized = polarisation_mode_decomposition(self.ux, self.uy, self.E_FPA_v_polarized, sgn=1)
A_TE_h = TE_TM_h_polarized["TE"]
A_TM_h = TE_TM_h_polarized["TM"]
A_TE_v = TE_TM_v_polarized["TE"]
A_TM_v = TE_TM_v_polarized["TM"]
# Obtain the integrated intensity in the detector for the h-polarized and v-polarized E-fields
# by calling get_Intensity_from_E with the corresponding TE and TM mode amplitudes for
# the h-polarized and v-polarized E-fields.
Intensity_integrated_h = self.get_Intensity_from_E(A_TE=A_TE_h, A_TM=A_TM_h, nworkers=nworkers)
Intensity_integrated_v = self.get_Intensity_from_E(A_TE=A_TE_v, A_TM=A_TM_v, nworkers=nworkers)
# Total intensity is the sum of the h-polarized and v-polarized intensities
Intensity_integrated_net = Intensity_integrated_h + Intensity_integrated_v
self.Intensity_in_detector = Intensity_integrated_net
return
[docs]
def get_Intensity_from_E(self, A_TE=1.0e10, A_TM=1.0e10, nworkers=8):
"""
Gets the intensity from the electric field amplitudes in TE and TM modes.
This is used to get the intensity in the detector after passing through the
interference filter.
Uses adaptive Gaussian quadrature integration optimized for exponential decay
in the HgCdTe detector material, replacing the previous trapezoid rule.
Parameters
----------
A_TE : np.ndarray of complex of shape same as `ux` and `uy`
The TE mode amplitude.
A_TM : np.ndarray of complex of shape same as `ux` and `uy`
The TM mode amplitude.
nworkers : int, optional
The number of workers to use for parallel processing when computing the
inverse fourier transforms of the E-field from wave-number space to FPA
postage stamp coordinates.
Returns
-------
Intensity_integrated : np.ndarray of float of shape same as `ux` and `uy`
The intensity in the detector, integrated over the depth of the detector,
after passing through the interference filter, using adaptive Gaussian quadrature.
"""
# Initialize adaptive Gaussian quadrature integrator for detector depth integration
self._quadrature_integrator = QuadratureIntegrator(
self.wavelength, self.detector_thickness, self.ux, self.uy, self.interference_filter
)
# Get optimized nodes for evaluation
(
self._quad_nodes,
self._quad_weights,
self._quad_order,
) = self._quadrature_integrator.get_nodes_and_weights()
filter = self.interference_filter
# Evaluate E-field at adaptive quadrature nodes instead of uniform z_array
E = filter.transmitted_E(self.wavelength, self.ux, self.uy, self._quad_nodes, A_TE=A_TE, A_TM=A_TM)
Ex = E[0]
Ey = E[1]
Ez = E[2]
Ex *= self.prefactor[:, :, na]
Ey *= self.prefactor[:, :, na]
Ez *= self.prefactor[:, :, na]
Ex_postage_stamp = ifft2(Ex, axes=(0, 1), workers=nworkers)
Ey_postage_stamp = ifft2(Ey, axes=(0, 1), workers=nworkers)
Ez_postage_stamp = ifft2(Ez, axes=(0, 1), workers=nworkers)
Intensity = (abs(Ex_postage_stamp) ** 2) + (abs(Ey_postage_stamp) ** 2) + (abs(Ez_postage_stamp) ** 2)
# Apply adaptive Gaussian quadrature integration
# Intensity has shape (ux, uy, nz_quad); integrate along axis 2 using precomputed weights
Intensity_integrated = np.tensordot(Intensity, self._quad_weights, axes=(2, 0))
return Intensity_integrated
[docs]
def get_image_from_Intensity(self, centerpix=True, reflect=True, tophat=True):
"""
Gets the image on the detector from the intensity in the detector by convolving with
the MTF of charge diffusion in the HgCdTe layer. This is used to get the final PSF
image on the detector after including the effects of charge diffusion.
Parameters
----------
centerpix : bool, optional
Whether to center the PSF on a pixel.
reflect : bool, optional
Whether to reflect the PSF.
tophat : bool, optional
Whether to use a tophat function.
oversampling : int, optional
The oversampling factor.
Returns
-------
detector_image : np.ndarray of float of shape (postage_stamp_size, postage_stamp_size)
The final PSF image on the detector after including the effects of charge diffusion.
"""
# Check if Intensity_in_detector has been computed, if not, compute Intensity_in_detector
if not hasattr(self, "Intensity_in_detector"):
self.get_Intensity_in_detector()
self.x_A, self.y_A = wfi.from_sca_to_analysis(
self.optics.scanum, self.optics.scax, self.optics.scay
) # Center of the PSF in Analysis coordinates
if centerpix:
x_out = (self.x_A // pix) * pix + (0.5 * pix)
y_out = (self.y_A // pix) * pix + (0.5 * pix)
else:
x_out = self.x_A
y_out = self.y_A
self.detector_image = intensity_to_image(
self.Intensity_in_detector,
x_in=self.x_A,
y_in=self.y_A,
x_out=x_out,
y_out=y_out,
n_out=self.postage_stamp_size * self.ovsamp,
dx=self.dx,
reflect=reflect,
tophat=tophat,
)
return
# def get_detector_image3(self):
# """
# Returns the postage_stamp_size x postage_stamp_size detector image as a 2D array of intensity values.
# """
# # if not hasattr(self, 'Intensity'):
# # self.get_E_in_detector()
# self.detector_image3 = fftconvolve(self.Intensity_in_detector, self.MTF_array, mode="same")
# # XAnalysis, YAnalysis = wfi.fromSCAtoAnalysis(self.optics.scaNum, self.optics.scaX,
# # self.optics.scaY) #Center of the PSF in Analysis coordinates
# # imageX = XAnalysis + self.sX[:,0] # Note that self.sX and self.sY are in microns whereas
# # Analysis coordinates and MTF are in mm
# # imageY = YAnalysis + self.sY[0,:]
# # MTF_array = np.zeros_like(self.sX, dtype=np.float64)
# def get_detector_image2(self):
# """
# Returns the postage_stamp_size x postage_stamp_size detector image as a 2D array of intensity values.
# """
# # if not hasattr(self, 'Intensity'):
# # self.get_E_in_detector()
# pix = 1.0
# # ps = self.ulen / pix
# # Compute the detector image by summing the contributions from all points in the postage stamp
# # detector_image = np.zeros((, 4088, self.optics.ulen, self.optics.ulen), dtype=np.float64)
# XAnalysis, YAnalysis = wfi.fromSCAtoAnalysis(
# self.optics.scaNum, self.optics.scaX, self.optics.scaY
# ) # Center of the PSF in Analysis coordinates
# imageX = XAnalysis + self.sX[:, 0] # Note that self.sX and self.sY and
# imageY = YAnalysis + self.sY[0, :]
# Xd = np.floor(XAnalysis // pix) * pix
# Yd = np.floor(YAnalysis // pix) * pix
# xd_array = (
# Xd
# - (np.floor((self.postage_stamp_size - 1) / 2) * pix)
# + pix * np.arange(int(self.postage_stamp_size))
# )
# yd_array = (
# Yd
# - (np.floor((self.postage_stamp_size - 1) / 2) * pix)
# + pix * np.arange(int(self.postage_stamp_size))
# )
# xD, yD = np.meshgrid(xd_array, yd_array, indexing="ij")
# result = MTF_SCA_postage_stamp(imageX, imageY, xD, yD, self.Intensity_integrated, self.npix_boundary)
# self.detector_image2 = result
# def get_detector_image(self, nworkers=8, chunk_size=1):
# """
# Returns the postage_stamp_size x postage_stamp_size detector image as a 2D array of intensity values.
# """
# # if not hasattr(self, 'Intensity'):
# # self.get_E_in_detector()
# pix = 10
# # Compute the detector image by summing the contributions from all points in the postage stamp
# # detector_image = np.zeros((, 4088, self.optics.ulen, self.optics.ulen), dtype=np.float64)
# XAnalysis, YAnalysis = wfi.fromSCAtoAnalysis(
# self.optics.scaNum, self.optics.scaX, self.optics.scaY
# ) # Center of the PSF in Analysis coordinates
# imageX = (
# XAnalysis + self.sX
# ) # Note that self.sX and self.sY are in microns whereas Analysis coordinates and MTF are in mm
# imageY = YAnalysis + self.sY
# self.imageX = imageX
# self.imageY = imageY
# Xd = np.floor(XAnalysis // pix) * pix
# Yd = np.floor(YAnalysis // pix) * pix
# xd_array = (
# Xd
# - (np.floor((self.postage_stamp_size - 1) / 2) * pix)
# + pix * np.arange(int(self.postage_stamp_size))
# )
# yd_array = (
# Yd
# - (np.floor((self.postage_stamp_size - 1) / 2) * pix)
# + pix * np.arange(int(self.postage_stamp_size))
# )
# xD, yD = np.meshgrid(xd_array, yd_array, indexing="ij")
# mask = (np.maximum(np.abs(xD), np.abs(yD)) <= 20440).astype(
# np.float64
# ) # Mask to zero out values outside the SCA
# shape = (int(self.postage_stamp_size), int(self.postage_stamp_size))
# detector_image = np.zeros(shape, dtype=np.float64)
# tasks = [
# (
# xd_array[index_xd],
# yd_array[index_yd],
# imageX,
# imageY,
# self.Intensity_integrated,
# self.npix_boundary,
# )
# for index_xd in range(self.postage_stamp_size)
# for index_yd in range(self.postage_stamp_size)
# ]
# with ProcessPoolExecutor(max_workers=nworkers) as executor:
# results = list(executor.map(parallel_MTF_image, tasks, chunksize=chunk_size))
# detector_image = np.array(results).reshape(shape)
# # Mask out values outside the SCA
# detector_image *= mask
# self.detector_image = detector_image
# def get_E_in_detector(self, filter=interference_filter, detector_thickness=2, zlen=20, nworkers=8):
#
# # Check is self.A_TE_h and self.A_TM_h exist, if not call get_TE_TM_modes
# if not hasattr(self, 'A_TE_h') or not hasattr(self, 'A_TM_h'):
# self.get_TE_TM_modes(filter=filter,
# detector_thickness=detector_thickness, zlen=zlen, nworkers=nworkers)
#
# # dZ = z_array[1] - z_array[0]
# # ulen = self.optics.ulen
# #uX = self.optics.u_array()
# #uY = self.optics.v_array()
# # uX, uY = np.meshgrid(uX, uY, indexing='ij')
# # uX, uY = np.meshgrid(self.uX, self.uY, indexing='ij')
# E_h_polarized = filter.Transmitted_E(self.wavelength, self.ux,
# self.uy, self.z_array, A_TE=self.A_TE_h, A_TM=self.A_TM_h)
# Ex_h = E_h_polarized[0]
# Ey_h = E_h_polarized[1]
# Ez_h = E_h_polarized[2]
# end_time = time.time()
# print("Time taken to get transmitted E field through filter = ", end_time - current_time, "\n")
# current_time = time.time()
# Ex_h *= self.prefactor[:, :, na]
# Ey_h *= self.prefactor[:, :, na]
# Ez_h *= self.prefactor[:, :, na]
# end_time = time.time()
# print("Time taken to multiply by prefactor = ", end_time - current_time, "\n")
# current_time = time.time()
# Ex_h_postage_stamp = ifft2(Ex_h, axes=(0, 1), workers=nworkers)
# Ey_h_postage_stamp = ifft2(Ey_h, axes=(0, 1), workers=nworkers)
# Ez_h_postage_stamp = ifft2(Ez_h, axes=(0, 1), workers=nworkers)
# Intensity_h = (abs(Ex_h_postage_stamp) ** 2) + (
# abs(Ey_h_postage_stamp) ** 2) + (abs(Ez_h_postage_stamp) ** 2)
# self.Filtered_PSF_h = Intensity_h[:, :, 0] # /np.sum(Intensity_h[:,:,0]*self.dsX*self.dsY)
# # Filtered PSF normalise to total flux of 1 (introduced only for testing purposes)
# # self.Filtered_PSF *= np.sum(self.dsX*self.dsY)
# self.Intensity_h = Intensity_h
# self.Intensity_integrated_h = np.trapz(Intensity_h, x=self.z_array, axis=2)
# return