Source code for meersolar.utils.casatasks

import types
import psutil
import numpy as np
import glob
import os
import traceback
from casatasks import casalog
from casatools import msmetadata, ms as casamstool, table
from .basic_utils import *
from .resource_utils import *

try:
    logfile = casalog.logfile()
    os.system("rm -rf " + logfile)
except BaseException:
    pass


#############################
# General CASA tasks
#############################


[docs] def check_scan_in_caltable(caltable, scan): """ Check scan number available in caltable or not Parameters ---------- caltable : str Name of the caltable scan : int Scan number Returns ------- bool Whether scan is present in the caltable or not """ tb = table() tb.open(caltable) scans = tb.getcol("SCAN_NUMBER") tb.close() if int(scan) in scans: return True else: return False
[docs] def reset_weights_and_flags( msname="", restore_flag=True, force_reset=False, n_threads=-1, dry_run=False ): """ Reset weights and flags for the ms Parameters ---------- msname : str Measurement set restore_flag : bool, optional Restore flags or not force_reset : bool, optional Force reset """ limit_threads(n_threads=n_threads) from casatasks import flagdata if dry_run: process = psutil.Process(os.getpid()) mem = round(process.memory_info().rss / 1024**3, 2) # in GB return mem msname = msname.rstrip("/") if os.path.exists(f"{msname}/.reset") == False or force_reset: mspath = os.path.dirname(os.path.abspath(msname)) os.chdir(mspath) if restore_flag: print(f"Restoring flags of measurement set : {msname}") if os.path.exists(msname + ".flagversions"): os.system("rm -rf " + msname + ".flagversions") flagdata(vis=msname, mode="unflag", flagbackup=False) print(f"Resetting previous weights of the measurement set: {msname}") msmd = msmetadata() msmd.open(msname) npol = msmd.ncorrforpol()[0] msmd.close() tb = table() tb.open(msname, nomodify=False) colnames = tb.colnames() nrows = tb.nrows() if "WEIGHT" in colnames: print(f"Resetting weight column to ones of measurement set : {msname}.") weight = np.ones((npol, nrows)) tb.putcol("WEIGHT", weight) if "SIGMA" in colnames: print(f"Resetting sigma column to ones of measurement set: {msname}.") sigma = np.ones((npol, nrows)) tb.putcol("SIGMA", sigma) if "WEIGHT_SPECTRUM" in colnames: print(f"Removing weight spectrum of measurement set: {msname}.") tb.removecols("WEIGHT_SPECTRUM") if "SIGMA_SPECTRUM" in colnames: print(f"Removing sigma spectrum of measurement set: {msname}.") tb.removecols("SIGMA_SPECTRUM") tb.flush() tb.close() os.system(f"touch {msname}/.reset") return
[docs] def correct_missing_col_subms(msname): """ Correct for missing colurmns in sub-MSs Parameters ---------- msname : str Name of the measurement set """ tb = table() colname_list = [] sub_mslist = glob.glob(msname + "/SUBMSS/*.ms") for ms in sub_mslist: tb.open(ms) colname_list.append(tb.colnames()) tb.close() sets = [set(sublist) for sublist in colname_list] if len(sets) > 0: common_elements = set.intersection(*sets) unique_elements = set.union(*sets) - common_elements for ms in sub_mslist: tb.open(ms, nomodify=False) colnames = tb.colnames() for colname in unique_elements: if colname in colnames: print(f"Removing column: {colname} from sub-MS: {ms}") tb.removecols(colname) tb.flush() tb.close() return
[docs] def single_mstransform( msname="", outputms="", field="", scan="", width=1, timebin="", datacolumn="DATA", spw="", corr="", timerange="", numsubms="auto", n_threads=-1, dry_run=False, ): """ Perform mstransform of a single scan Parameters ---------- msname : str Name of the measurement set outputms : str Output ms name scan : int Scan to split (a single scan) field : str, optional Field name width : int, optional Number of channels to average timebin : str, optional Time to average datacolumn : str, optional Data column to split spw : str, optional Spectral window corr : str, optional Correlation to split timerange : str, optional Time range numsubms : str, optional Number of subms n_threads : int, optional Number of CPU threads Returns ------- str Output measurement set name """ limit_threads(n_threads=n_threads) from casatasks import mstransform, initweights, flagdata if dry_run: process = psutil.Process(os.getpid()) mem = round(process.memory_info().rss / 1024**3, 2) # in GB return mem if timebin == "" or timebin is None: timeaverage = False else: timeaverage = True if width > 1: chanaverage = True else: chanaverage = False outputms = outputms.rstrip("/") if os.path.exists(outputms): os.system("rm -rf " + outputms) if os.path.exists(outputms + ".flagversions"): os.system("rm -rf " + outputms + ".flagversions") try: if n_threads < 1: n_threads = 2 else: n_threads = min(n_threads, 2) if field == "": msmd = msmetadata() msmd.open(msname) field = str(msmd.fieldsforscan(int(scan))[0]) msmd.close() with suppress_casa_output(): mstransform( vis=msname, outputvis=outputms, spw=spw, timerange=timerange, field=field, scan=scan, datacolumn=datacolumn, createmms=True, correlation=corr, timeaverage=timeaverage, timebin=timebin, chanaverage=chanaverage, chanbin=int(width), nthreads=n_threads, separationaxis="scan", numsubms=numsubms, ) with suppress_casa_output(): initweights(vis=outputms, wtmode="ones", dowtsp=True) flagdata( vis=outputms, mode="clip", clipzeros=True, datacolumn="data", flagbackup=False, ) os.system(f"touch {outputms}/.splited") return outputms except Exception as e: traceback.print_exc() if os.path.exists(outputms): os.system("rm -rf " + outputms) return
# Expose functions and classes __all__ = [ name for name, obj in globals().items() if ( (isinstance(obj, types.FunctionType) or isinstance(obj, type)) and obj.__module__ == __name__ ) ]