import types
import psutil
import numpy as np
import glob
import os
import traceback
import warnings
import astropy.units as u
from sunpy.net import Fido, attrs as a
from sunpy.timeseries import TimeSeries
from astroquery.jplhorizons import Horizons
from astropy.visualization import ImageNormalize, LogStretch
from astropy.wcs import FITSFixedWarning
from astropy.io import fits
from astropy.time import Time
from astropy.coordinates import SkyCoord
from casatools import msmetadata, ms as casamstool, table
from datetime import datetime as dt, timedelta
from .basic_utils import *
from .resource_utils import *
from .udocker_utils import *
from .ms_metadata import *
from .image_utils import *
warnings.simplefilter("ignore", category=FITSFixedWarning)
#######################
# MS metadata related
#######################
[docs]
def get_fluxcals(msname):
"""
Get fluxcal field names and scans (all scans, valids and invalids
Parameters
----------
msname : str
Name of the ms
Returns
-------
list
Fluxcal field names
dict
Fluxcal scans
"""
msmd = msmetadata()
if os.path.exists(msname + "/SUBMSS"):
mslist = glob.glob(msname + "/SUBMSS/*.ms")
else:
mslist = [msname]
fluxcal_fields = []
fluxcal_scans = {}
for msname in mslist:
msmd.open(msname)
field_names = msmd.fieldnames()
for field in field_names:
if field in ["J1939-6342", "J0408-6545"]:
if field not in fluxcal_fields:
fluxcal_fields.append(field)
scans = msmd.scansforfield(field).tolist()
if field in fluxcal_scans:
for scan in scans:
fluxcal_scans[field].append(scan)
else:
fluxcal_scans[field] = scans
msmd.close()
msmd.done()
del msmd
for field in fluxcal_scans:
scans = np.unique(fluxcal_scans[field]).tolist()
fluxcal_scans[field] = scans
return fluxcal_fields, fluxcal_scans
[docs]
def get_polcals(msname):
"""
Get polarization calibrator field names and scans (all scans, valids and invalids
Parameters
----------
msname : str
Name of the ms
Returns
-------
list
Polcal field names
dict
Polcal scans
"""
msmd = msmetadata()
if os.path.exists(msname + "/SUBMSS"):
mslist = glob.glob(msname + "/SUBMSS/*.ms")
else:
mslist = [msname]
polcal_fields = []
polcal_scans = {}
for msname in mslist:
msmd.open(msname)
field_names = msmd.fieldnames()
for field in field_names:
if field in ["3C286", "1328+307", "1331+305", "J1331+3030"] or field in [
"3C138",
"0518+165",
"0521+166",
"J0521+1638",
]:
if field not in polcal_fields:
polcal_fields.append(field)
scans = msmd.scansforfield(field).tolist()
if field in polcal_scans:
for scan in scans:
polcal_scans[field].append(scan)
else:
polcal_scans[field] = scans
msmd.close()
msmd.done()
del msmd
for field in polcal_scans:
scans = np.unique(polcal_scans[field]).tolist()
polcal_scans[field] = scans
return polcal_fields, polcal_scans
[docs]
def get_phasecals(msname):
"""
Get phasecal field names and scans (all scans, valids and invalids)
Parameters
----------
msname : str
Name of the ms
Returns
-------
list
Phasecal field names
dict
Phasecal scans
dict
Phasecal flux
"""
msmd = msmetadata()
if os.path.exists(msname + "/SUBMSS"):
mslist = glob.glob(msname + "/SUBMSS/*.ms")
else:
mslist = [msname]
phasecal_fields = []
phasecal_scans = {}
phasecal_flux_list = {}
datadir = get_datadir()
for msname in mslist:
msmd.open(msname)
field_names = msmd.fieldnames()
bandname = get_band_name(msname)
if bandname == "U":
phasecals, phasecal_flux = np.load(
datadir + "/UHF_band_cal.npy", allow_pickle=True
).tolist()
elif bandname == "L":
phasecals, phasecal_flux = np.load(
datadir + "/L_band_cal.npy", allow_pickle=True
).tolist()
for field in field_names:
if field in phasecals and (field != "J1939-6342" and field != "J0408-6545"):
if field not in phasecal_fields:
phasecal_fields.append(field)
scans = msmd.scansforfield(field).tolist()
if field in phasecal_scans:
for scan in scans:
phasecal_scans[field].append(scan)
else:
phasecal_scans[field] = scans
flux = phasecal_flux[phasecals.index(field)]
phasecal_flux_list[field] = flux
msmd.close()
msmd.done()
del msmd
for field in phasecal_scans:
scans = np.unique(phasecal_scans[field]).tolist()
phasecal_scans[field] = scans
return phasecal_fields, phasecal_scans, phasecal_flux_list
[docs]
def get_valid_scans(msname, field="", min_scan_time=1, n_threads=-1):
"""
Get valid list of scans
Parameters
----------
msname : str
Measurement set name
field : str
Field names (comma seperated)
min_scan_time : float
Minimum valid scan time in minute
Returns
-------
list
Valid scan list
"""
limit_threads(n_threads=n_threads)
from casatools import ms as casamstool
mstool = casamstool()
mstool.open(msname)
scan_summary = mstool.getscansummary()
mstool.close()
scans = np.sort(np.array([int(i) for i in scan_summary.keys()]))
target_scans, cal_scans, f_scans, g_scans, p_scans = get_cal_target_scans(msname)
selected_field = []
valid_scans = []
if field != "":
field = field.split(",")
msmd = msmetadata()
msmd.open(msname)
for f in field:
with suppress_output():
try:
field_id = msmd.fieldsforname(f)[0]
except Exception as e:
field_id = int(f)
selected_field.append(field_id)
msmd.close()
msmd.done()
del msmd
for scan in scans:
scan_field = scan_summary[str(scan)]["0"]["FieldId"]
if len(selected_field) == 0 or scan_field in selected_field:
duration = (
scan_summary[str(scan)]["0"]["EndTime"]
- scan_summary[str(scan)]["0"]["BeginTime"]
) * 86400.0
duration = round(duration / 60.0, 1)
if duration >= min_scan_time:
valid_scans.append(scan)
return valid_scans
[docs]
def get_target_fields(msname):
"""
Get target fields
Parameters
----------
msname : str
Name of the measurement set
Returns
-------
list
Target field names
dict
Target field scans
"""
fluxcal_field, fluxcal_scans = get_fluxcals(msname)
phasecal_field, phasecal_scans, phasecal_fluxs = get_phasecals(msname)
calibrator_field = fluxcal_field + phasecal_field
msmd = msmetadata()
msmd.open(msname)
field_names = msmd.fieldnames()
field_names = np.unique(field_names)
target_fields = []
target_scans = {}
for f in field_names:
if f not in calibrator_field:
target_fields.append(f)
for field in target_fields:
scans = msmd.scansforfield(field).tolist()
target_scans[field] = scans
msmd.close()
msmd.done()
del msmd
return target_fields, target_scans
[docs]
def get_caltable_fields(caltable):
"""
Get caltable field names
Parameters
----------
caltable : str
Caltable name
Returns
-------
list
Field names
"""
tb = table()
tb.open(caltable + "/FIELD")
field_names = tb.getcol("NAME")
field_ids = tb.getcol("SOURCE_ID")
tb.close()
tb.open(caltable)
fields = np.unique(tb.getcol("FIELD_ID"))
tb.close()
field_name_list = []
for f in fields:
pos = np.where(field_ids == f)[0][0]
field_name_list.append(str(field_names[pos]))
return field_name_list
[docs]
def get_cal_target_scans(msname):
"""
Get calibrator and target scans
Parameters
----------
msname : str
Name of the measurement set
Returns
-------
list
Target scan numbers
list
Calibrator scan numbers
list
Fluxcal scans
list
Phasecal scans
list
Polcal scans
"""
f_scans = []
p_scans = []
g_scans = []
fluxcal_fields, fluxcal_scans = get_fluxcals(msname)
phasecal_fields, phasecal_scans, phasecal_flux_list = get_phasecals(msname)
polcal_fields, polcal_scans = get_polcals(msname)
for fluxcal_scan in fluxcal_scans.values():
for s in fluxcal_scan:
f_scans.append(s)
for polcal_scan in polcal_scans.values():
for s in polcal_scan:
p_scans.append(s)
for phasecal_scan in phasecal_scans.values():
for s in phasecal_scan:
g_scans.append(s)
cal_scans = f_scans + p_scans + g_scans
msmd = msmetadata()
msmd.open(msname)
all_scans = msmd.scannumbers()
msmd.close()
msmd.done()
target_scans = []
for scan in all_scans:
if scan not in cal_scans:
target_scans.append(scan)
return target_scans, cal_scans, f_scans, g_scans, p_scans
[docs]
def get_band_name(msname):
"""
Get band name
Parameters
----------
msname : str
Name of the ms
Returns
-------
str
Band name ('U','L','S')
"""
msmd = msmetadata()
msmd.open(msname)
meanfreq = msmd.meanfreq(0) / 10**6
msmd.close()
msmd.done()
if meanfreq >= 544 and meanfreq <= 1088:
return "U"
elif meanfreq >= 856 and meanfreq <= 1712:
return "L"
else:
return "S"
[docs]
def get_bad_chans(msname):
"""
Get bad channels to flag
Parameters
----------
msname : str
Name of the ms
Returns
-------
str
SPW string of bad channels
"""
msmd = msmetadata()
msmd.open(msname)
chanfreqs = msmd.chanfreqs(0) / 10**6
msmd.close()
msmd.done()
bandname = get_band_name(msname)
if bandname == "U":
bad_freqs = [
(-1, 580),
(925, 960),
(1010, -1),
]
elif bandname == "L":
bad_freqs = [
(-1, 879),
(925, 960),
(1166, 1186),
(1217, 1249),
(1375, 1387),
(1526, 1626),
(1681, -1),
]
else:
print("MeerKAT data is not in UHF or L-band.")
bad_freqs = []
if min(chanfreqs) <= bad_freqs[0][1] and max(chanfreqs) >= bad_freqs[-1][0]:
spw = "0:"
count = 0
for freq_range in bad_freqs:
start_freq = freq_range[0]
end_freq = freq_range[1]
if start_freq == -1:
start_chan = 0
else:
start_chan = np.argmin(np.abs(start_freq - chanfreqs))
if count > 0 and start_chan <= end_chan:
break
if end_freq == -1:
end_chan = len(chanfreqs) - 1
else:
end_chan = np.argmin(np.abs(end_freq - chanfreqs))
if end_chan > start_chan:
spw += str(start_chan) + "~" + str(end_chan) + ";"
else:
spw += str(start_chan) + ";"
count += 1
spw = spw[:-1]
else:
spw = ""
return spw
[docs]
def get_good_chans(msname):
"""
Get good channel range to perform gaincal
Parameters
----------
msname : str
Name of the ms
Returns
-------
str
SPW string
"""
msmd = msmetadata()
msmd.open(msname)
chanfreqs = msmd.chanfreqs(0) / 10**6
msmd.close()
msmd.done()
bandname = get_band_name(msname)
if bandname == "U":
good_freqs = [
(581, 924),
(959, 1009),
]
elif bandname == "L":
good_freqs = [
(880, 924),
(961, 1165),
(1187, 1216),
(1250, 1374),
(1388, 1525),
(1627, 1680),
]
else:
print("MeerKAT data is not in UHF or L-band.")
good_freqs = []
if min(chanfreqs) <= good_freqs[0][1] and max(chanfreqs) >= good_freqs[-1][0]:
spw = "0:"
count = 0
for freq_range in good_freqs:
start_freq = freq_range[0]
end_freq = freq_range[1]
if start_freq == -1:
start_chan = 0
else:
start_chan = np.argmin(np.abs(start_freq - chanfreqs))
if count > 0 and start_chan <= end_chan:
break
if end_freq == -1:
end_chan = len(chanfreqs) - 1
else:
end_chan = np.argmin(np.abs(end_freq - chanfreqs))
if end_chan > start_chan:
spw += str(start_chan) + "~" + str(end_chan) + ";"
else:
spw += str(start_chan) + ";"
count += 1
spw = spw[:-1]
else:
spw = ""
return spw
###########################
# Spliting noise diode
###########################
[docs]
def split_noise_diode_scans(
msname="",
noise_on_ms="",
noise_off_ms="",
field="",
scan="",
datacolumn="data",
n_threads=-1,
):
"""
Split noise diode on and off timestamps into two seperate measurement sets
Parameters
----------
msname : str
Measurement set
noise_on_ms : str, optional
Noise diode on ms
noise_off_ms : str, optional
Noise diode off ms
field : str, optional
Field name or id
scan : str, optional
Scan number
datacolumn : str, optional
Data column to split
Returns
-------
tuple
splited ms names
"""
limit_threads(n_threads=n_threads)
from casatasks import split
msname = msname.rstrip("/")
mspath = os.path.dirname(os.path.abspath(msname))
os.chdir(mspath)
print(f"Spliting ms: {msname} into noise diode on and off measurement sets.")
if noise_on_ms == "":
noise_on_ms = msname.split(".ms")[0] + "_noise_on.ms"
if noise_off_ms == "":
noise_off_ms = msname.split(".ms")[0] + "_noise_off.ms"
if os.path.exists(noise_on_ms):
os.system("rm -rf " + noise_on_ms)
if os.path.exists(noise_on_ms + ".flagversions"):
os.system("rm -rf " + noise_on_ms + ".flagversions")
if os.path.exists(noise_off_ms):
os.system("rm -rf " + noise_off_ms)
if os.path.exists(noise_off_ms + ".flagversions"):
os.system("rm -rf " + noise_off_ms + ".flagversions")
tb = table()
tb.open(msname)
times = tb.getcol("TIME")
tb.close()
unique_times = np.unique(times)
even_times = unique_times[::2] # Even-indexed timestamps
odd_times = unique_times[1::2] # Odd-indexed timestamps
even_timerange = ",".join(
[mjdsec_to_timestamp(t, str_format=1) for t in even_times]
)
odd_timerange = ",".join([mjdsec_to_timestamp(t, str_format=1) for t in odd_times])
even_ms = msname.split(".ms")[0] + "_even.ms"
odd_ms = msname.split(".ms")[0] + "_odd.ms"
split(
vis=msname,
outputvis=even_ms,
timerange=even_timerange,
field=field,
scan=scan,
datacolumn=datacolumn,
)
split(
vis=msname,
outputvis=odd_ms,
timerange=odd_timerange,
field=field,
scan=scan,
datacolumn=datacolumn,
)
mstool = casamstool()
mstool.open(even_ms)
mstool.select({"antenna1": 1, "antenna2": 1})
even_data = np.nanmean(np.abs(mstool.getdata("DATA")["data"]))
mstool.close()
mstool.open(odd_ms)
mstool.select({"antenna1": 1, "antenna2": 1})
odd_data = np.nanmean(np.abs(mstool.getdata("DATA")["data"]))
mstool.close()
if even_data > odd_data:
os.system("mv " + even_ms + " " + noise_on_ms)
os.system("mv " + odd_ms + " " + noise_off_ms)
else:
os.system("mv " + odd_ms + " " + noise_on_ms)
os.system("mv " + even_ms + " " + noise_off_ms)
return noise_on_ms, noise_off_ms
[docs]
def determine_noise_diode_cal_scan(msname, scan):
"""
Determine whether a calibrator scan is a noise-diode cal scan or not
Parameters
----------
msname : str
Name of the measurement set
scan : int
Scan number
Returns
-------
bool
Whether it is noise-diode cal scan or not
"""
def is_noisescan(msname, chan, scan):
mstool = casamstool()
mstool.open(msname)
mstool.select({"antenna1": 1, "antenna2": 1, "scan_number": scan})
mstool.selectchannel(nchan=1, width=1, start=chan)
data = mstool.getdata("DATA", ifraxis=True)["data"][:, 0, 0, :]
mstool.close()
xx = np.abs(data[0, ...])
yy = np.abs(data[-1, ...])
even_xx = xx[1::2]
odd_xx = xx[::2]
minlen = min(len(even_xx), len(odd_xx))
d_xx = even_xx[:minlen] - odd_xx[:minlen]
even_yy = yy[1::2]
odd_yy = yy[::2]
d_yy = even_yy[:minlen] - odd_yy[:minlen]
mean_d_xx = np.abs(np.nanmedian(d_xx))
mean_d_yy = np.abs(np.nanmedian(d_yy))
if mean_d_xx > 10 and mean_d_yy > 10:
return True
else:
return False
print(f"Check noise-diode cal for scan : {scan}")
good_spw = get_good_chans(msname)
chan = int(good_spw.split(";")[0].split(":")[-1].split("~")[0])
return is_noisescan(msname, chan, scan)