Source code for holodeck.librarian.fit_spectra

"""
"""

import argparse
from pathlib import Path

import numpy as np
import h5py
import tqdm

import holodeck as holo
import holodeck.librarian
from holodeck.librarian import lib_utils, FITS_NBINS_PLAW, FITS_NBINS_TURN
from holodeck.constants import YR


def main():
    log = holo.log

    parser = argparse.ArgumentParser()

    parser.add_argument(
        'path', default=None,
        help='library directory to run fits on; must contain the `sam_lib.hdf5` file'
    )
    parser.add_argument(
        '--recreate', '-r', action='store_true', default=False,
        help='recreate/replace existing fits file with a new one.'
    )
    parser.add_argument(
        '--all', '-a', nargs='?', const=True, default=False,
        help=(
            "recursively find all libraries within the given path, and fit them.  "
            "Optional argument is a pattern that all found paths must match, e.g. 'uniform-07'."
        )
    )

    # ---- Run sub-command

    args = parser.parse_args()
    log.debug(f"{args=}")
    path = Path(args.path)

    if args.all is not False:
        pattern = None if args.all is True else args.all
        fit_all_libraries_in_path(path, log, pattern, recreate=args.recreate)
    else:
        fit_library_spectra(path, log, recreate=args.recreate)

    log.debug("done")
    return



