import numbers
import sympy as sp
import numpy as np
from shenfun.config import config
[docs]
class Coordinates:
"""Class for handling curvilinear coordinates
Parameters
----------
psi : tuple or sp.Symbol
The new coordinates
rv : tuple
The position vector in terms of the new coordinates
assumptions : Sympy assumptions
One or more `Sympy assumptions <https://docs.sympy.org/latest/modules/assumptions/index.html>`_
replace : sequence of two-tuples
Use Sympy's replace with these two-tuples
measure : Python function to replace Sympy's count_ops.
For example, to discourage the use of powers in an expression use::
def discourage_powers(expr):
POW = sp.Symbol('POW')
count = sp.count_ops(expr, visual=True)
count = count.replace(POW, 100)
count = count.replace(sp.Symbol, type(sp.S.One))
return count
"""
def __init__(self, psi, rv, assumptions=True, replace=(), measure=sp.count_ops):
self._psi = (psi,) if isinstance(psi, sp.Symbol) else psi
self._rv = rv
self._assumptions = assumptions
self._replace = replace
self._measure = measure
self._hi = None
self._b = None
self._bt = None
self._e = None
self._g = None
self._gt = None
self._gn = None
self._ct = None
self._det_g = {True: None, False: None}
self._sqrt_det_g = {True: None, False: None}
@property
def b(self):
return self.get_covariant_basis()
@property
def e(self):
return self.get_normal_basis()
@property
def hi(self):
return self.get_scaling_factors()
@property
def sg(self):
if self.is_cartesian:
return 1
return self.get_sqrt_det_g(True)
@property
def coordinates(self):
return (self.psi, self.rv, self._assumptions, self._replace)
@property
def psi(self):
return self._psi
@property
def rv(self):
return self._rv
@property
def is_orthogonal(self):
return sp.Matrix(self.get_covariant_metric_tensor()).is_diagonal()
@property
def is_cartesian(self):
if len(self.psi) != len(self.rv):
return False
return sp.Matrix(self.get_covariant_metric_tensor()).is_Identity
[docs]
def get_det_g(self, covariant=True):
"""Return determinant of covariant metric tensor"""
if self._det_g[covariant] is not None:
return self._det_g[covariant]
if covariant:
g = sp.Matrix(self.get_covariant_metric_tensor()).det()
else:
g = sp.Matrix(self.get_contravariant_metric_tensor()).det()
g = g.factor()
g = self.refine(g)
g = sp.simplify(g, measure=self._measure)
self._det_g[covariant] = g
return g
[docs]
def get_sqrt_det_g(self, covariant=True):
"""Return square root of determinant of covariant metric tensor"""
if self._sqrt_det_g[covariant] is not None:
return self._sqrt_det_g[covariant]
g = self.get_det_g(covariant)
sg = sp.simplify(sp.sqrt(g), measure=self._measure)
sg = self.refine(sg)
if isinstance(sg, numbers.Number):
if isinstance(sg, numbers.Real):
sg = float(sg)
elif isinstance(sg, numbers.Complex):
sg = complex(sg)
else:
raise NotImplementedError
self._sqrt_det_g[covariant] = sg
return sg
[docs]
def get_cartesian_basis(self):
"""Return Cartesian basis vectors"""
return np.eye(len(self.rv), dtype=object)
[docs]
def get_scaling_factors(self):
"""Return scaling factors"""
if self._hi is not None:
return self._hi
hi = np.zeros_like(self.psi)
for i, s in enumerate(np.sum(self.b**2, axis=1)):
hi[i] = sp.sqrt(self.refine(sp.simplify(s, measure=self._measure)))
hi[i] = self.refine(hi[i])
self._hi = hi
return hi
def get_normal_basis(self):
if self._e is not None:
return self._e
b = self.b
e = np.zeros_like(b)
for i, bi in enumerate(b):
l = sp.sqrt(sp.simplify(np.dot(bi, bi)))
l = self.refine(l)
e[i] = bi / l
self._e = e
return e
[docs]
def get_covariant_basis(self):
"""Return covariant basisvectors"""
if self._b is not None:
return self._b
b = np.zeros((len(self.psi), len(self.rv)), dtype=object)
for i, ti in enumerate(self.psi):
for j, rj in enumerate(self.rv):
b[i, j] = rj.diff(ti, 1)
b[i, j] = sp.simplify(b[i, j], measure=self._measure)
b[i, j] = self.refine(b[i, j])
#if len(self.psi) == 2 and len(self.rv) == 3:
# b[-1] = np.cross(b[0], b[1])
# bl = self.refine(sp.sqrt(sp.simplify(np.dot(b[-1], b[-1]))))
# b[-1] = b[-1] / bl
# for j in range(len(self.rv)):
# b[-1, j] = sp.simplify(b[-1, j], measure=self._measure)
# b[-1, j] = self.refine(b[-1, j])
self._b = b
return b
[docs]
def get_contravariant_basis(self):
"""Return contravariant basisvectors"""
if self._bt is not None:
return self._bt
bt = np.zeros_like(self.b)
g = self.get_contravariant_metric_tensor()
b = self.b
for i in range(len(self.psi)):
for j in range(len(self.psi)):
bt[i] += g[i, j]*b[j]
for i in range(len(self.psi)):
for j in range(len(self.psi)):
bt[i, j] = sp.simplify(bt[i, j], measure=self._measure)
self._bt = bt
return bt
[docs]
def get_normal_metric_tensor(self):
"""Return normal metric tensor"""
if self._gn is not None:
return self._gn
gn = np.zeros((len(self.psi), len(self.psi)), dtype=object)
e = self.e
for i in range(len(self.psi)):
for j in range(len(self.psi)):
gn[i, j] = sp.simplify(np.dot(e[i], e[j]).expand(), measure=self._measure)
gn[i, j] = self.refine(gn[i, j])
self._gn = gn
return gn
[docs]
def get_covariant_metric_tensor(self):
"""Return covariant metric tensor"""
if self._g is not None:
return self._g
g = np.zeros((len(self.psi), len(self.psi)), dtype=object)
b = self.b
for i in range(len(self.psi)):
for j in range(len(self.psi)):
g[i, j] = sp.simplify(np.dot(b[i], b[j]).expand(), measure=self._measure)
g[i, j] = self.refine(g[i, j])
self._g = g
return g
[docs]
def get_contravariant_metric_tensor(self):
"""Return contravariant metric tensor"""
if self._gt is not None:
return self._gt
g = self.get_covariant_metric_tensor()
gt = sp.Matrix(g).inv()
for i in range(gt.shape[0]):
for j in range(gt.shape[1]):
gt[i, j] = sp.simplify(gt[i, j], measure=self._measure)
gt = np.array(gt)
self._gt = gt
return gt
[docs]
def get_christoffel_second(self):
"""Return Christoffel symbol of second kind"""
if self._ct is not None:
return self._ct
b = self.get_covariant_basis()
bt = self.get_contravariant_basis()
ct = np.zeros((len(self.psi),)*3, object)
for i in range(len(self.psi)):
for j in range(len(self.psi)):
for k in range(len(self.psi)):
ct[k, i, j] = sp.simplify(np.dot(np.array([bij.diff(self.psi[j], 1) for bij in b[i]]), bt[k]), measure=self._measure)
self._ct = ct
return ct
def get_metric_tensor(self, kind='normal'):
if kind == 'covariant':
gij = self.get_covariant_metric_tensor()
elif kind == 'contravariant':
gij = self.get_contravariant_metric_tensor()
elif kind == 'normal':
gij = self.get_normal_metric_tensor()
else:
raise NotImplementedError
return gij
def get_basis(self, kind='normal'):
if kind == 'covariant':
return self.get_covariant_basis()
assert kind == 'normal'
return self.get_normal_basis()
def refine(self, sc):
sc = sp.refine(sc, self._assumptions)
for a, b in self._replace:
sc = sc.replace(a, b)
return sc
def subs(self, s0, s1):
b = self.get_covariant_basis()
for i in range(b.shape[0]):
for j in range(b.shape[1]):
b[i, j] = b[i, j].subs(s0, s1)
g = self.get_covariant_metric_tensor()
gt = self.get_contravariant_metric_tensor()
for i in range(g.shape[0]):
for j in range(g.shape[1]):
g[i, j] = g[i, j].subs(s0, s1)
gt[i, j] = gt[i, j].subs(s0, s1)
sg = self.get_sqrt_det_g().subs(s0, s1)
self._sqrt_det_g[True] = sg
hi = self.get_scaling_factors()
for i in range(len(hi)):
hi[i] = hi[i].subs(s0, s1)
self._psi = tuple([p.subs(s0, s1) for p in self._psi])
self._rv = tuple([r.subs(s0, s1) for r in self._rv])
def latex_basis_vectors(self, symbol_names=None, replace=None, kind=None):
if kind is None:
kind = config['basisvectors']
if kind == 'covariant':
b = self.get_covariant_basis()
elif kind == 'contravariant':
b = self.get_contravariant_basis()
else:
b = self.get_normal_basis()
psi = self.psi
symbols = {p: str(p) for p in psi}
if symbol_names is not None:
symbols = symbol_names
k = {0: '\\mathbf{i}', 1: '\\mathbf{j}', 2: '\\mathbf{k}'}
m = ' '
bl = 'e' if kind == 'normal' else 'b'
for i, p in enumerate(psi):
if kind in ('covariant', 'normal'):
m += '\\mathbf{%s}_{%s} ='%(bl, symbols[p])
else:
m += '\\mathbf{%s}^{%s} ='%(bl, symbols[p])
for j in range(b.shape[1]):
if b[i, j] == 1:
m += (k[j]+'+')
elif b[i, j] != 0:
if replace is not None:
for repl in replace:
assert len(repl) == 2
b[i, j] = b[i, j].replace(*repl)
sl = sp.latex(b[i, j], symbol_names=symbols)
if sl.startswith('-') and not isinstance(b[i, j], sp.Add):
m = m.rstrip('+')
if isinstance(b[i, j], sp.Add):
sl = '\\left(%s\\right)'%(sl)
m += (sl+'\\,'+k[j]+'+')
m = m.rstrip('+')
m += ' \\\\ '
m += ' '
return r'%s'%(m)