test_lt.py
1 import unittest 2 3 import pytest 4 from sympy import symbols, S, Matrix, Symbol 5 6 from galgebra.ga import Ga 7 from galgebra.lt import Mlt 8 9 10 class TestLt(unittest.TestCase): 11 12 # reproduce gh-540: callable Lt must accept zero (e.g. projection maps) 13 def test_lt_callable_zero(self): 14 ga, e1, e2, e3 = Ga.build('e*1|2|3', g=[1, 1, 1]) 15 # projection onto e1: maps e2 and e3 to zero 16 L = ga.lt(lambda x: (x | e1) * e1) 17 assert L(e1) == e1 18 assert L(e2).is_zero() 19 assert L(e3).is_zero() 20 # zero map: every basis vector maps to zero 21 L_zero = ga.lt(lambda x: x - x) 22 for basis_v in [e1, e2, e3]: 23 assert L_zero(basis_v).is_zero() 24 25 # reproduce gh-105 26 def test_lt_matrix(self): 27 base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True)) 28 a, b = base.mv() 29 A = base.lt([a+b, 2*a-b]) 30 assert str(A) == 'Lt(a) = a + b\nLt(b) = 2*a - b' 31 assert str(A.matrix()) == 'Matrix([[1, 2], [1, -1]])' 32 33 # reproduce gh-461: lt.matrix() on non-Euclidean metrics 34 def test_lt_matrix_oblique(self): 35 # oblique metric g=[[1,1],[1,2]]: matrix() must not include metric factors 36 coords = symbols('x y', real=True) 37 ga, e1, e2 = Ga.build('e*1|2', g=[[1, 1], [1, 2]], coords=coords) 38 L = ga.lt([e1 + e2, 2*e1 - e2]) 39 assert L.matrix() == Matrix([[1, 2], [1, -1]]) 40 41 # Minkowski metric g=diag(1,-1): same check 42 ga2, f0, f1 = Ga.build('e*0|1', g=[1, -1], coords=symbols('t x', real=True)) 43 L2 = ga2.lt([f0 + f1, 2*f0 - f1]) 44 assert L2.matrix() == Matrix([[1, 2], [1, -1]]) 45 46 # reproduce gh-461: Ga.lt('f') on oblique metric must not mix in metric tensor 47 def test_lt_generic_oblique(self): 48 coords = symbols('x y', real=True) 49 ga, e1, e2 = Ga.build('e*1|2', g=[[1, 1], [1, 2]], coords=coords) 50 F = ga.lt('f') 51 c1 = F(e1).get_coefs(1) 52 c2 = F(e2).get_coefs(1) 53 # each coefficient must be a plain symbol, not a metric-weighted expression 54 assert all(isinstance(s, Symbol) for s in c1 + c2) 55 # and the matrix must match those coefficients exactly 56 assert F.matrix() == Matrix([[c1[0], c2[0]], [c1[1], c2[1]]]) 57 58 def test_lt_function(self): 59 """ Test construction from a function """ 60 base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True)) 61 a, b = base.mv() 62 63 def not_linear(x): 64 return x * x 65 with pytest.raises(ValueError, match='linear'): 66 base.lt(not_linear) 67 68 def not_vector(x): 69 return x + S.One 70 with pytest.raises(ValueError, match='vector'): 71 base.lt(not_vector) 72 73 def ok(x): 74 return (x | b) * a + 2*x 75 f = base.lt(ok) 76 x = base.mv('x', 'vector') 77 y = base.mv('y', 'vector') 78 assert f(x) == ok(x) 79 assert f(x^y) == ok(x)^ok(y) 80 assert f(1 + 2*(x^y)) == 1 + 2*(ok(x)^ok(y)) 81 82 def test_deprecations(self): 83 base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True)) 84 l = base.lt([[1, 2], [3, 4]]) 85 with pytest.warns(DeprecationWarning): 86 assert l.X == l.Ga.coord_vec 87 with pytest.warns(DeprecationWarning): 88 assert l.coords == l.Ga.coords 89 90 l = base.lt('L', mode='a') 91 with pytest.warns(DeprecationWarning): 92 assert l.mode == 'a' 93 with pytest.warns(DeprecationWarning): 94 assert not l.fct_flg 95 96 l = base.lt('L', mode='s', f=True) 97 with pytest.warns(DeprecationWarning): 98 assert l.mode == 's' 99 with pytest.warns(DeprecationWarning): 100 assert l.fct_flg 101 102 103 104 class TestMlt(unittest.TestCase): 105 106 def test_basic(self): 107 # from TensorDef.py 108 coords = symbols('t x y z', real=True) 109 st4d, g0, g1, g2, g3 = Ga.build('gamma*t|x|y|z', g=[1, -1, -1, -1], 110 coords=coords) 111 112 A = st4d.mv('T', 'bivector') 113 114 def TA(a1, a2): 115 return A | (a1 ^ a2) 116 117 T = Mlt(TA, st4d) 118 119 # tests begin 120 121 a1 = st4d.mv('a1', 'vector') 122 a2 = st4d.mv('a2', 'vector') 123 a3 = st4d.mv('a3', 'vector') 124 a4 = st4d.mv('a4', 'vector') 125 126 # calling the Mlt is like calling the function 127 assert T(a1, a2) == TA(a1, a2) 128 129 # for addition, argument slots are reused 130 assert (T + T)(a1, a2) == T(a1, a2) + T(a1, a2) 131 assert (T - T)(a1, a2) == T(a1, a2) - T(a1, a2) 132 133 # for multiplication, argument slots are chained 134 assert (T * T)(a1, a2, a3, a4) == TA(a1, a2) * T(a3, a4) 135 assert (T ^ T)(a1, a2, a3, a4) == TA(a1, a2) ^ T(a3, a4) 136 assert (T | T)(a1, a2, a3, a4) == TA(a1, a2) | T(a3, a4) 137 138 # Test linearity properties. Note that this behavior is implied by our 139 # test that T and TA are equivalent above, but it does exercise 140 # `Mlt.__call__` with compound expressions as arguments. 141 alpha = st4d.mv('alpha', 'scalar') 142 b = st4d.mv('b', 'vector') 143 144 assert T(alpha * a1, a2) == alpha * T(a1, a2) 145 assert T(a1, alpha * a2) == alpha * T(a1, a2) 146 assert T(a1 + b, a2) == T(a1, a2) + T(b, a2) 147 assert T(a1, a2 + b) == T(a1, a2) + T(a1, b) 148 149 def test_from_str(self): 150 coords = symbols('x y', real=True) 151 g, e1, e2 = Ga.build('e*1|2', coords=coords, g=[1, 1]) 152 153 a1 = g.mv('a1', 'vector') 154 a2 = g.mv('a2', 'vector') 155 a1x, a1y = a1.get_coefs(1) 156 a2x, a2y = a2.get_coefs(1) 157 158 # one-d 159 T = Mlt('T', g, nargs=1) 160 v = T(a1) 161 162 # Two new symbols created 163 Tx, Ty = sorted(v.free_symbols - {a1x, a1y}, key=lambda x: x.sort_key()) 164 assert v == ( 165 Tx * a1x + 166 Ty * a1y 167 ) 168 169 # two-d 170 T = Mlt('T', g, nargs=2) 171 v = T(a1, a2) 172 173 # four new symbols created 174 Txx, Txy, Tyx, Tyy = sorted(v.free_symbols - {a1x, a1y, a2x, a2y}, key=lambda x: x.sort_key()) 175 assert v == ( 176 Txx * a1x * a2x + 177 Txy * a1x * a2y + 178 Tyx * a1y * a2x + 179 Tyy * a1y * a2y 180 ) 181 182 def test_str_arithmetic(self): 183 """Mlt arithmetic on string-constructed tensors routes through the 184 component-expression constructor and must not raise NotImplementedError.""" 185 coords = symbols('x y', real=True) 186 g, e1, e2 = Ga.build('e*1|2', coords=coords, g=[1, 1]) 187 188 a1 = g.mv('a1', 'vector') 189 a2 = g.mv('a2', 'vector') 190 191 S = Mlt('S', g, nargs=2) 192 T = Mlt('T', g, nargs=2) 193 194 assert (S + T)(a1, a2) == S(a1, a2) + T(a1, a2) 195 assert (S - T)(a1, a2) == S(a1, a2) - T(a1, a2) 196 197 def test_from_component_expression(self): 198 """Mlt constructed from a pre-built component expression (issue #578).""" 199 coords = symbols('x y', real=True) 200 g, e1, e2 = Ga.build('e*1|2', coords=coords, g=[1, 1]) 201 202 a1 = g.mv('a1', 'vector') 203 a2 = g.mv('a2', 'vector') 204 205 # Build the same rank-2 tensor as test_from_str but pass the 206 # fvalue expression directly instead of a string name. 207 T_str = Mlt('T', g, nargs=2) 208 fvalue = T_str.fvalue # sympy expression in slot variables 209 210 # Reconstruct using the component-expression path 211 T_expr = Mlt(fvalue, g, nargs=2) 212 assert T_expr(a1, a2) == T_str(a1, a2) 213 214 # nargs is required for component expressions 215 with pytest.raises(TypeError): 216 Mlt(fvalue, g) 217 218 def test_deprecations(self): 219 g = Ga('e*a|b', g=[1, 1]) 220 with pytest.warns(DeprecationWarning): 221 assert Mlt.extact_basis_indexes(g) == ['a', 'b']