/ test / test_lt.py
test_lt.py
  1  import unittest
  2  
  3  import pytest
  4  from sympy import symbols, S, Matrix, Symbol
  5  
  6  from galgebra.ga import Ga
  7  from galgebra.lt import Mlt
  8  
  9  
 10  class TestLt(unittest.TestCase):
 11  
 12      # reproduce gh-540: callable Lt must accept zero (e.g. projection maps)
 13      def test_lt_callable_zero(self):
 14          ga, e1, e2, e3 = Ga.build('e*1|2|3', g=[1, 1, 1])
 15          # projection onto e1: maps e2 and e3 to zero
 16          L = ga.lt(lambda x: (x | e1) * e1)
 17          assert L(e1) == e1
 18          assert L(e2).is_zero()
 19          assert L(e3).is_zero()
 20          # zero map: every basis vector maps to zero
 21          L_zero = ga.lt(lambda x: x - x)
 22          for basis_v in [e1, e2, e3]:
 23              assert L_zero(basis_v).is_zero()
 24  
 25      # reproduce gh-105
 26      def test_lt_matrix(self):
 27          base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True))
 28          a, b = base.mv()
 29          A = base.lt([a+b, 2*a-b])
 30          assert str(A) == 'Lt(a) = a + b\nLt(b) = 2*a - b'
 31          assert str(A.matrix()) == 'Matrix([[1, 2], [1, -1]])'
 32  
 33      # reproduce gh-461: lt.matrix() on non-Euclidean metrics
 34      def test_lt_matrix_oblique(self):
 35          # oblique metric g=[[1,1],[1,2]]: matrix() must not include metric factors
 36          coords = symbols('x y', real=True)
 37          ga, e1, e2 = Ga.build('e*1|2', g=[[1, 1], [1, 2]], coords=coords)
 38          L = ga.lt([e1 + e2, 2*e1 - e2])
 39          assert L.matrix() == Matrix([[1, 2], [1, -1]])
 40  
 41          # Minkowski metric g=diag(1,-1): same check
 42          ga2, f0, f1 = Ga.build('e*0|1', g=[1, -1], coords=symbols('t x', real=True))
 43          L2 = ga2.lt([f0 + f1, 2*f0 - f1])
 44          assert L2.matrix() == Matrix([[1, 2], [1, -1]])
 45  
 46      # reproduce gh-461: Ga.lt('f') on oblique metric must not mix in metric tensor
 47      def test_lt_generic_oblique(self):
 48          coords = symbols('x y', real=True)
 49          ga, e1, e2 = Ga.build('e*1|2', g=[[1, 1], [1, 2]], coords=coords)
 50          F = ga.lt('f')
 51          c1 = F(e1).get_coefs(1)
 52          c2 = F(e2).get_coefs(1)
 53          # each coefficient must be a plain symbol, not a metric-weighted expression
 54          assert all(isinstance(s, Symbol) for s in c1 + c2)
 55          # and the matrix must match those coefficients exactly
 56          assert F.matrix() == Matrix([[c1[0], c2[0]], [c1[1], c2[1]]])
 57  
 58      def test_lt_function(self):
 59          """ Test construction from a function """
 60          base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True))
 61          a, b = base.mv()
 62  
 63          def not_linear(x):
 64              return x * x
 65          with pytest.raises(ValueError, match='linear'):
 66              base.lt(not_linear)
 67  
 68          def not_vector(x):
 69              return x + S.One
 70          with pytest.raises(ValueError, match='vector'):
 71              base.lt(not_vector)
 72  
 73          def ok(x):
 74              return (x | b) * a + 2*x
 75          f = base.lt(ok)
 76          x = base.mv('x', 'vector')
 77          y = base.mv('y', 'vector')
 78          assert f(x) == ok(x)
 79          assert f(x^y) == ok(x)^ok(y)
 80          assert f(1 + 2*(x^y)) == 1 + 2*(ok(x)^ok(y))
 81  
 82      def test_deprecations(self):
 83          base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True))
 84          l = base.lt([[1, 2], [3, 4]])
 85          with pytest.warns(DeprecationWarning):
 86              assert l.X == l.Ga.coord_vec
 87          with pytest.warns(DeprecationWarning):
 88              assert l.coords == l.Ga.coords
 89  
 90          l = base.lt('L', mode='a')
 91          with pytest.warns(DeprecationWarning):
 92              assert l.mode == 'a'
 93          with pytest.warns(DeprecationWarning):
 94              assert not l.fct_flg
 95  
 96          l = base.lt('L', mode='s', f=True)
 97          with pytest.warns(DeprecationWarning):
 98              assert l.mode == 's'
 99          with pytest.warns(DeprecationWarning):
