Source code for meersolar.pipeline.single_image_meerpbcor

import numpy as np, copy, psutil, os, astropy.units as u, warnings, gc, traceback, time, argparse, sys
from astropy.io import fits
from numpy.linalg import inv
from astropy.coordinates import AltAz, EarthLocation, SkyCoord
from astropy.time import Time
from astropy.wcs import WCS
from scipy.interpolate import RectBivariateSpline
from joblib import Parallel, delayed as joblid_delayed
from astropy.wcs import FITSFixedWarning
from meersolar.pipeline.basic_func import *
from casatasks import casalog

try:
    casalogfile = casalog.logfile()
    os.system("rm -rf " + casalogfile)
except:
    pass

warnings.simplefilter("ignore", category=FITSFixedWarning)

# Define MeerKAT location
MEERLAT = -30.7133
MEERLON = 21.4429
MEERALT = 1086.6
datadir = get_datadir()


[docs] def get_IQUV(filename): """ Get IQUV from a fits Parameters ---------- filename : str Fits image name Returns ------- dict Stokes """ data = fits.getdata(filename).astype("float32") header = fits.getheader(filename) if header["CTYPE3"] == "STOKES": stokesaxis = 3 elif header["CTYPE4"] == "STOKES": stokesaxis = 4 else: stokesaxis = 1 shape = data.shape stokes = {} if shape[0] > 1 or shape[1] > 1 and (stokesaxis == 3 or stokesaxis == 4): if stokesaxis == 3: stokes["I"] = data[0, 0, :, :] stokes["Q"] = data[0, 1, :, :] stokes["U"] = data[0, 2, :, :] stokes["V"] = data[0, 3, :, :] elif stokesaxis == 4: stokes["I"] = data[0, 0, :, :] stokes["Q"] = data[1, 0, :, :] stokes["U"] = data[2, 0, :, :] stokes["V"] = data[3, 0, :, :] else: stokes["I"] = data[0, 0, :, :] stokes["Q"] = copy.deepcopy(stokes["I"]) * 0 stokes["U"] = copy.deepcopy(stokes["I"]) * 0 stokes["V"] = copy.deepcopy(stokes["I"]) * 0 return stokes
[docs] def put_IQUV(filename, stokes, header): """ Put IQUV into a fits Parameters ---------- filename : str Fits image name stokes : dict Stokes header : dict Image header Returns ------- dict Stokes """ if header["CTYPE3"] == "STOKES": stokesaxis = 3 elif header["CTYPE4"] == "STOKES": stokesaxis = 4 else: stokesaxis = 1 naxis = header["NAXIS"] shape = tuple(header[f"NAXIS{axis}"] for axis in range(naxis, 0, -1)) data = np.empty(shape, dtype=np.float32) if shape[0] > 1 or shape[1] > 1 and (stokesaxis == 3 or stokesaxis == 4): if stokesaxis == 3: data[0, 0, :, :] = stokes["I"] data[0, 1, :, :] = stokes["Q"] data[0, 2, :, :] = stokes["U"] data[0, 3, :, :] = stokes["V"] elif stokesaxis == 4: data[0, 0, :, :] = stokes["I"] data[1, 0, :, :] = stokes["Q"] data[2, 0, :, :] = stokes["U"] data[3, 0, :, :] = stokes["V"] else: data[0, 0, :, :] = stokes["I"] fits.writeto(filename, data=data, header=header, overwrite=True) return filename
[docs] def get_brightness(stokes): """ Returns brightness matrix from stokes dictionary (X and Y are in opposite convention of IAU in MeerKAT) """ I = stokes["I"].astype("float32") Q = stokes["Q"].astype("float32") U = stokes["U"].astype("float32") V = stokes["V"].astype("float32") XX = I - Q XY = U - 1j * V YX = U + 1j * V YY = I + Q B = np.array([XX, XY, YX, YY]).astype("complex64") B = B.T B = B.reshape(B.shape[0], B.shape[1], 2, 2) return B
[docs] def make_stokes(b): """ Makes stokes images from brightness matrix """ XX = b[0, 0, ...].astype("complex64") XY = b[0, 1, ...].astype("complex64") YX = b[1, 0, ...].astype("complex64") YY = b[1, 1, ...].astype("complex64") stokes = {} stokes["I"] = np.real(XX + YY) / 2.0 stokes["Q"] = np.real(YY - XX) / 2.0 stokes["U"] = np.real(XY + YX) / 2.0 stokes["V"] = np.real(1j * (XY - YX)) / 2.0 return stokes
[docs] def load_beam(image_file, band=""): """ Load MeerKAT beam Parameters ---------- image_file : str Image name (Assuming single spectral image) band : str, optional Band name (If not provided, check from header or frequency) Returns ------- numpy.array l,m coordinates numpy.array Full Jones complex beam """ hdr = fits.getheader(image_file) if hdr["CTYPE3"] == "FREQ": freq = hdr["CRVAL3"] delfreq = hdr["CDELT3"] elif hdr["CTYPE4"] == "FREQ": freq = hdr["CRVAL4"] delfreq = hdr["CDELT4"] else: print("No frequency axis in image.") return freq1 = (freq - (delfreq / 2)) / 10**6 # In MHz freq2 = (freq + (delfreq / 2)) / 10**6 # In MHz if band == "": try: band = hdr["BAND"] except: if freq1 >= 544 and freq2 <= 1088: # UHF band band = "U" elif freq1 >= 856 and freq2 <= 1712: # L band band = "L" else: print("Image is not in UHF or L-band.") return if band == "U": beam_data = np.load(datadir + "/MeerKAT_antavg_Uband.npz", mmap_mode="r") elif band == "L": beam_data = np.load(datadir + "/MeerKAT_antavg_Lband.npz", mmap_mode="r") else: print("Image is not in UHF or L-band.") return freqs = beam_data["freqs"] coords = np.deg2rad( beam_data["coords"] ) # It is done as l,m values were converted into degree pos = np.where((freq >= freq1) & (freqs <= freq1))[0] beam = beam_data["beams"][:, pos, ...].mean(1) beam = beam.astype("complex64") del beam_data, freqs gc.collect() return coords, beam
[docs] def get_radec_grid(image_file): """ Get RA and Dec arrays for all pixels in an image. Parameters ---------- image_file : str FITS image file name Returns ------- ra : 2D numpy.ndarray RA values in degrees for each pixel dec : 2D numpy.ndarray Dec values in degrees for each pixel """ hdr = fits.getheader(image_file) wcs = WCS(hdr).celestial ny, nx = hdr["NAXIS2"], hdr["NAXIS1"] y, x = np.mgrid[0:ny, 0:nx] # pixel coordinates world = wcs.pixel_to_world(x, y) ra = world.ra.deg dec = world.dec.deg return ra, dec
[docs] def get_pointingcenter_radec(image_file): """ Get image pointing center RA DEC Parameters ---------- image_file : str Image file name Returns ------- float RA in degree float DEC in degree """ hdr = fits.getheader(image_file) image_wcs = WCS(hdr) image_shape = (hdr["NAXIS2"], hdr["NAXIS1"]) ra0 = float(hdr["CRVAL1"]) dec0 = float(hdr["CRVAL2"]) return ra0, dec0
[docs] def radec_to_lm(ra_deg, dec_deg, ra0_deg, dec0_deg): """ Convert RA/Dec to l,m direction cosines relative to a phase center. Parameters ---------- ra_deg, dec_deg : 2D arrays RA and Dec in degrees ra0_deg, dec0_deg : float Phase center RA and Dec in degrees Returns ------- l, m : 2D arrays Direction cosines (dimensionless) """ ra = np.radians(ra_deg) dec = np.radians(dec_deg) ra0 = np.radians(ra0_deg) dec0 = np.radians(dec0_deg) delta_ra = ra - ra0 l = np.cos(dec) * np.sin(delta_ra) m = np.sin(dec) * np.cos(dec0) - np.cos(dec) * np.sin(dec0) * np.cos(delta_ra) return l, m
[docs] def get_parallactic_angle( obs_time, ra_deg, dec_deg, LAT=MEERLAT, LON=MEERLON, ALT=MEERALT ): """ Get parallactic angle Parameters ---------- obs_time : str Observation time in YYY-MM-DDThh:mm:ss format ra : float RA in degree dec : float DEC in degree Returns ------- float Parallactic angle in degree """ sky = SkyCoord(ra=ra_deg * u.deg, dec=dec_deg * u.deg, frame="icrs") obstime = Time(obs_time) meerpos = EarthLocation(lat=LAT * u.deg, lon=LON * u.deg, height=ALT * u.m) altaz = sky.transform_to(AltAz(obstime=obstime, location=meerpos)) az = altaz.az.rad alt = altaz.alt.rad lat = np.deg2rad(LAT) p = np.arctan2( np.sin(az) * np.cos(lat), np.cos(alt) * np.sin(lat) - np.sin(alt) * np.cos(lat) * np.cos(az), ) return np.rad2deg(p)
[docs] def get_beam_interpolator(jones, coords): """ Get beam interpolator Parameters ---------- jones : numpy.array Jones array (shape, npol, l_npix, m_npix) coords : numpy.array l,m coordinates Returns ------- interpolator Interpolation functions """ j00_r = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.real(jones[0, ...])) ) j00_i = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.imag(jones[0, ...])) ) j01_r = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.real(jones[1, ...])) ) j01_i = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.imag(jones[1, ...])) ) j10_r = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.real(jones[2, ...])) ) j10_i = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.imag(jones[2, ...])) ) j11_r = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.real(jones[3, ...])) ) j11_i = RectBivariateSpline( x=coords, y=coords, z=np.nan_to_num(np.imag(jones[3, ...])) ) return j00_r, j00_i, j01_r, j01_i, j10_r, j10_i, j11_r, j11_i
[docs] def apply_parallactic_rotation(jones, p_angle): """ Apply left-side parallactic rotation: J' = J.R(p_angle) as needed in the RIME context (sky-frame transformation). Parameters ---------- jones : ndarray Jones matrix, shape (4, H, W), with components: [0] = J_00, [1] = J_01, [2] = J_10, [3] = J_11 chi : float Parallactic angle in degree Returns ------- jones_rot : ndarray Rotated Jones matrix, shape (4, H, W) """ p_angle = np.deg2rad(p_angle) c = np.cos(p_angle) s = -np.sin(p_angle) j00, j01, j10, j11 = jones jj00 = j00 * c - j01 * s jj01 = j00 * s + j01 * c jj10 = j10 * c - j11 * s jj11 = j10 * s + j11 * c return np.stack([jj00, jj01, jj10, jj11], axis=0).astype("complex64")
[docs] def get_image_beam( image_file, pbdir, save_beam=True, band="", apply_parang=True, n_cpu=8, verbose=False, ): """ Get image beam Parameters ---------- image_file : str Image file name pbdir : str Primary beam directory save_beam : bool, optional Save beam of the image band : str, optional Band name apply_parang : bool, optional Apply parallactic angle correction n_cpu : int, optinal Number of CPU threads to use verbose : bool, optional Verbose output Returns ------- numpy.array Jones array """ if n_cpu > 8: n_cpu = 8 start_time = time.time() ################################## header = fits.getheader(image_file) if header["CTYPE3"] == "FREQ": freq = header["CRVAL3"] elif header["CTYPE4"] == "FREQ": freq = header["CRVAL4"] else: print("No frequency axis in image.") return freq = round(freq / 10**6, 1) # In MHz pbfile = f"{pbdir}/freq_{freq}_pb.npy" obs_time = header["DATE-OBS"] ra0, dec0 = get_pointingcenter_radec(image_file) # Phase center p_angle = get_parallactic_angle( obs_time, ra0, dec0 ) # Parallactic angle of the center ####################################### # If beam file exists ####################################### fresh_run = True if os.path.exists(pbfile): if verbose: print(f"Loading beam from: {pbfile}") try: jones_array = np.load(pbfile, allow_pickle=True) fresh_run = False except: fresh_run = True os.system(f"rm -rf {pbfile}") ################################# # Fresh run ################################# if fresh_run: ra_grid, dec_grid = get_radec_grid(image_file) # RA DEC grid l_grid, m_grid = radec_to_lm(ra_grid, dec_grid, ra0, dec0) ############################ # Load beam ############################ beam_results = load_beam(image_file, band=band) if beam_results == None: return lm_coords, beam = beam_results j00_r, j00_i, j01_r, j01_i, j10_r, j10_i, j11_r, j11_i = get_beam_interpolator( beam, lm_coords ) l_grid_flat = l_grid.flatten() m_grid_flat = m_grid.flatten() grid_shape = l_grid.shape del l_grid, m_grid gc.collect() with Parallel(n_jobs=n_cpu, backend="threading") as parallel: results = parallel( [ joblid_delayed(j00_r)(l_grid_flat, m_grid_flat, grid=False), joblid_delayed(j00_i)(l_grid_flat, m_grid_flat, grid=False), joblid_delayed(j01_r)(l_grid_flat, m_grid_flat, grid=False), joblid_delayed(j01_i)(l_grid_flat, m_grid_flat, grid=False), joblid_delayed(j10_r)(l_grid_flat, m_grid_flat, grid=False), joblid_delayed(j10_i)(l_grid_flat, m_grid_flat, grid=False), joblid_delayed(j11_r)(l_grid_flat, m_grid_flat, grid=False), joblid_delayed(j11_i)(l_grid_flat, m_grid_flat, grid=False), ] ) del parallel ( j00_r_arr, j00_i_arr, j01_r_arr, j01_i_arr, j10_r_arr, j10_i_arr, j11_r_arr, j11_i_arr, ) = results j00_r_arr = j00_r_arr.reshape(grid_shape) j00_i_arr = j00_i_arr.reshape(grid_shape) j01_r_arr = j01_r_arr.reshape(grid_shape) j01_i_arr = j01_i_arr.reshape(grid_shape) j10_r_arr = j10_r_arr.reshape(grid_shape) j10_i_arr = j10_i_arr.reshape(grid_shape) j11_r_arr = j11_r_arr.reshape(grid_shape) j11_i_arr = j11_i_arr.reshape(grid_shape) jones_array = np.array( [ j00_r_arr + 1j * j00_i_arr, j01_r_arr + 1j * j01_i_arr, j10_r_arr + 1j * j10_i_arr, j11_r_arr + 1j * j11_i_arr, ] ).astype("complex64") del ( j00_r_arr, j00_i_arr, j01_r_arr, j01_i_arr, j10_r_arr, j10_i_arr, j11_r_arr, j11_i_arr, ) gc.collect() if save_beam and os.path.exists(pbfile) == False: np.save(pbfile, np.array(jones_array, dtype="object")) if verbose: print(f"Beam saved in: {pbfile}") if apply_parang: jones_array = apply_parallactic_rotation( jones_array, p_angle ).T # This is to account B'=P(X)BP(-X) parallactic trasnform on brightness matrix jones_array = jones_array.reshape(jones_array.shape[0], jones_array.shape[1], 2, 2) gc.collect() if verbose: print(f"Beam calculated in : {round(time.time()-start_time,1)}s") return jones_array
[docs] def get_pbcor_image( image_file, pbdir, pbcor_dir, save_beam=True, band="", apply_parang=True, n_cpu=8, verbose=False, ): """ Get primary beam corrected image Parameters ---------- image_file : str Image file name pbdir : str Primary beam directory pbcor_dir : str Primary beam corrected image directory save_beam : bool, optional Save the beam for the image band : str, optional Band name apply_parang : bool, optional Apply parallactic correction n_cpu : int, optional Number of CPU threads to use verbose : bool, optional Verbose output Returns ------- str Primary beam corrected image """ try: image_file = image_file.rstrip("/") print(f"Correcting beam for image: {os.path.basename(image_file)}...") beam = get_image_beam( image_file, pbdir, save_beam=save_beam, band=band, apply_parang=apply_parang, n_cpu=int(n_cpu), verbose=verbose, ) if type(beam) != np.ndarray: print(f"Error in correct beam for image: {os.path.basename(image_file)}") return det = beam[..., 0, 0] * beam[..., 1, 1] - beam[..., 0, 1] * beam[..., 1, 0] inv_beam = np.empty_like(beam, dtype=np.complex64) inv_beam[..., 0, 0] = beam[..., 1, 1] / det inv_beam[..., 0, 1] = -beam[..., 0, 1] / det inv_beam[..., 1, 0] = -beam[..., 1, 0] / det inv_beam[..., 1, 1] = beam[..., 0, 0] / det beam_H = np.conj(np.swapaxes(beam, -1, -2)) del beam gc.collect() det = ( beam_H[..., 0, 0] * beam_H[..., 1, 1] - beam_H[..., 0, 1] * beam_H[..., 1, 0] ) inv_beam_H = np.empty_like(beam_H, dtype=np.complex64) inv_beam_H[..., 0, 0] = beam_H[..., 1, 1] / det inv_beam_H[..., 0, 1] = -beam_H[..., 0, 1] / det inv_beam_H[..., 1, 0] = -beam_H[..., 1, 0] / det inv_beam_H[..., 1, 1] = beam_H[..., 0, 0] / det del beam_H gc.collect() image_stokes = get_IQUV(image_file) B_matrix = get_brightness(image_stokes) del image_stokes gc.collect() B_tmp = np.matmul(B_matrix, inv_beam_H) del inv_beam_H gc.collect() B_cor = np.matmul(inv_beam, B_tmp) del B_tmp, inv_beam gc.collect() B_cor = np.transpose(B_cor, (2, 3, 1, 0)) pbcor_stokes = make_stokes(B_cor) del B_cor gc.collect() ################################# pbcor_file = ( pbcor_dir + "/" + os.path.basename(image_file).split(".fits")[0] + "_pbcor.fits" ) header = fits.getheader(image_file) pbcor_file = put_IQUV(pbcor_file, pbcor_stokes, header) return pbcor_file except Exception as e: traceback.print_exc() gc.collect() return
[docs] def main(): parser = argparse.ArgumentParser( description="Correct image for full-polar antenna averaged MeerKAT primary beam", formatter_class=SmartDefaultsHelpFormatter, ) ## Essential parameters basic_args = parser.add_argument_group( "###################\nEssential parameters\n###################" ) basic_args.add_argument( "imagename", type=str, help="Name of image (required positional argument)" ) basic_args.add_argument( "--pbdir", type=str, default="", help="Name of primary beam directory", ) basic_args.add_argument( "--pbcor_dir", type=str, default="", help="Name of primary beam corrected image directory", ) ## Advanced parameters adv_args = parser.add_argument_group( "###################\nAdvanced parameters\n###################" ) adv_args.add_argument( "--no_save_beam", action="store_false", dest="save_beam", help="Do not save beam to disk", ) adv_args.add_argument( "--band", type=str, default="", help="Band name" ) adv_args.add_argument( "--no_apply_parang", action="store_false", dest="apply_parang", help="Do not apply parallactic angle correction", ) adv_args.add_argument("--verbose", action="store_true", help="Verbose output") ## Resource management parameters hard_args = parser.add_argument_group( "###################\nHardware resource management parameters\n###################" ) hard_args.add_argument( "--ncpu", type=int, default=8, help="Number of CPU threads to use", metavar="Integer", ) hard_args.add_argument( "--jobid", type=int, default=0, help="Job ID" ) if len(sys.argv) == 1: parser.print_help(sys.stderr) sys.exit(1) args = parser.parse_args() pid = os.getpid() meersolar_cachedir=get_meersolar_cachedir() save_pid(pid, f"{meersolar_cachedir}/pids/pids_{args.jobid}.txt") try: if args.imagename and os.path.exists(args.imagename): if args.pbdir == "": print("Provide an existing directory name in pbdir.") msg = 1 else: os.makedirs(args.pbdir, exist_ok=True) pbcor_dir = args.pbcor_dir if args.pbcor_dir else args.pbdir os.makedirs(pbcor_dir, exist_ok=True) pbcor_image = get_pbcor_image( args.imagename, args.pbdir, pbcor_dir, band=args.band, apply_parang=args.apply_parang, save_beam=args.save_beam, n_cpu=int(args.ncpu), verbose=args.verbose, ) if pbcor_image is None or not os.path.exists(pbcor_image): msg = 1 print(f"Primary beam correction is not successful") else: msg = 0 print(f"Primary beam corrected image: {pbcor_image}") else: print("Please provide correct image name.\n") msg = 1 except Exception as e: traceback.print_exc() msg = 1 return msg
if __name__ == "__main__": result = main() os._exit(result)