/ galgebra / atoms.py
atoms.py
  1  """ Sympy primitives for representing atoms of ga expressions """
  2  
  3  from typing import Union
  4  
  5  from sympy import Symbol, AtomicExpr, S, Basic, sympify, MatrixExpr
  6  from sympy import Determinant as _Determinant
  7  from sympy.core import numbers
  8  from sympy.core.function import AppliedUndef, UndefinedFunction
  9  from sympy.printing.pretty.stringpict import prettyForm, stringPict
 10  from sympy.printing.pretty.pretty_symbology import U
 11  
 12  
 13  __all__ = [
 14      'BasisVectorSymbol',
 15      'BasisBladeSymbol',
 16      'BasisBladeNoWedgeSymbol',
 17      'BasisBaseSymbol',
 18      'DotProductSymbol',
 19  ]
 20  
 21  
 22  def _all_same(items):
 23      return all(x == items[0] for x in items)
 24  
 25  
 26  class BasisVectorSymbol(Symbol):
 27      """ A symbol representing a basis vector """
 28      is_commutative = False
 29  
 30      def _latex(self, print_obj):
 31          try:
 32              return print_obj._print_Symbol(self, style="bold")
 33          except TypeError:
 34              # too old a sympy version for `style=`
 35              return r"\mathbf{{{}}}".format(print_obj._print_Symbol(self))
 36  
 37  
 38  class _GradedSymbol(AtomicExpr):
 39      """ Base class for all graded symbols
 40  
 41      Constructing this from a single symbol returns that symbol itself.
 42      Constructing from no symbols returns the scalar `S.One`.
 43      This may change in future.
 44      """
 45      # the scalar isn't commutative, but __new__ ensures we do not ever create
 46      # this type of objects for scalars
 47      is_commutative = False
 48  
 49      def __new__(cls, *args: BasisVectorSymbol) -> Union[
 50          numbers.One,
 51          BasisVectorSymbol,
 52          "_GradedSymbol"
 53      ]:
 54          if len(args) == 0:
 55              return S.One
 56          elif len(args) == 1:
 57              return args[0]
 58          else:
 59              return super().__new__(cls, *args)
 60  
 61  
 62  class _JoinedPrinterMixin(Basic):
 63      """ Helper class to print `Basic.args` joined by symbol.
 64  
 65      Subclasses must populate `_op_sym` and `_op_sym_latex`
 66      """
 67  
 68      def _sympystr(self, printer):
 69          return self._op_sympystr.join(
 70              printer._print(v)
 71              for v in self.args
 72          )
 73  
 74      def _pretty(self, printer):
 75          ret = []
 76          for i, v in enumerate(self.args):
 77              if i != 0:
 78                  ret.append(self._op_pretty)
 79              ret.append(printer._print(v))
 80          return prettyForm(*stringPict.next(*ret))
 81  
 82      def _latex(self, printer):
 83          return self._op_latex.join(
 84              printer._print(v)
 85              for v in self.args
 86          )
 87  
 88  
 89  class BasisBaseSymbol(_GradedSymbol, _JoinedPrinterMixin):
 90      r""" A basis base in a non-orthogonal algebra, such as :math:`e_1 e_2` """
 91      _op_sympystr = '*'
 92      _op_pretty = prettyForm('*')
 93      _op_latex = ''
 94  
 95  
 96  class BasisBladeSymbol(_GradedSymbol, _JoinedPrinterMixin):
 97      r""" A basis blade such as :math:`e_1 \wedge e_2` """
 98      _op_sympystr = '^'
 99      _op_pretty = prettyForm('^')
100      _op_latex = r'\wedge '
101  
102  
103  class BasisBladeNoWedgeSymbol(BasisBladeSymbol):
104      r""" A basis blade with shortened rendering such as :math:`e_{12}` """
105      def _split_name(self):
106          sub_str = []
107          root_str = []
108          for basis_vec in self.args:
109              split_lst = basis_vec.name.split('_')
110              if len(split_lst) != 2:
111                  raise ValueError('Incompatible basis vector {} for wedgeless printing'.format(basis_vec))
112              else:
113                  sub_str.append(split_lst[1])
114                  root_str.append(split_lst[0])
115          if _all_same(root_str):
116              return root_str[0], ''.join(sub_str)
117          else:
118              raise ValueError('No unique root symbol to use for wedgeless printing')
119  
120      def __common_printer(self, printer):
121          # print as if we were a basis vector
122          root, sub = self._split_name()
123          return printer._print(BasisVectorSymbol("{}_{}".format(root, sub)))
124  
125      _sympystr = _pretty = _latex = __common_printer
126  
127  
128  class DotProductSymbol(AtomicExpr):
129      """ A symbol used to represent a dot product, like :class:`sympy.DotProduct` """
130      is_real = True
131  
132      def _sympystr(self, printer):
133          a, b = self.args
134          return "({}.{})".format(printer._print(a), printer._print(b))
135  
136      def _latex(self, printer):
137          a, b = self.args
138          return r"\left ({}\cdot {}\right ) ".format(printer._print(a), printer._print(b))
139  
140      def _pretty(self, printer):
141          a, b = self.args
142          pform = prettyForm(*stringPict.next(
143              printer._print(a),
144              printer._print(U('DOT OPERATOR')),
145              printer._print(b),
146          ))
147          return prettyForm(*pform.parens())
148  
149  
150  class MatrixFunctionClass(UndefinedFunction):
151      """ Like a MatrixSymbol, but for functions. """
152      def __new__(mcl, name, shape, **kwargs):
153          cls = super().__new__(mcl, name, (AppliedUndef, MatrixExpr), {}, **kwargs)
154          m, n = shape
155          cls.shape = sympify(m, strict=True), sympify(n, strict=True)
156          return cls
157  
158  
159  # workaround until pygae/galgebra#495 is truely fixed
160  def MatrixFunction(name, m, n):
161      return MatrixFunctionClass(name, (m, n))
162  
163  
164  # workaround until sympy/sympy#19354 is merged
165  if _Determinant.is_commutative is not True:
166      class Determinant(_Determinant):
167          is_commutative = True
168  else:
169      Determinant = _Determinant