100              assert l.fct_flg
101  
102  
103  
104  class TestMlt(unittest.TestCase):
105  
106      def test_basic(self):
107          # from TensorDef.py
108          coords = symbols('t x y z', real=True)
109          st4d, g0, g1, g2, g3 = Ga.build('gamma*t|x|y|z', g=[1, -1, -1, -1],
110                                          coords=coords)
111  
112          A = st4d.mv('T', 'bivector')
113  
114          def TA(a1, a2):
115              return A | (a1 ^ a2)
116  
117          T = Mlt(TA, st4d)
118  
119          # tests begin
120  
121          a1 = st4d.mv('a1', 'vector')
122          a2 = st4d.mv('a2', 'vector')
123          a3 = st4d.mv('a3', 'vector')
124          a4 = st4d.mv('a4', 'vector')
125  
126          # calling the Mlt is like calling the function
127          assert T(a1, a2) == TA(a1, a2)
128  
129          # for addition, argument slots are reused
130          assert (T + T)(a1, a2) == T(a1, a2) + T(a1, a2)
131          assert (T - T)(a1, a2) == T(a1, a2) - T(a1, a2)
132  
133          # for multiplication, argument slots are chained
134          assert (T * T)(a1, a2, a3, a4) == TA(a1, a2) * T(a3, a4)
135          assert (T ^ T)(a1, a2, a3, a4) == TA(a1, a2) ^ T(a3, a4)
136          assert (T | T)(a1, a2, a3, a4) == TA(a1, a2) | T(a3, a4)
137  
138          # Test linearity properties. Note that this behavior is implied by our
139          # test that T and TA are equivalent above, but it does exercise
140          # `Mlt.__call__` with compound expressions as arguments.
141          alpha = st4d.mv('alpha', 'scalar')
142          b = st4d.mv('b', 'vector')
143  
144          assert T(alpha * a1, a2) == alpha * T(a1, a2)
145          assert T(a1, alpha * a2) == alpha * T(a1, a2)
146          assert T(a1 + b, a2) == T(a1, a2) + T(b, a2)
147          assert T(a1, a2 + b) == T(a1, a2) + T(a1, b)
148  
149      def test_from_str(self):
150          coords = symbols('x y', real=True)
151          g, e1, e2 = Ga.build('e*1|2', coords=coords, g=[1, 1])
152  
153          a1 = g.mv('a1', 'vector')
154          a2 = g.mv('a2', 'vector')
155          a1x, a1y = a1.get_coefs(1)
156          a2x, a2y = a2.get_coefs(1)
157  
158          # one-d
159          T = Mlt('T', g, nargs=1)
160          v = T(a1)
161  
162          # Two new symbols created
163          Tx, Ty = sorted(v.free_symbols - {a1x, a1y}, key=lambda x: x.sort_key())
164          assert v == (
165              Tx * a1x +
166              Ty * a1y
167          )
168  
169          # two-d
170          T = Mlt('T', g, nargs=2)
171          v = T(a1, a2)
172  
173          # four new symbols created
174          Txx, Txy, Tyx, Tyy = sorted(v.free_symbols - {a1x, a1y, a2x, a2y}, key=lambda x: x.sort_key())
175          assert v == (
176              Txx * a1x * a2x +
177              Txy * a1x * a2y +
178              Tyx * a1y * a2x +
179              Tyy * a1y * a2y
180          )
181  
182      def test_str_arithmetic(self):
183          """Mlt arithmetic on string-constructed tensors routes through the
184          component-expression constructor and must not raise NotImplementedError."""
185          coords = symbols('x y', real=True)
186          g, e1, e2 = Ga.build('e*1|2', coords=coords, g=[1, 1])
187  
188          a1 = g.mv('a1', 'vector')
189          a2 = g.mv('a2', 'vector')
190  
191          S = Mlt('S', g, nargs=2)
192          T = Mlt('T', g, nargs=2)
193  
194          assert (S + T)(a1, a2) == S(a1, a2) + T(a1, a2)
195          assert (S - T)(a1, a2) == S(a1, a2) - T(a1, a2)
196  
197      def test_from_component_expression(self):
198          """Mlt constructed from a pre-built component expression (issue #578)."""
199          coords = symbols('x y', real=True)
200          g, e1, e2 = Ga.build('e*1|2', coords=coords, g=[1, 1])
201  
202          a1 = g.mv('a1', 'vector')
203          a2 = g.mv('a2', 'vector')
204  
205          # Build the same rank-2 tensor as test_from_str but pass the
206          # fvalue expression directly instead of a string name.
207          T_str = Mlt('T', g, nargs=2)
208          fvalue = T_str.fvalue  # sympy expression in slot variables
209  
210          # Reconstruct using the component-expression path
211          T_expr = Mlt(fvalue, g, nargs=2)
212          assert T_expr(a1, a2) == T_str(a1, a2)
213  
214          # nargs is required for component expressions
215          with pytest.raises(TypeError):
216              Mlt(fvalue, g)
217  
218      def test_deprecations(self):
219          g = Ga('e*a|b', g=[1, 1])
220          with pytest.warns(DeprecationWarning):
221              assert Mlt.extact_basis_indexes(g) == ['a', 'b']