[docs] def fit_library_spectra(library_path, log, recreate=False): """Calculate line fits to library spectra using MPI. """ # make sure MPI is working try: from mpi4py import MPI comm = MPI.COMM_WORLD except Exception as err: comm = None holo.log.error(f"failed to load `mpi4py` in {__file__}: {err}") holo.log.error("`mpi4py` may not be included in the standard `requirements.txt` file") holo.log.error("Check if you have `mpi4py` installed, and if not, please install it") raise err # ---- setup path if comm.rank == 0: log.info(f"Fitting library from path {library_path}") library_path = Path(library_path) if library_path.is_dir(): library_path = lib_utils.get_sam_lib_fname(library_path, gwb_only=False) if not library_path.exists() or not library_path.is_file(): err = f"{library_path=} must point to an existing library file!" log.exception(err) raise FileNotFoundError(err) log.debug(f"library path = {library_path}") # ---- check for existing fits file fits_path = lib_utils.get_fits_path(library_path) return_flag = False if fits_path.exists(): lvl = log.INFO if recreate else log.WARNING log.log(lvl, f"library fits already exists: {fits_path}") if recreate: log.log(lvl, "re-fitting data into new file") else: return_flag = True # ---- load library GWB and convert to PSD with h5py.File(library_path, 'r') as library: fobs = library['fobs'][()] psd = holo.utils.char_strain_to_psd(fobs[np.newaxis, :, np.newaxis], library['gwb'][()]) nsamps, nfreqs, nreals = psd.shape log.debug(f"{nsamps=}, {nfreqs=}, {nreals=}") # make a copy of the `psd` in the current shape, so that we can confirm shape manipulations work later on psd_check = psd.copy() # ---- reshape PSD into (N, F) and we will split the N points across all processors # (S, F, R) ==> (S, R, F) psd = np.moveaxis(psd, -1, 1) # (S, R, F) ==> (S*R, F) psd = psd.reshape((-1, nfreqs)) # total number of spectra that will be fit ntot = psd.shape[0] indices = range(ntot) indices = np.random.permutation(indices) indices = np.array_split(indices, comm.size) num_ind_per_proc = [len(ii) for ii in indices] log.info(f"{ntot=} cores={comm.size} || max runs per core = {np.max(num_ind_per_proc)}") else: fobs = None psd = None nsamps = None nfreqs = None nreals = None indices = None return_flag = None # exit if we're not recreating an existing fits file return_flag = comm.bcast(return_flag, root=0) if return_flag: return # distribute quantities to all tasks fobs = comm.bcast(fobs, root=0) psd = comm.bcast(psd, root=0) indices = comm.scatter(indices, root=0) comm.barrier() # select the PSD spectra for each task my_psd = psd[indices] # log.info(f"{my_psd.shape=}") # ---- Run fits nbins_plaw, fits_plaw = fit_spectra_plaw(comm, fobs, my_psd, nbins_list=FITS_NBINS_PLAW) nbins_turn, fits_turn = fit_spectra_turn(comm, fobs, my_psd, nbins_list=FITS_NBINS_TURN) # ---- gather results and save to output comm.barrier() all_indices = comm.gather(indices, root=0) all_fits_plaw = comm.gather(fits_plaw, root=0) all_fits_turn = comm.gather(fits_turn, root=0) all_psd = comm.gather(my_psd, root=0) if comm.rank == 0: # recombine the scatter indices so that we can sort back to the original order of PSD entries indices = np.concatenate(all_indices) # find the ordering to sort indices idx = np.argsort(indices) # re-combine all of the separate arrays, [(N1, ...), (N2, ...), ...] ===> (N1*N2*etc, ...) fits_plaw = np.concatenate(all_fits_plaw, axis=0) fits_turn = np.concatenate(all_fits_turn, axis=0) all_psd = np.concatenate(all_psd, axis=0) # return elements to original order, to match original GWB/PSD fits_plaw = fits_plaw[idx] fits_turn = fits_turn[idx] all_psd = all_psd[idx] # confirm that the resorting worked correctly assert np.all(all_psd == psd) # reshape arrays to convert back to (Samples, Realizations, ...) len_nbins_plaw = len(nbins_plaw) len_nbins_turn = len(nbins_turn) npars_plaw = np.shape(fits_plaw)[-1] npars_turn = np.shape(fits_turn)[-1] # (S*R, B, P) ==> (S, R, B, P) fits_plaw = fits_plaw.reshape(nsamps, nreals, len_nbins_plaw, npars_plaw) fits_turn = fits_turn.reshape(nsamps, nreals, len_nbins_turn, npars_turn) # (S*R, F) ==> (S, R, F) all_psd = all_psd.reshape(nsamps, nreals, nfreqs) # (S, R, F) ==> (S, F, R) all_psd = np.moveaxis(all_psd, 1, -1) # confirm that reshaping worked correctly assert np.all(all_psd == psd_check) # Report how many fits failed fails = np.any(~np.isfinite(fits_plaw), axis=-1) lvl = log.INFO if np.any(fails) else log.DEBUG log.log(lvl, f"Failed to fit {holo.utils.frac_str(fails)} spectra with power-law model") fails = np.any(~np.isfinite(fits_turn), axis=-1) lvl = log.INFO if np.any(fails) else log.DEBUG log.log(lvl, f"Failed to fit {holo.utils.frac_str(fails)} spectra with turn-over model") # --- Save to output file np.savez( fits_path, fobs=fobs, psd=psd, version=holo.librarian.__version__, nbins_plaw=nbins_plaw, fits_plaw=fits_plaw, nbins_turn=nbins_turn, fits_turn=fits_turn, ) log.warning(f"Saved fits to {fits_path} size: {holo.utils.get_file_size(fits_path)}") return
[docs] def fit_all_libraries_in_path(path, log, pattern=None, recreate=False): """Recursively find all `sam_lib.hdf5` files in the given path, and construct spectra fits for them. """ path = Path(path) msg = "" if pattern is None else f" that match pattern {pattern}" log.info(f"fitting all libraries in path {path}" + msg) sub_paths = _find_sam_lib_in_path_tree(path, pattern=pattern) log.info(f"found {len(sub_paths)} sam_lib files") for pp in sub_paths: log.info(f"path: {pp}") fit_library_spectra(pp, log, recreate=recreate) return
def _find_sam_lib_in_path_tree(path, pattern=None): """Recursive method to find `sam_lib.hdf5` files anywhere in the given path. """ if path.is_file(): # if a pattern is given, and it's not in this path, return nothing if (pattern is not None) and (pattern not in str(path)): return [] # if we find the library file, return it if path.name == "sam_lib.hdf5": return [path] return [] # don't recursively follow into these subdirectories if path.name in ['sims', 'logs']: return [] # accumulate paths from all subdirectories path_list = [] for pp in path.iterdir(): path_list += _find_sam_lib_in_path_tree(pp, pattern=pattern) return path_list # ============================================================================== # ==== Fitting Functions ==== # ============================================================================== def _fit_spectra(comm, freqs, psd, nbins_list, fit_npars, fit_func): assert np.ndim(psd) == 2 npoints, nfreqs = np.shape(psd) assert len(freqs) == nfreqs assert np.ndim(nbins_list) == 1 bad_pars = [np.nan] * fit_npars def fit_if_all_finite(xx, yy): if np.any(~np.isfinite(yy)): return bad_pars sel = (yy > 0.0) if np.count_nonzero(sel) < fit_npars: return bad_pars try: pars = fit_func(xx[sel], yy[sel]) except RuntimeError: return bad_pars return pars len_nbins = len(nbins_list) shape_fits = [npoints, len_nbins, fit_npars] fits = np.zeros(shape_fits) failures = 0 iterator = tqdm.trange(npoints) if (comm is None) or (comm.rank == 0) else range(npoints) for ii in iterator: yy = psd[ii, :] for nn, nbin in enumerate(nbins_list): if nbin > nfreqs: raise ValueError(f"Cannot fit for {nbin=} bins, data has {nfreqs=} frequencies!") pars = fit_if_all_finite(freqs[:nbin], yy[:nbin]) fits[ii, nn, :] = pars if not np.isfinite(pars[0]): failures += 1 return fits def fit_spectra_plaw(comm, freqs, psd, nbins_list=FITS_NBINS_PLAW): fit_func = lambda xx, yy: holo.utils.fit_powerlaw_psd(xx, yy, 1/YR)[0] fit_npars = 2 fits = _fit_spectra(comm, freqs, psd, nbins_list, fit_npars, fit_func) return nbins_list, fits def fit_spectra_turn(comm, freqs, psd, nbins_list=FITS_NBINS_TURN): fit_func = lambda xx, yy: holo.utils.fit_turnover_psd(xx, yy, 1/YR)[0] fit_npars = 4 fits = _fit_spectra(comm, freqs, psd, nbins_list, fit_npars, fit_func) return nbins_list, fits