Source code for shenfun.optimization.numba.diagma
import numba as nb
import numpy as np
from .la import Solve_axis_2D, Solve_axis_3D, Solve_axis_4D
__all__ = ['DiagMA_inner_solve', 'DiagMA_Solve']
[docs]
def DiagMA_Solve(x, data, axis=0):
n = x.ndim
if n == 1:
DiagMA_inner_solve(x, data)
elif n == 2:
Solve_axis_2D(data, x, DiagMA_inner_solve, axis)
elif n == 3:
Solve_axis_3D(data, x, DiagMA_inner_solve, axis)
elif n == 4:
Solve_axis_4D(data, x, DiagMA_inner_solve, axis)
else:
if axis > 0:
x = np.moveaxis(x, axis, 0)
DiagMA_inner_solve(x, data)
if axis > 0:
x = np.moveaxis(x, 0, axis)
[docs]
@nb.njit
def DiagMA_inner_solve(u, data):
d = data[0]
for i in range(d.shape[0]):
u[i] /= d[i]