Source code for meersolar.pipeline.flagging

import sys, traceback, time, gc
import os, numpy as np, copy, glob, argparse
from datetime import datetime as dt, timezone
from functools import partial
from meersolar.pipeline.basic_func import *
from dask import delayed, compute, config
from casatasks import casalog

try:
    logfile = casalog.logfile()
    os.system("rm -rf " + logfile)
except:
    pass


[docs] def single_ms_flag( msname="", badspw="", bad_ants_str="", datacolumn="data", use_tfcrop=True, use_rflag=False, flagdimension="freqtime", flag_autocorr=True, n_threads=-1, memory_limit=-1, dry_run=False, ): """ Flag on a single ms Parameters ---------- msname : str Measurement set name badspw : str, optional Bad spectral window bad_ants_str : str, optional Bad antenna string datacolumn : str, optional Data column use_tfcrop : str, optional Use tfcrop or not use_rflag : str, optional Use rflag or not flagdimension : str, optional Flag dimension (only applicable for tfcrop) flag_autocorr : bool, optional Flag autocorrelations or not n_threads : int, optional Number of OpenMP threads memory_limit : float, optional Memory limit in GB dry_run : bool, optional Return the amount of pre-occupied memory in GB """ limit_threads(n_threads=n_threads) from casatasks import flagdata if dry_run: process = psutil.Process(os.getpid()) mem = round(process.memory_info().rss / 1024**3, 2) # in GB return mem msname = msname.rstrip("/") try: ############################## # Flagging bad channels ############################## if badspw != "": print(f"Flagging bad spectral windows: {badspw}\n") try: with suppress_casa_output(): flagdata( vis=msname, mode="manual", spw=badspw, cmdreason="badchan", flagbackup=False, ) except: pass ############################## # Flagging bad antennas ############################## if bad_ants_str != "": print(f"Flagging bad antenna: {bad_ants_str}\n") try: with suppress_casa_output(): flagdata( vis=msname, mode="manual", antenna=bad_ants_str, cmdreason="badant", flagbackup=False, ) except: pass ################################# # Clip zero amplitude data points ################################# try: with suppress_casa_output(): flagdata( vis=msname, mode="clip", clipzeros=True, datacolumn=datacolumn, autocorr=flag_autocorr, flagbackup=False, ) except: pass ################################# # Flag auto-correlations ################################# if flag_autocorr: try: with suppress_casa_output(): flagdata( vis=msname, mode="manual", autocorr=True, datacolumn=datacolumn, flagbackup=False, ) except: pass #################################################### # Check if required columns are present for residual #################################################### if datacolumn == "residual" or datacolumn == "RESIDUAL": modelcolumn_present = check_datacolumn_valid( msname, datacolumn="MODEL_DATA" ) corcolumn_present = check_datacolumn_valid( msname, datacolumn="CORRECTED_DATA" ) if modelcolumn_present == False or corcolumn_present == False: datacolumn = "corrected" elif datacolumn == "RESIDUAL_DATA": modelcolumn_present = check_datacolumn_valid( msname, datacolumn="MODEL_DATA" ) datacolumn_present = check_datacolumn_valid(msname, datacolumn="DATA") if modelcolumn_present == False or datacolumn_present == False: datacolumn = "corrected" ################################################# # Whether corrected data column is present or not ################################################# if datacolumn == "corrected" or datacolumn == "CORRECTED_DATA": corcolumn_present = check_datacolumn_valid( msname, datacolumn="CORRECTED_DATA" ) if corcolumn_present == False: print( "Corrected data column is chosen for flagging, but it is not present.\n" ) return else: datacolumn = "corrected" ################################################# # Whether data column is present or not ################################################# if datacolumn == "data" or datacolumn == "DATA": datacolumn_present = check_datacolumn_valid(msname, datacolumn="DATA") if datacolumn_present == False: print("Data column is chosen for flagging, but it is not present.\n") return else: datacolumn = "data" ########################### # Determinign time chunking ############################ if use_tfcrop or use_rflag: nchunk = get_chunk_size(msname, memory_limit=memory_limit) if nchunk <= 1: ntime = "scan" else: msmd = msmetadata() msmd.open(msname) scan = np.unique(msmd.scannumbers())[0] times = msmd.timesforspws(0) msmd.close() total_time = np.nanmax(times) - np.nanmin(times) timeres = np.nanmin(np.diff(times)) ntime = float(total_time / nchunk) if ntime < timeres: ntime = timeres ############## # Tfcrop flag ############## if use_tfcrop: try: with suppress_casa_output(): flagdata( vis=msname, mode="tfcrop", timefit="line", freqfit="line", extendflags=False, flagdimension=flagdimension, timecutoff=5.0, freqcutoff=5.0, extendpols=True, growaround=False, action="apply", flagbackup=False, overwrite=True, writeflags=True, datacolumn=datacolumn, ntime=ntime, ) except: pass ############# # Rflag flag ############# try: with suppress_casa_output(): flagdata( vis=msname, mode="rflag", timefit="line", freqfit="line", extendflags=False, timedevscale=5.0, freqdevscale=5.0, extendpols=True, growaround=False, action="apply", flagbackup=False, overwrite=True, writeflags=True, datacolumn=datacolumn, ntime=ntime, ) except: pass ############## # Extend flag ############## try: with suppress_casa_output(): flagdata( vis=msname, mode="extend", datacolumn="data", clipzeros=True, extendflags=False, extendpols=True, growtime=80.0, growfreq=80.0, growaround=False, flagneartime=False, flagnearfreq=False, action="apply", flagbackup=False, overwrite=True, writeflags=True, ntime=ntime, ) except: pass except Exception as e: traceback.print_exc() finally: time.sleep(5) drop_cache(msname) return
[docs] def do_flagging( msname, datacolumn="data", flag_bad_ants=True, flag_bad_spw=True, use_tfcrop=True, use_rflag=False, flagdimension="freqtime", flag_autocorr=True, flag_backup=True, cpu_frac=0.8, mem_frac=0.8, ): """ Function to perform initial flagging Parameters ---------- msname : str Name of the ms datacolumn : str, optional Data column flag_bad_ants : bool, optional Flag bad antennas flag_bad_spw : bool, optional Flag bad channels use_tfcrop : bool, optional Use tfcrop or not use_rflag : bool, optional Use rflag or not flagdimension : str, optional Flag dimension (only for tfcrop) flag_autocorr : bool,optional Flag auto-correlations flag_backup : bool, optional Flag backup cpu_frac : float, optional CPU fraction to use mem_frac : float, optional Memory fraction to use Returns ------- int Success message """ start_time = time.time() try: from casatasks import flagdata msname = msname.rstrip("/") mspath = os.path.dirname(os.path.abspath(msname)) os.chdir(mspath) print("###########################") print("Flagging measurement set : ", msname) print("###########################\n") total_cpus = psutil.cpu_count(logical=True) ncpu = int(total_cpus * cpu_frac) correct_missing_col_subms(msname) print("Restoring all previous flags...") with suppress_casa_output(): flagdata(vis=msname, mode="unflag", spw="0", flagbackup=False) fluxcal_field, fluxcal_scans = get_fluxcals(msname) if len(fluxcal_field) == 0: flag_bad_spw = False flag_bad_ants = False if flag_bad_spw: badspw = get_bad_chans(msname) else: bandspw = "" if flag_bad_ants: bad_ants, bad_ants_str = get_bad_ants(msname, fieldnames=fluxcal_field) else: bad_ants_str = "" ########################### # Dask local cluster setup ########################## if os.path.exists(msname + "/SUBMSS"): subms_list = glob.glob(msname + "/SUBMSS/*") for subms in subms_list: os.system(f"rm -rf {subms}/.flagversions") else: subms_list = [msname] task = delayed(single_ms_flag)(dry_run=True) mem_limit = run_limited_memory_task(task, dask_dir=mspath) dask_client, dask_cluster, n_jobs, n_threads, mem_limit = get_dask_client( len(subms_list), dask_dir=mspath, cpu_frac=cpu_frac, mem_frac=mem_frac, min_mem_per_job=mem_limit / 0.6, ) if flag_backup: do_flag_backup(msname, flagtype="flagdata") tasks = [ delayed(single_ms_flag)( ms, badspw=badspw, bad_ants_str=bad_ants_str, datacolumn=datacolumn, use_tfcrop=use_tfcrop, use_rflag=use_rflag, flagdimension=flagdimension, flag_autocorr=flag_autocorr, n_threads=n_threads, memory_limit=mem_limit, ) for ms in subms_list ] results = compute(*tasks) dask_client.close() dask_cluster.close() print("##################") print("Total time taken : ", time.time() - start_time) print("##################\n") return 0 except Exception as e: traceback.print_exc() print("##################") print("Total time taken : " + str(time.time() - start_time) + "s") print("##################\n") return 1 finally: time.sleep(5) drop_cache(msname)
[docs] def main(): usage = "Initial flagging of calibrator data" parser = argparse.ArgumentParser( description=usage, formatter_class=SmartDefaultsHelpFormatter ) ## Essential parameters basic_args = parser.add_argument_group( "###################\nEssential parameters\n###################" ) basic_args.add_argument("msname", type=str, help="Name of measurement set") basic_args.add_argument( "--workdir", type=str, default="", help="Name of work directory" ) basic_args.add_argument( "--datacolumn", type=str, default="DATA", help="Name of the datacolumn" ) adv_args = parser.add_argument_group( "###################\nAdvanced parameters\n###################" ) # Advanced switches adv_args.add_argument( "--no_flag_bad_ants", dest="flag_bad_ants", action="store_false", help="Do not flag bad antennas", ) adv_args.add_argument( "--no_flag_bad_spw", dest="flag_bad_spw", action="store_false", help="Do not flag bad spectral windows", ) adv_args.add_argument( "--use_tfcrop", action="store_true", help="Use tfcrop flagging" ) adv_args.add_argument("--use_rflag", action="store_true", help="Use rflag flagging") adv_args.add_argument( "--no_flag_autocorr", dest="flag_autocorr", action="store_false", help="Do not flag auto-correlations", ) adv_args.add_argument( "--no_flagbackup", dest="flagbackup", action="store_false", help="Do not backup flags", ) ## Resource management parameters hard_args = parser.add_argument_group( "###################\nHardware resource management parameters\n###################" ) hard_args.add_argument( "--flagdimension", type=str, default="freqtime", help="Flag dimension" ) hard_args.add_argument( "--cpu_frac", type=float, default=0.8, help="CPU fraction to use" ) hard_args.add_argument( "--mem_frac", type=float, default=0.8, help="Memory fraction to use" ) hard_args.add_argument("--logfile", type=str, default=None, help="Log file") hard_args.add_argument("--jobid", type=int, default=0, help="Job ID") if len(sys.argv) == 1: parser.print_help(sys.stderr) sys.exit(1) args = parser.parse_args() pid = os.getpid() save_pid(pid, datadir + f"/pids/pids_{args.jobid}.txt") if args.workdir == "" or not os.path.exists(args.workdir): workdir = os.path.dirname(os.path.abspath(args.msname)) + "/workdir" else: workdir = args.workdir os.makedirs(workdir, exist_ok=True) observer = None if os.path.exists(f"{workdir}/jobname_password.npy") and args.logfile: time.sleep(5) jobname, password = np.load( f"{workdir}/jobname_password.npy", allow_pickle=True ) if os.path.exists(args.logfile): observer = init_logger( "do_flagging", args.logfile, jobname=jobname, password=password ) try: if args.msname and os.path.exists(args.msname): msg = do_flagging( args.msname, datacolumn=args.datacolumn, flag_bad_ants=args.flag_bad_ants, flag_bad_spw=args.flag_bad_spw, use_tfcrop=args.use_tfcrop, use_rflag=args.use_rflag, flagdimension=args.flagdimension, flag_autocorr=args.flag_autocorr, flag_backup=args.flagbackup, cpu_frac=args.cpu_frac, mem_frac=args.mem_frac, ) else: print("Please provide correct measurement set.\n") msg = 1 except Exception as e: traceback.print_exc() msg = 1 finally: time.sleep(5) drop_cache(args.msname) drop_cache(args.workdir) clean_shutdown(observer) return msg
if __name__ == "__main__": result = main() print(f"Final msg : {result}") print("\n###################\nBasic flagging is finished.\n###################\n") os._exit(result)