Source code for meersolar.pipeline.do_target_split

import numpy as np, os, time, traceback, gc, argparse
from meersolar.pipeline.basic_func import *
from casatasks import casalog
from dask import delayed, compute

try:
    logfile = casalog.logfile()
    os.system("rm -rf " + logfile)
except:
    pass


[docs] def chanlist_to_str(lst): lst = sorted(lst) ranges = [] start = lst[0] for i in range(1, len(lst)): if lst[i] != lst[i - 1] + 1: ranges.append(f"{start}~{lst[i - 1]}") start = lst[i] ranges.append(f"{start}~{lst[-1]}") return ";".join(ranges)
[docs] def split_scan( msname="", outputvis="", scan="", width="", timebin="", datacolumn="", spw="", corr="", timerange="", n_threads=-1, dry_run=False, ): """ Split a single target scan Parameters ---------- msname : str Measurement set outputvis : str Output measurement set scan : int Scan number width : int Channel width timebin : str Timebin width datacolumn : str Datacolumn to split spw : str, optional Spectral window to split corr : str, optional Correlation to split timerange : str, optional Time range to split n_threads : int, optional Number of OpenMP threads Returns ------- str Splited measurement set """ limit_threads(n_threads=n_threads) from casatasks import split, flagdata, initweights from casatools import msmetadata msmd = msmetadata() if dry_run: process = psutil.Process(os.getpid()) mem = round(process.memory_info().rss / 1024**3, 2) # in GB return mem msmd.open(msname) fields = msmd.fieldsforscan(int(scan)) msmd.close() del msmd fields_str = "" for f in fields: fields_str += str(f) + "," fields_str = fields_str[:-1] if os.path.exists(f"{outputvis}/.splited") == False: if os.path.exists(outputvis): os.system("rm -rf " + outputvis) if os.path.exists(outputvis + ".flagversions"): os.system("rm -rf " + outputvis + ".flagversions") print(f"Spliting scan : {scan} of ms: {msname}\n") print( f"split(vis='{msname}',outputvis='{outputvis}',field='{fields_str}',scan='{scan}',spw='{spw}',correlation='{corr}',timerange='{timerange}',width={width},timebin='{timebin}',datacolumn='{datacolumn}')\n" ) with suppress_casa_output(): split( vis=msname, outputvis=outputvis, field=fields_str, correlation=corr, scan=scan, spw=spw, timerange=timerange, width=width, timebin=timebin, datacolumn=datacolumn, ) ########################################## # Initiate proper weighting ########################################## print("Initiating weights ....") with suppress_casa_output(): initweights(vis=outputvis, wtmode="ones", dowtsp=True) flagdata( vis=outputvis, mode="clip", clipzeros=True, datacolumn="data", flagbackup=False, ) os.system(f"touch {outputvis}/.splited") return outputvis
[docs] def split_target_scans( msname, workdir, timeres, freqres, datacolumn, spw="", spectral_chunk=-1, n_spectral_chunk=-1, scans=[], prefix="targets", fullpol=False, time_interval=-1, time_window=-1, quack_timestamps=-1, merge_spws=False, cpu_frac=0.8, mem_frac=0.8, max_cpu_frac=0.8, max_mem_frac=0.8, ): """ Split target scans Parameters ---------- msname : str Measurement set workdir : str Work directory timeres : float Time resolution in seconds freqres : float Frequency resolution in MHz datacolumn : str Data column to split spw : str, optional Spectral window spectral_chunk : float, optional Spectral chunk in MHz n_spectral_chunk : int, optional Number of spectral chunks to split from the beginning scans : list Scan list to split prefix : str, optional Splited ms prefix fullpol : bool, optional Full polar split time_interval : float Time interval in seconds time_window : float Time window in seconds quack_timestamps : int, optional Number of timestamps ignored at the start and end of each scan merge_spws : bool, optional Merge spectral window ranges cpu_frac : float, optional CPU fraction to use mem_frac : float, optional Memory fraction to use max_cpu_frac : float, optional Maximum CPU fraction to use max_mem_frac : float, optional Maximum memory fraction to use Returns ------- list Splited ms list """ start_time = time.time() try: os.chdir(workdir) print(f"Spliting ms : {msname}") target_scans, cal_scans, f_scans, g_scans, p_scans = get_cal_target_scans( msname ) valid_scans = get_valid_scans(msname) filtered_scan_list = [] for scan in target_scans: if scan in valid_scans: if len(scans) == 0 or (len(scans) > 0 and scan in scans): filtered_scan_list.append(scan) filtered_scan_list = sorted(filtered_scan_list) ####################################### # Extracting time frequency information ####################################### from casatools import msmetadata msmd = msmetadata() msmd.open(msname) chanres = msmd.chanres(0, unit="MHz")[0] freqs = msmd.chanfreqs(0, unit="MHz") bw = max(freqs) - min(freqs) nchan = msmd.nchan(0) msmd.close() if freqres > 0: # Image resolution is in MHz chanwidth = int(freqres / chanres) if chanwidth < 1: chanwidth = 1 else: chanwidth = 1 if timeres > 0: # Image resolution is in seconds timebin = str(timeres) + "s" else: timebin = "" if fullpol == False: corr = "XX,YY" else: corr = "" ############################# # Making spectral chunks ############################# bad_spws = get_bad_chans(msname).split("0:")[-1].split(";") good_spws = [] for i in range(len(bad_spws) - 1): start_chan = int(bad_spws[i].split("~")[-1]) + 1 end_chan = int(bad_spws[i + 1].split("~")[0]) - 1 good_spws.append(f"{start_chan}~{end_chan}") if spw != "": good_spws = "0:" + ";".join(good_spws) common_spws = get_common_spw(good_spws, spw) good_spws = common_spws.split("0:")[-1].split(";") chanlist = [] if spectral_chunk > 0: if spectral_chunk > bw: print( f"Given spectral chunk: {spectral_chunk} is more than total bandwidth: {bw} MHz." ) spectral_chunk = bw nchan_per_chunk = max(1, int(spectral_chunk / chanres)) good_channels = [] for good_spw in good_spws: start_chan = int(good_spw.split("~")[0]) end_chan = int(good_spw.split("~")[-1]) for s in range(start_chan, end_chan): good_channels.append(s) channel_chunks = split_into_chunks(good_channels, nchan_per_chunk) for chunk in channel_chunks: chan_str = chanlist_to_str(chunk) if chan_str not in chanlist: chanlist.append(chan_str) if n_spectral_chunk > 0: indices = np.linspace( 0, len(chanlist) - 1, num=n_spectral_chunk, dtype=int ) chanlist = [chanlist[i] for i in indices] else: chan_range = "" for good_spw in good_spws: s = int(good_spw.split("~")[0]) e = int(good_spw.split("~")[-1]) chan_range += f"{s}~{e};" chan_range = chan_range[:-1] if chan_range not in chanlist: chanlist.append(chan_range) if merge_spws: temp_spw = ";".join(chanlist) chanlist = [temp_spw] print(f"Spliting channel blocks : {chanlist}") ################################## # Parallel spliting ################################## if len(chanlist) > 0: total_chunks = len(chanlist) * len(filtered_scan_list) else: total_chunks = len(filtered_scan_list) ############################################# # Memory limit ############################################# task = delayed(split_scan)(dry_run=True) mem_limit = run_limited_memory_task(task, dask_dir=workdir) ####################### dask_client, dask_cluster, max_n_jobs, n_threads, mem_limit = get_dask_client( total_chunks, dask_dir=workdir, cpu_frac=max_cpu_frac, mem_frac=max_mem_frac, min_mem_per_job=mem_limit / 0.6, only_cal=True, ) tasks = [] splited_ms_list = [] for scan in filtered_scan_list: timerange_list = get_timeranges_for_scan( msname, scan, time_interval, time_window, quack_timestamps=quack_timestamps, ) timerange = ",".join(timerange_list) for chanrange in chanlist: chanrange_str = ( chanrange.split(";")[0].split("~")[0] + "~" + chanrange.split(";")[-1].split("~")[-1] ) outputvis = f"{workdir}/{prefix}_scan_{scan}_spw_{chanrange_str}.ms" if os.path.exists(f"{outputvis}/.splited"): print(f"{outputvis} is already splited successfully.") splited_ms_list.append(outputvis) else: task = delayed(split_scan)( msname, outputvis, scan, chanwidth, timebin, datacolumn, corr=corr, timerange=timerange, spw="0:" + chanrange, n_threads=n_threads, ) tasks.append(task) ##################################### # Adaptive dask client ##################################### if cpu_frac == max_cpu_frac and mem_frac == max_mem_frac: total_chunks = len(tasks) if total_chunks > 0: dask_client, dask_cluster, n_jobs, n_threads, mem_limit = ( get_dask_client( total_chunks, dask_dir=workdir, cpu_frac=cpu_frac, mem_frac=mem_frac, min_mem_per_job=mem_limit / 0.6, ) ) results = compute(*tasks) dask_client.close() dask_cluster.close() for r in results: splited_ms_list.append(r) else: while True: total_chunks = len(tasks) if total_chunks == 0: break else: dask_client, dask_cluster, n_jobs, n_threads, mem_limit = ( get_dask_client( total_chunks, dask_dir=workdir, cpu_frac=cpu_frac, mem_frac=mem_frac, min_mem_per_job=mem_limit / 0.6, ) ) chunk_tasks = tasks[0 : min(n_jobs, max_n_jobs)] for ctask in chunk_tasks: tasks.remove(ctask) results = compute(*chunk_tasks) dask_client.close() dask_cluster.close() for r in results: splited_ms_list.append(r) n_current_process = ( get_nprocess_meersolar(workdir) - 1 ) # One is subtracted for the current process if len(tasks) == 0: break elif n_current_process == 0: available_cpu_frac = round( (100 - psutil.cpu_percent(interval=1)) / 100.0, 2 ) available_mem_frac = round( psutil.virtual_memory().available / psutil.virtual_memory().total, 2, ) cpu_frac = min(max_cpu_frac, max(cpu_frac, available_cpu_frac)) mem_frac = min(max_mem_frac, max(mem_frac, available_mem_frac)) print( f"Updated CPU fraction: {cpu_frac}, memory fraction: {mem_frac}." ) print("##################") print("Spliting of target scans are done successfully.") print("Total time taken : ", time.time() - start_time) print("##################\n") return 0, splited_ms_list except Exception as e: traceback.print_exc() print("##################") print("Spliting of target scans are unsuccessful.") print("Total time taken : ", time.time() - start_time) print("##################\n") return 1, [] finally: time.sleep(5) drop_cache(msname) drop_cache(workdir)
[docs] def main(): parser = argparse.ArgumentParser( description="Split target scans", formatter_class=SmartDefaultsHelpFormatter ) ## Essential parameters basic_args = parser.add_argument_group( "###################\nEssential parameters\n###################" ) basic_args.add_argument( "msname", type=str, help="Name of measurement set (required positional argument)", ) basic_args.add_argument( "--workdir", type=str, default="", help="Name of work directory", ) ## Advanced parameters adv_args = parser.add_argument_group( "###################\nAdvanced parameters\n###################" ) adv_args.add_argument( "--datacolumn", type=str, default="data", help="Data column to split", ) adv_args.add_argument( "--spw", type=str, default="", help="Spectral window to split" ) adv_args.add_argument( "--scans", type=str, default="", help="Target scan list (default: all)", metavar="SCANS (Comma seperated)", ) adv_args.add_argument( "--time_window", type=float, default=-1, help="Time window in seconds", metavar="Float", ) adv_args.add_argument( "--time_interval", type=float, default=-1, help="Time interval in seconds", metavar="Float", ) adv_args.add_argument( "--quack_timestamps", type=int, default=-1, help="Time stamps to ignore at the start and end of the each scan", metavar="Integer", ) adv_args.add_argument( "--spectral_chunk", type=float, default=-1, help="Spectral chunk in MHz", metavar="Float", ) adv_args.add_argument( "--n_spectral_chunk", type=int, default=-1, help="Numbers of spectral chunks to split", metavar="Integer", ) adv_args.add_argument( "--freqres", type=float, default=-1, help="Frequency to average in MHz", metavar="Float", ) adv_args.add_argument( "--timeres", type=float, default=-1, help="Time bin to average in seconds", metavar="Float", ) adv_args.add_argument( "--prefix", type=str, default="targets", help="Splited ms prefix name", ) adv_args.add_argument( "--split_fullpol", action="store_true", help="Split full polar data" ) adv_args.add_argument( "--merge_spws", action="store_true", help="Merge spectral windows" ) ## Resource management parameters hard_args = parser.add_argument_group( "###################\nHardware resource management parameters\n###################" ) hard_args.add_argument( "--cpu_frac", type=float, default=0.8, help="CPU fraction to use", metavar="Float", ) hard_args.add_argument( "--mem_frac", type=float, default=0.8, help="Memory fraction to use", metavar="Float", ) hard_args.add_argument( "--max_cpu_frac", type=float, default=0.8, help="Maximum CPU fraction to use", metavar="Float", ) hard_args.add_argument( "--max_mem_frac", type=float, default=0.8, help="Maximum memory fraction to use", metavar="Float", ) hard_args.add_argument( "--logfile", type=str, default=None, help="Log file" ) hard_args.add_argument( "--jobid", type=int, default=0, help="Job ID" ) if len(sys.argv) == 1: parser.print_help(sys.stderr) sys.exit(1) args = parser.parse_args() pid = os.getpid() save_pid(pid, datadir + f"/pids/pids_{args.jobid}.txt") workdir = ( args.workdir if args.workdir and os.path.exists(args.workdir) else os.path.dirname(os.path.abspath(args.msname)) + "/workdir" ) os.makedirs(workdir,exist_ok=True) logfile = args.logfile observer = None if os.path.exists(f"{workdir}/jobname_password.npy") and logfile is not None: time.sleep(5) jobname, password = np.load( f"{workdir}/jobname_password.npy", allow_pickle=True ) if os.path.exists(logfile): observer = init_logger( "do_target_split", logfile, jobname=jobname, password=password ) try: if args.msname and os.path.exists(args.msname): print("\n###################################") print("Start spliting target scans.") print("###################################\n") scans = [int(i) for i in args.scans.split(",")] if args.scans else [] msg, final_target_mslist = split_target_scans( args.msname, workdir, float(args.timeres), float(args.freqres), args.datacolumn, spw=str(args.spw), time_window=float(args.time_window), time_interval=float(args.time_interval), quack_timestamps=int(args.quack_timestamps), scans=scans, fullpol=args.split_fullpol, n_spectral_chunk=int(args.n_spectral_chunk), prefix=args.prefix, merge_spws=args.merge_spws, spectral_chunk=float(args.spectral_chunk), cpu_frac=float(args.cpu_frac), mem_frac=float(args.mem_frac), max_cpu_frac=float(args.max_cpu_frac), max_mem_frac=float(args.max_mem_frac), ) else: print("Please provide correct measurement set.\n") msg = 1 except Exception as e: traceback.print_exc() msg = 1 finally: time.sleep(5) drop_cache(args.msname) drop_cache(args.workdir) clean_shutdown(observer) return msg
if __name__ == "__main__": result = main() print( "\n###################\nSpliting target scans are done.\n###################\n" ) os._exit(result)