Source code for shenfun.io

import sys
import os
import numpy as np
from mpi4py import MPI
from mpi4py_fft.io import NCFile, HDF5File
from .generate_xdmf import generate_xdmf

__all__ = ['HDF5File', 'NCFile', 'ShenfunFile', 'Checkpoint', 'generate_xdmf']

comm = MPI.COMM_WORLD

[docs] def ShenfunFile(name, T, backend='hdf5', mode='r', mesh='quadrature', **kw): """Return a file handler Parameters ---------- name : str Name of file, without ending T : :class:`.TensorProductSpace` The space used for the data stored. Can also be :class:`.CompositeSpace` or :class:`.VectorSpace`. backend : str, optional ``hdf5`` or ``netcdf4``. Default is ``hdf5``. mode : str, optional ``r`` or ``w``. Default is ``r``. mesh : str, optional - 'quadrature' - use quadrature mesh of self - 'uniform' - use uniform mesh for non-periodic bases Returns ------- Class instance Instance of either :class:`.HDF5File` or :class:`.NCFile` """ if backend.lower() == 'hdf5': return HDF5File(name+'.h5', domain=[np.squeeze(d) for d in T.mesh(kind=mesh)], mode=mode, **kw) assert kw.get('forward_output', False) is False, "NetCDF4 cannot store complex arrays, use HDF5" return NCFile(name+'.nc', domain=[np.squeeze(d) for d in T.mesh(kind=mesh)], mode=mode, **kw)
[docs] class Checkpoint: """Class for checkpointing simulations Checkpoint data are used to store intermediate simulation results, and can be used to restart a simulation at a later stage, with no loss of accuracy. Data is provided as dictionaries. The checkpoint dictionary is represented as:: data = { '0': {'U': [U_hat], 'phi': [phi_hat]}, '1': {'U': [U0_hat], 'phi': [phi_hat0]}, ... } which contains solutions to be stored at possibly several different timesteps. The current timestep is 0, previous is 1 and so on if more is needed by the integrator. Note that checkpoint is storing results from spectral space. """ def __init__(self, filename, checkevery=10, data={}): self.f = None self.filename = filename self.data = data self.checkevery = checkevery
[docs] def open(self, mode='r+'): import h5py self.f = h5py.File(self.filename+'.chk.h5', mode, driver="mpio", comm=comm)
[docs] def close(self): if self.f: self.f.close()
[docs] def update(self, t, tstep): if self.f is None: self.open(mode='w') self.f.attrs.create('tstep', 0) self.f.attrs.create('t', 0.0) self.close() kill = self.check_if_kill() if tstep % self.checkevery == 0 or kill: if comm.Get_rank() == 0: # for safety os.system(f'cp {self.filename}.chk.h5 {self.filename}.old.chk.h5') self.open() for key, val in self.data.items(): self.write(int(key), val) self.f.attrs['tstep'] = tstep self.f.attrs['t'] = t self.close() if kill: sys.exit(1)
[docs] def write(self, step, d): for name, val in d.items(): self.f.require_group(name) for u in val: s = u.local_slice() self.f[name].require_dataset(str(step), shape=u.global_shape, dtype=u.dtype) self.f["/".join((name, str(step)))][s] = u
[docs] def read(self, u, name, **kw): step = kw.get('step', 0) self.open() s = u.local_slice() dset = "/".join((name, str(step))) u[:] = self.f[dset][s] self.close()
[docs] @staticmethod def check_if_kill(): """Check if user has put a file named killshenfun in running folder.""" found = 0 if 'killshenfun' in os.listdir(os.getcwd()): found = 1 collective = comm.allreduce(found) if collective > 0: if comm.Get_rank() == 0: os.remove('killshenfun') print('killshenfun Found! Stopping simulations cleanly by checkpointing...') return True else: return False