Source code for meersolar.utils.proc_manage_utils

import types
import resource
import psutil
import dask
import numpy as np
import warnings
import gc
import logging
import time
import glob
import os
import subprocess
import sys
import tempfile
import shutil
from pathlib import Path
from dask import delayed, config
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SLURMCluster
from datetime import datetime as dt, timedelta
from .basic_utils import *


#################################
# Process management
#################################
[docs] def get_nprocess_solarpipe(jobid): """ Get numbers of processes currently running Parameters ---------- workdir : str Work directory name jobid : int Job ID Returns ------- int Number of running processes """ cachedir = get_cachedir() pid_file = f"{cachedir}/pids/pids_{jobid}.txt" pids = np.loadtxt(pid_file, unpack=True) n_process = 0 for pid in pids: if psutil.pid_exists(int(pid)): n_process += 1 return n_process
[docs] def get_jobid(): """ Get Job ID with millisecond-level uniqueness. Returns ------- int Job ID in the format YYYYMMDDHHMMSSmmm (milliseconds) """ cachedir = get_cachedir() jobid_file = os.path.join(cachedir, "jobids.txt") if os.path.exists(jobid_file): prev_jobids = np.loadtxt(jobid_file, unpack=True, dtype="int64") if prev_jobids.size == 0: prev_jobids = [] elif prev_jobids.size == 1: prev_jobids = [str(prev_jobids)] else: prev_jobids = [str(jid) for jid in prev_jobids] else: prev_jobids = [] if len(prev_jobids) > 0: FORMAT = "%Y%m%d%H%M%S%f" CUTOFF = dt.utcnow() - timedelta(days=15) filtered_prev_jobids = [] for job_id in prev_jobids: job_time = dt.strptime(job_id.ljust(20, "0"), FORMAT) # pad if truncated if job_time >= CUTOFF or job_id == 0: # Job ID 0 is always kept filtered_prev_jobids.append(job_id) prev_jobids = filtered_prev_jobids now = dt.utcnow() cur_jobid = ( now.strftime("%Y%m%d%H%M%S") + f"{int(now.microsecond/1000):03d}" ) # ms = first 3 digits of microseconds prev_jobids.append(cur_jobid) job_ids_int = np.array(prev_jobids, dtype=np.int64) np.savetxt(jobid_file, job_ids_int, fmt="%d") return int(cur_jobid)
[docs] def save_main_process_info(pid, jobid, msname, workdir, outdir, cpu_frac, mem_frac): """ Save main processes info Parameters ---------- pid : int Main job process id jobid : int Job ID msname : str Main measurement set workdir : str Work directory outdir : str Output directory cpu_frac : float CPU fraction of the job mem_frac : float Memory fraction of the job Returns ------- str Job info file name """ cachedir = get_cachedir() prev_main_pids = glob.glob(f"{cachedir}/main_pids_*.txt") prev_jobids = [ str(os.path.basename(i).rstrip(".txt").split("main_pids_")[-1]) for i in prev_main_pids ] if len(prev_jobids) > 0: FORMAT = "%Y%m%d%H%M%S%f" CUTOFF = dt.utcnow() - timedelta(days=15) filtered_prev_jobids = [] for i in range(len(prev_jobids)): job_id = prev_jobids[i] job_time = dt.strptime(job_id.ljust(20, "0"), FORMAT) # pad if truncated if job_time < CUTOFF or job_id == 0: # Job ID 0 is always kept filtered_prev_jobids.append(job_id) else: os.system(f"rm -rf {prev_main_pids[i]}") if os.path.exists(f"{cachedir}/pids/pids_{job_id}.txt"): os.system(f"rm -rf {cachedir}/pids/pids_{job_id}.txt") main_job_file = f"{cachedir}/main_pids_{jobid}.txt" main_str = f"{jobid} {pid} {msname} {workdir} {outdir} {cpu_frac} {mem_frac}" with open(main_job_file, "w") as f: f.write(main_str) return main_job_file
[docs] def save_pid(pid, pid_file): """ Save PID Parameters ---------- pid : int Process ID pid_file : str File to save """ if os.path.exists(pid_file): pids = np.loadtxt(pid_file, unpack=True, dtype="int") pids = np.append(pids, pid) else: pids = np.array([int(pid)]) np.savetxt(pid_file, pids, fmt="%d")
[docs] def generate_activate_env(outfile="activate_env.sh"): """ Generate a shell script that activates the current Python environment. This works for both Conda and virtualenv environments and is safe for use in non-interactive shells (e.g., Slurm batch jobs) by explicitly sourcing `conda.sh`. If conda is not found in $PATH, it will try loading either `anaconda` or `anaconda3` module. Parameters ---------- outfile : str Path to the shell script to write (default: ./activate_env.sh). Returns ------- str Output file name """ outfile = Path(outfile).expanduser().resolve() putfile = os.path.abspath(outfile) lines = ["#!/bin/bash", ""] def module_exists(name): """Check if a module exists using 'module avail'.""" try: subprocess.run( ["module", "avail", name], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False, ) return True except Exception: return False # Conda-based environment if "CONDA_DEFAULT_ENV" in os.environ: conda_env = os.environ["CONDA_DEFAULT_ENV"] lines.append("# === Activate Conda Environment Safely ===") lines.append("if ! command -v conda >/dev/null 2>&1; then") if module_exists("anaconda"): lines.append(" module load anaconda") elif module_exists("anaconda3"): lines.append(" module load anaconda3") else: lines.append(" echo 'No Conda module found (anaconda or anaconda3)'") lines.append(" exit 1") lines.append("fi") lines.append("source $(conda info --base)/etc/profile.d/conda.sh") lines.append(f"conda activate {conda_env}") # Virtualenv-based environment elif "VIRTUAL_ENV" in os.environ: venv_path = os.environ["VIRTUAL_ENV"] lines.append("# === Activate Virtualenv ===") lines.append(f"source {venv_path}/bin/activate") else: python_path = sys.executable lines.append( "# === No Conda/Virtualenv Detected — Using current Python directly ===" ) lines.append(f"echo 'No Conda or virtualenv detected; using: {python_path}'") lines.append(f"export PATH={os.path.dirname(python_path)}:$PATH") # Write file with open(outfile, "w") as f: f.write("\n".join(lines) + "\n") os.chmod(outfile, 0o755) print(f"Created activation script at: {outfile}") return outfile
[docs] def get_total_worker(cluster): """ Get total workers in the cluster Parameters ---------- cluster : dask.cluster Dask cluster Returns ------- int Number of workers """ return len(cluster.workers)
[docs] def scale_worker_and_wait(dask_cluster, nworker, timeout=60, poll_interval=1): """ Scale worker and wait until it is done Parameters ---------- dask_cluster : dask.cluster Dask cluster nworker : int Number of worker timeout : float, optional Timeout, show a warning and move poll_interval : float, optional Check interval in seconds """ print(f"Start scaling to {nworker} workers") dask_cluster.scale(nworker) timeout = 60 c = 0 while c < timeout: if get_total_worker(dask_cluster) == nworker: print(f"Successfully scaled to {nworker} workers") return 0 else: time.sleep(poll_interval) c += poll_interval print(f"Dask cluster did not scale to {nworker} within {timeout} seconds.") return 1
[docs] def wait_for_dask_workers(client, min_worker=1, timeout=60): """ Wait until the Dask cluster has a minimum number of total and/or new workers. Parameters ---------- client : dask.distributed.Client Dask client min_worker : int, optional Minimum new connected workers (default: 1) timeout : float, optional Maximum time to wait in seconds (default: 60) Raises ------ TimeoutError If the required number of workers do not connect in time. """ client.wait_for_workers(n_workers=min_worker, timeout=timeout)
[docs] def get_scheduler_name(): """ Get job scheduler available Returns ------- str Scheduler name (local, pbs, slurm) """ if shutil.which("sbatch"): return "slurm" elif shutil.which("bsub"): return "lsf" elif shutil.which("qhost"): return "sge" elif shutil.which("qsub"): return "pbs" elif shutil.which("condor_submit"): return "htcondor" elif shutil.which("msub"): return "mab" elif shutil.which("oarsub"): return "oar" else: return "local"
[docs] def get_local_dask_cluster( njobs, dask_dir, cpu_frac=0.8, mem_frac=0.8, ncpu=-1, mem=-1, spill_frac=0.7, verbose=True, ): """ Create a local Dask cluster Parameters ---------- njobs : int Number of MS tasks (ideally = number of MS files) dask_dir : str Dask temporary directory cpu_frac : float, optional Fraction of total CPUs to use mem_frac : float, optional Fraction of total memory to use ncpu : int, optional Number of CPUs to use (if specified, cpu_frac will be ignored) mem : float, optional Memory in GB to use (if specified, mem_frac will be ignored) spill_frac : float, optional Spill to disk at this fraction verbose : bool, optional Verbose (details of cluster) Returns ------- client : dask.distributed.Client Dask client cluster : dask.distributed.LocalCluster Dask cluster str Dask directory """ logging.getLogger("distributed").setLevel(logging.ERROR) print("Creating local cluster on the current node.") # Set up Dask working directories dask_dir = os.path.join(dask_dir.rstrip("/"), f"dask_{int(time.time())}") dask_dir_tmp = os.path.join(dask_dir, "tmp") os.makedirs(dask_dir_tmp, exist_ok=True) total_cpus = psutil.cpu_count(logical=True) total_mem = psutil.virtual_memory().total / 1024**3 # In GB # Override fractions if ncpu or mem is provided if ncpu > 0: cpu_frac = min(ncpu / total_cpus, 0.8) if mem > 0: mem_frac = min(mem / total_mem, 0.8) cpu_frac = min(cpu_frac, 0.8) mem_frac = min(mem_frac, 0.8) usable_mem = total_mem * mem_frac usable_cpus = int(total_cpus * cpu_frac) n_workers = max(1, usable_cpus) # Raise file descriptor limit soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) if soft < int(hard * 0.8): resource.setrlimit(resource.RLIMIT_NOFILE, (int(hard * 0.8), hard)) dask.config.set( { "temporary-directory": dask_dir, "distributed.worker.memory.target": spill_frac, "distributed.worker.memory.spill": spill_frac + 0.1, "distributed.worker.memory.pause": spill_frac + 0.2, "distributed.worker.memory.terminate": spill_frac + 0.25, } ) cluster = LocalCluster( n_workers=1, threads_per_worker=1, memory_limit=f"{usable_mem}GB", local_directory=dask_dir, dashboard_address=":0", processes=True, env={ "TMPDIR": dask_dir_tmp, "TMP": dask_dir_tmp, "TEMP": dask_dir_tmp, "DASK_TEMPORARY_DIRECTORY": dask_dir_tmp, "MALLOC_TRIM_THRESHOLD_": "0", "PYTHONWARNINGS": "ignore::UserWarning:contextlib", }, ) client = Client(cluster, heartbeat_interval="5s") client.run_on_scheduler(gc.collect) if verbose: print("####################################################") print(f"Dask dashboard available at: {client.dashboard_link}") print("####################################################") return client, cluster, dask_dir
[docs] def get_slurm_dask_cluster( njobs, config_yaml, dask_dir, cpu_frac=0.8, mem_frac=0.8, ncpu=-1, mem=-1, spill_frac=0.7, verbose=True, ): """ Launch a SLURMCluster using a YAML configuration and return a connected Dask client. Parameters ---------- njobs : int Number of expected tasks (used for worker scaling) config_yaml : str Path to Dask SLURMCluster YAML configuration dask_dir : str Dask working directory (for temporary files) cpu_frac : float Fraction of total CPUs to use (ignored if ncpu > 0) mem_frac : float Fraction of total RAM to use (ignored if mem > 0) ncpu : int Total CPUs to use (overrides cpu_frac) mem : float Total memory (in GB) to use (overrides mem_frac) spill_frac : float Fraction of memory to spill to disk verbose : bool Print Dask dashboard URL and diagnostics Returns ------- client : dask.distributed.Client Connected Dask client cluster : dask_jobqueue.SLURMCluster SLURM Dask cluster str Dask directory used """ logging.getLogger("distributed").setLevel(logging.ERROR) dask_dir = os.path.join(dask_dir.rstrip("/"), f"dask_{int(time.time())}") dask_dir_tmp = os.path.join(dask_dir, "tmp") os.makedirs(dask_dir_tmp, exist_ok=True) total_cpus = psutil.cpu_count(logical=True) total_mem = psutil.virtual_memory().total / 1024**3 # in GB if ncpu > 0: cpu_frac = min(ncpu / total_cpus, 0.8) if mem > 0: mem_frac = min(mem / total_mem, 0.8) cpu_frac = min(cpu_frac, 0.8) mem_frac = min(mem_frac, 0.8) usable_mem = total_mem * mem_frac usable_cpus = int(total_cpus * cpu_frac) # Raise file descriptor limit soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) if soft < int(hard * 0.8): resource.setrlimit(resource.RLIMIT_NOFILE, (int(hard * 0.8), hard)) dask.config.set( { "temporary-directory": dask_dir, "distributed.worker.memory.target": spill_frac, "distributed.worker.memory.spill": spill_frac + 0.1, "distributed.worker.memory.pause": spill_frac + 0.2, "distributed.worker.memory.terminate": spill_frac + 0.25, } ) # Load cluster config from YAML with open(config_yaml, "r") as f: cluster_config = yaml.safe_load(f) dask.config.update(cluster_config, priority="new") # Initialize SLURM cluster cluster = SLURMCluster( local_directory=dask_dir_tmp, env_extra=[ f"TMPDIR={dask_dir_tmp}", f"TMP={dask_dir_tmp}", f"TEMP={dask_dir_tmp}", f"DASK_TEMPORARY_DIRECTORY={dask_dir_tmp}", "MALLOC_TRIM_THRESHOLD_=0", "PYTHONWARNINGS=ignore::UserWarning:contextlib", ], ) # Scale workers (1 per task/MS file ideally) cluster.scale(njobs) client = Client(cluster, heartbeat_interval="5s") client.run_on_scheduler(gc.collect) if verbose: print("####################################################") print(f"Dask dashboard available at: {client.dashboard_link}") print("####################################################") return client, cluster, dask_dir
# Exposing only functions __all__ = [ name for name, obj in globals().items() if isinstance(obj, types.FunctionType) and obj.__module__ == __name__ ]