/ test / test_printer.py
test_printer.py
  1  import contextlib
  2  import io
  3  import textwrap
  4  import sys
  5  
  6  import pytest
  7  from sympy import Symbol, Derivative
  8  
  9  from galgebra.printer import GaPrinter, GaLatexPrinter, oprint, _texify, Get_Program
 10  from galgebra.ga import Ga
 11  
 12  
 13  has_ordered_dictionaries = sys.version_info >= (3, 6)
 14  
 15  
 16  def test_latex_flg_GaPrinter():
 17      g3d, e_1, e_2, e_3 = Ga.build('e*1|2|3')
 18      t = Symbol('theta')
 19  
 20      mv = t + 0*e_1
 21  
 22      # it shouldn't matter whether the latex printer is enabled, if we call
 23      # the non-latex one we should get a non latex result.
 24      assert GaPrinter().doprint(mv) == 'theta'
 25      GaLatexPrinter.redirect()
 26      try:
 27          assert GaPrinter().doprint(mv) == 'theta'
 28      finally:
 29          GaLatexPrinter.restore()
 30  
 31  
 32  def test_latex_flg_Symbol_sortkey():
 33      """
 34      Symbol.sort_key should not be affected by whether latex printing is enabled.
 35  
 36      `sort_key` affects the order of expressions returned by `sympy.simplify`.
 37      """
 38      t = Symbol('theta')
 39  
 40      t_sort = t.sort_key()
 41  
 42      # sort key is cached, and we want to compare the behavior between the cache
 43      # being populated from within a latex context to it being populated from
 44      # outside it
 45      t.sort_key.cache_clear()
 46  
 47      GaLatexPrinter.redirect()
 48      try:
 49          t_latex_sort = t.sort_key()
 50      finally:
 51          GaLatexPrinter.restore()
 52  
 53      assert t_sort == t_latex_sort
 54  
 55  
 56  def test_deprecated_get_program():
 57      with pytest.warns(DeprecationWarning):
 58          # returns nothing and does very little
 59          assert Get_Program() is None
 60  
 61  
 62  def test_oprint():
 63      s = io.StringIO()
 64      with contextlib.redirect_stdout(s):
 65          oprint(
 66              'int', 1,
 67              'dictionary', dict(a=1, b=2),
 68              'set', {1},
 69              'tuple', (1, 2),
 70              'list', [1, 2, 3],
 71              'str', 'a quote: "',
 72              'deriv', Derivative(Symbol('x'), Symbol('x'), evaluate=False),
 73          )
 74  
 75      if has_ordered_dictionaries:
 76          assert s.getvalue() == textwrap.dedent("""\
 77              int        = 1
 78              dictionary = {a: 1, b: 2}
 79              set        = {1}
 80              tuple      = (1, 2)
 81              list       = [1, 2, 3]
 82              str        = a quote: "
 83              deriv      = D{x}x
 84              """)
 85  
 86  
 87  def test_oprint_dict_mode():
 88      s = io.StringIO()
 89      with contextlib.redirect_stdout(s):
 90          oprint(
 91              'int', 1,
 92              'dictionary', dict(a=1, b=2),
 93              'set', {1},
 94              'tuple', (1, 2),
 95              'list', [1, 2, 3],
 96              'str', 'a quote: "',
 97              dict_mode=True
 98          )
 99  
100      if has_ordered_dictionaries:
101          assert s.getvalue() == textwrap.dedent("""\
102              int   = 1
103              dictionary:
104              a -> 1
105              b -> 2
106              set   = {1}
107              tuple = (1, 2)
108              list  = [1, 2, 3]
109              str   = a quote: "
110              """)
111  
112  
113  def test_texify():
114      # operators
115      assert _texify('a|b') == r'a\cdot b'
116      assert _texify('a^b') == r'a\W b'
117      assert _texify('a*b') == r'a b'
118      assert _texify('a<b') == r'a\rfloor b'
119      assert _texify('a>b') == r'a\lfloor b'
120      assert _texify('a>>b') == r'a \times b'
121      assert _texify('a<<b') == r'a \bar{\times} b'
122  
123      # grad
124      assert _texify('grad(a)') == r'\boldsymbol{\nabla} (a)'
125      assert _texify('a rgrad') == r'a \bar{\boldsymbol{\nabla}} '
126      # does not affect words containing grad
127      assert _texify('gradual') == r'gradual'
128  
129      # superscripts with {} do not become wedges
130      assert _texify('x^{2}') == r'x^{2}'
131  
132      # @@ was previously an internal marker
133      assert _texify('@@ is safe') == r'@@ is safe'
134  
135  
136  def test_no_extra_cdot():
137      """Regression test: no spurious \\cdot in latex output (issue 494)."""
138      from sympy import symbols
139      ga = Ga('e', g=[1, 1, 1], coords=symbols('x y z', real=True))
140      f = ga.mv('f', 'scalar', f=True)
141      grad = ga.grad
142      p = GaLatexPrinter()
143      latex_str = p.doprint(grad * f)
144      # There should be no \cdot in the gradient of a scalar field
145      assert r'\cdot' not in latex_str