Source code for mpi4py_fft.io.h5py_file

import numpy as np
from mpi4py import MPI
from .file_base import FileBase

__all__ = ('HDF5File',)

comm = MPI.COMM_WORLD

[docs] class HDF5File(FileBase): """Class for reading/writing data to HDF5 format Parameters ---------- h5name : str Name of hdf5 file to be created. domain : sequence, optional An optional spatial mesh or domain to go with the data. Sequence of either - 2-tuples, where each 2-tuple contains the (origin, length) of each dimension, e.g., (0, 2*pi). - Arrays of coordinates, e.g., np.linspace(0, 2*pi, N). One array per dimension. mode : str, optional ``r``, ``w`` or ``a`` for read, write or append. Default is ``a``. kw : dict, optional Optional additional keyword arguments used when creating the file used to store data. """ def __init__(self, h5name, domain=None, mode='a', **kw): FileBase.__init__(self, h5name, domain=domain) import h5py self.f = h5py.File(h5name, mode, driver="mpio", comm=comm, **kw) self.close() def _check_domain(self, group, field): if self.domain is None: self.domain = ((0, 2*np.pi),)*field.dimensions assert len(self.domain) == field.dimensions self.f.require_group(group) if not "shape" in self.f[group].attrs: self.f[group].attrs.create("shape", field.pencil.shape) if not "rank" in self.f[group].attrs: self.f[group].attrs.create("rank", field.rank) assert field.rank == self.f[group].attrs["rank"] assert np.all(field.pencil.shape == self.f[group].attrs["shape"]) if isinstance(self.domain[0], np.ndarray): self.f[group].require_group("mesh") else: self.f[group].require_group("domain") for i in range(field.dimensions): d = self.domain[i] if isinstance(d, np.ndarray): d0 = np.squeeze(d) self.f[group]["mesh"].require_dataset("x{}".format(i), shape=d0.shape, dtype=d0.dtype, data=d0) else: d0 = np.array([d[0], d[1]]) self.f[group]["domain"].require_dataset("x{}".format(i), shape=d0.shape, dtype=d0.dtype, data=d0)
[docs] @staticmethod def backend(): return 'hdf5'
[docs] def open(self, mode='r+'): import h5py self.f = h5py.File(self.filename, mode, driver="mpio", comm=comm)
[docs] def write(self, step, fields, **kw): """Write snapshot ``step`` of ``fields`` to HDF5 file Parameters ---------- step : int Index of snapshot. fields : dict The fields to be dumped to file. (key, value) pairs are group name and either arrays or 2-tuples, respectively. The arrays are complete arrays to be stored, whereas 2-tuples are arrays with associated *global* slices. as_scalar : boolean, optional Whether to store rank > 0 arrays as scalars. Default is False. Example ------- >>> from mpi4py import MPI >>> from mpi4py_fft import PFFT, HDF5File, newDistArray >>> comm = MPI.COMM_WORLD >>> T = PFFT(comm, (15, 16, 17)) >>> u = newDistArray(T, forward_output=False, val=1) >>> v = newDistArray(T, forward_output=False, val=2) >>> f = HDF5File('h5filename.h5', mode='w') >>> f.write(0, {'u': [u, (u, [slice(None), 4, slice(None)])], ... 'v': [v, (v, [slice(None), 5, 5])]}) >>> f.write(1, {'u': [u, (u, [slice(None), 4, slice(None)])], ... 'v': [v, (v, [slice(None), 5, 5])]}) This stores data within two main groups ``u`` and ``v``. The HDF5 file will in the end contain groups:: /u/3D/{0, 1} /u/2D/slice_4_slice/{0, 1} /v/3D/{0, 1} /v/1D/slice_5_5/{0, 1} Note ---- The list of slices used in storing only parts of the arrays are views of the *global* arrays. """ self.open() FileBase.write(self, step, fields, **kw) self.close()
[docs] def read(self, u, name, **kw): step = kw.get('step', 0) self.open() s = u.local_slice() dset = "/".join((name, "{}D".format(u.dimensions), str(step))) u[:] = self.f[dset][s] self.close()
def _write_slice_step(self, name, step, slices, field, **kw): rank = field.rank slices = (slice(None),)*rank + tuple(slices) slices = list(slices) ndims = slices[rank:].count(slice(None)) slname = self._get_slice_name(slices[rank:]) s = field.local_slice() slices, inside = self._get_local_slices(slices, s) sp = np.nonzero([isinstance(x, slice) for x in slices])[0] sf = tuple(np.take(s, sp)) sl = tuple(slices) group = "/".join((name, "{}D".format(ndims), slname)) self.f.require_group(group) N = field.global_shape self.f[group].require_dataset(str(step), shape=tuple(np.take(N, sp)), dtype=field.dtype) if inside == 1: self.f["/".join((group, str(step)))][sf] = field[sl] def _write_group(self, name, u, step, **kw): s = u.local_slice() group = "/".join((name, "{}D".format(u.dimensions))) self.f.require_group(group) self.f[group].require_dataset(str(step), shape=u.global_shape, dtype=u.dtype) self.f["/".join((group, str(step)))][s] = u