atoms.py
1 """ Sympy primitives for representing atoms of ga expressions """ 2 3 from typing import Union 4 5 from sympy import Symbol, AtomicExpr, S, Basic, sympify, MatrixExpr 6 from sympy import Determinant as _Determinant 7 from sympy.core import numbers 8 from sympy.core.function import AppliedUndef, UndefinedFunction 9 from sympy.printing.pretty.stringpict import prettyForm, stringPict 10 from sympy.printing.pretty.pretty_symbology import U 11 12 13 __all__ = [ 14 'BasisVectorSymbol', 15 'BasisBladeSymbol', 16 'BasisBladeNoWedgeSymbol', 17 'BasisBaseSymbol', 18 'DotProductSymbol', 19 ] 20 21 22 def _all_same(items): 23 return all(x == items[0] for x in items) 24 25 26 class BasisVectorSymbol(Symbol): 27 """ A symbol representing a basis vector """ 28 is_commutative = False 29 30 def _latex(self, print_obj): 31 try: 32 return print_obj._print_Symbol(self, style="bold") 33 except TypeError: 34 # too old a sympy version for `style=` 35 return r"\mathbf{{{}}}".format(print_obj._print_Symbol(self)) 36 37 38 class _GradedSymbol(AtomicExpr): 39 """ Base class for all graded symbols 40 41 Constructing this from a single symbol returns that symbol itself. 42 Constructing from no symbols returns the scalar `S.One`. 43 This may change in future. 44 """ 45 # the scalar isn't commutative, but __new__ ensures we do not ever create 46 # this type of objects for scalars 47 is_commutative = False 48 49 def __new__(cls, *args: BasisVectorSymbol) -> Union[ 50 numbers.One, 51 BasisVectorSymbol, 52 "_GradedSymbol" 53 ]: 54 if len(args) == 0: 55 return S.One 56 elif len(args) == 1: 57 return args[0] 58 else: 59 return super().__new__(cls, *args) 60 61 62 class _JoinedPrinterMixin(Basic): 63 """ Helper class to print `Basic.args` joined by symbol. 64 65 Subclasses must populate `_op_sym` and `_op_sym_latex` 66 """ 67 68 def _sympystr(self, printer): 69 return self._op_sympystr.join( 70 printer._print(v) 71 for v in self.args 72 ) 73 74 def _pretty(self, printer): 75 ret = [] 76 for i, v in enumerate(self.args): 77 if i != 0: 78 ret.append(self._op_pretty) 79 ret.append(printer._print(v)) 80 return prettyForm(*stringPict.next(*ret)) 81 82 def _latex(self, printer): 83 return self._op_latex.join( 84 printer._print(v) 85 for v in self.args 86 ) 87 88 89 class BasisBaseSymbol(_GradedSymbol, _JoinedPrinterMixin): 90 r""" A basis base in a non-orthogonal algebra, such as :math:`e_1 e_2` """ 91 _op_sympystr = '*' 92 _op_pretty = prettyForm('*') 93 _op_latex = '' 94 95 96 class BasisBladeSymbol(_GradedSymbol, _JoinedPrinterMixin): 97 r""" A basis blade such as :math:`e_1 \wedge e_2` """ 98 _op_sympystr = '^' 99 _op_pretty = prettyForm('^') 100 _op_latex = r'\wedge ' 101 102 103 class BasisBladeNoWedgeSymbol(BasisBladeSymbol): 104 r""" A basis blade with shortened rendering such as :math:`e_{12}` """ 105 def _split_name(self): 106 sub_str = [] 107 root_str = [] 108 for basis_vec in self.args: 109 split_lst = basis_vec.name.split('_') 110 if len(split_lst) != 2: 111 raise ValueError('Incompatible basis vector {} for wedgeless printing'.format(basis_vec)) 112 else: 113 sub_str.append(split_lst[1]) 114 root_str.append(split_lst[0]) 115 if _all_same(root_str): 116 return root_str[0], ''.join(sub_str) 117 else: 118 raise ValueError('No unique root symbol to use for wedgeless printing') 119 120 def __common_printer(self, printer): 121 # print as if we were a basis vector 122 root, sub = self._split_name() 123 return printer._print(BasisVectorSymbol("{}_{}".format(root, sub))) 124 125 _sympystr = _pretty = _latex = __common_printer 126 127 128 class DotProductSymbol(AtomicExpr): 129 """ A symbol used to represent a dot product, like :class:`sympy.DotProduct` """ 130 is_real = True 131 132 def _sympystr(self, printer): 133 a, b = self.args 134 return "({}.{})".format(printer._print(a), printer._print(b)) 135 136 def _latex(self, printer): 137 a, b = self.args 138 return r"\left ({}\cdot {}\right ) ".format(printer._print(a), printer._print(b)) 139 140 def _pretty(self, printer): 141 a, b = self.args 142 pform = prettyForm(*stringPict.next( 143 printer._print(a), 144 printer._print(U('DOT OPERATOR')), 145 printer._print(b), 146 )) 147 return prettyForm(*pform.parens()) 148 149 150 class MatrixFunctionClass(UndefinedFunction): 151 """ Like a MatrixSymbol, but for functions. """ 152 def __new__(mcl, name, shape, **kwargs): 153 cls = super().__new__(mcl, name, (AppliedUndef, MatrixExpr), {}, **kwargs) 154 m, n = shape 155 cls.shape = sympify(m, strict=True), sympify(n, strict=True) 156 return cls 157 158 159 # workaround until pygae/galgebra#495 is truely fixed 160 def MatrixFunction(name, m, n): 161 return MatrixFunctionClass(name, (m, n)) 162 163 164 # workaround until sympy/sympy#19354 is merged 165 if _Determinant.is_commutative is not True: 166 class Determinant(_Determinant): 167 is_commutative = True 168 else: 169 Determinant = _Determinant