Source code for shenfun.fourier
#pylint: disable=missing-docstring
import numpy as np
from .bases import *
from .matrices import *
[docs]
def energy_fourier(u, T):
r"""Compute the energy of u using Parceval's theorem
.. math::
\int abs(u)^2 dx = N*\sum abs(u_hat)^2
Parameters
----------
u : Array
The Fourier coefficients
T : TensorProductSpace
See https://en.wikipedia.org/wiki/Parseval's_theorem
"""
if not hasattr(T, 'comm'):
# Just a 1D basis
assert u.ndim == 1
if isinstance(T, R2C):
if u.shape[0] % 2 == 0:
result = (2*np.sum(abs(u[1:-1])**2) +
np.sum(abs(u[0])**2) +
np.sum(abs(u[-1])**2))
else:
result = (2*np.sum(abs(u[1:])**2) +
np.sum(abs(u[0])**2))
else:
result = np.sum(abs(u)**2)
return result
comm = T.comm
assert np.all([base.family() == 'fourier' for base in T.bases])
real = False
for axis, base in enumerate(T.bases):
if isinstance(base, R2C):
real = True
break
if real:
s = [slice(None)]*u.ndim
uaxis = axis + u.ndim-len(T.bases)
if T.forward.output_pencil.subcomm[axis].Get_size() == 1:
# aligned in r2c direction
if base.N % 2 == 0:
s[uaxis] = slice(1, -1)
result = 2*np.sum(abs(u[tuple(s)])**2)
s[uaxis] = 0
result += np.sum(abs(u[tuple(s)])**2)
s[uaxis] = -1
result += np.sum(abs(u[tuple(s)])**2)
else:
s[uaxis] = slice(1, None)
result = 2*np.sum(abs(u[tuple(s)])**2)
s[uaxis] = 0
result += np.sum(abs(u[tuple(s)])**2)
else:
# Data not aligned along r2c axis. Need to check about 0 and -1
if base.N % 2 == 0:
s[uaxis] = slice(1, -1)
result = 2*np.sum(abs(u[tuple(s)])**2)
s[uaxis] = 0
if T.local_slice(True)[axis].start == 0:
result += np.sum(abs(u[tuple(s)])**2)
else:
result += 2*np.sum(abs(u[tuple(s)])**2)
s[uaxis] = -1
if T.local_slice(True)[axis].stop == T.dims()[axis]:
result += np.sum(abs(u[tuple(s)])**2)
else:
result += 2*np.sum(abs(u[tuple(s)])**2)
else:
s[uaxis] = slice(1, None)
result = 2*np.sum(abs(u[tuple(s)])**2)
s[uaxis] = 0
if T.local_slice(True)[axis].start == 0:
result += np.sum(abs(u[tuple(s)])**2)
else:
result += 2*np.sum(abs(u[tuple(s)])**2)
else:
result = np.sum(abs(u[...])**2)
result = comm.allreduce(result)
return result