test_gorilla.py
1 import pytest 2 3 from mlflow.utils import gorilla 4 5 6 class Delegator: 7 def __init__(self, delegated_fn): 8 self.delegated_fn = delegated_fn 9 10 def __get__(self, instance, owner): 11 return self.delegated_fn 12 13 14 def delegate(delegated_fn): 15 return lambda fn: Delegator(delegated_fn) 16 17 18 def gen_class_A_B(): 19 class A: 20 def f1(self): 21 pass 22 23 def f2(self): 24 pass 25 26 def delegated_f3(self): 27 pass 28 29 @delegate(delegated_f3) 30 def f3(self): 31 pass 32 33 class B(A): 34 def f1(self): 35 pass 36 37 return A, B 38 39 40 @pytest.fixture 41 def gorilla_setting(): 42 return gorilla.Settings(allow_hit=True, store_hit=True) 43 44 45 def test_basic_patch_for_class(gorilla_setting): 46 A, B = gen_class_A_B() 47 48 original_A_f1 = A.f1 49 original_A_f2 = A.f2 50 original_B_f1 = B.f1 51 52 def patched_A_f1(self): 53 pass 54 55 def patched_A_f2(self): 56 pass 57 58 def patched_B_f1(self): 59 pass 60 61 patch_A_f1 = gorilla.Patch(A, "f1", patched_A_f1, gorilla_setting) 62 patch_A_f2 = gorilla.Patch(A, "f2", patched_A_f2, gorilla_setting) 63 patch_B_f1 = gorilla.Patch(B, "f1", patched_B_f1, gorilla_setting) 64 65 assert gorilla.get_original_attribute(A, "f1") is original_A_f1 66 assert gorilla.get_original_attribute(B, "f1") is original_B_f1 67 assert gorilla.get_original_attribute(B, "f2") is original_A_f2 68 69 gorilla.apply(patch_A_f1) 70 assert A.f1 is patched_A_f1 71 assert gorilla.get_original_attribute(A, "f1") is original_A_f1 72 assert gorilla.get_original_attribute(B, "f1") is original_B_f1 73 74 gorilla.apply(patch_B_f1) 75 assert A.f1 is patched_A_f1 76 assert B.f1 is patched_B_f1 77 assert gorilla.get_original_attribute(A, "f1") is original_A_f1 78 assert gorilla.get_original_attribute(B, "f1") is original_B_f1 79 80 gorilla.apply(patch_A_f2) 81 assert A.f2 is patched_A_f2 82 assert B.f2 is patched_A_f2 83 assert gorilla.get_original_attribute(A, "f2") is original_A_f2 84 assert gorilla.get_original_attribute(B, "f2") is original_A_f2 85 86 gorilla.revert(patch_A_f2) 87 assert A.f2 is original_A_f2 88 assert B.f2 is original_A_f2 89 assert gorilla.get_original_attribute(A, "f2") == original_A_f2 90 assert gorilla.get_original_attribute(B, "f2") == original_A_f2 91 92 gorilla.revert(patch_B_f1) 93 assert A.f1 is patched_A_f1 94 assert B.f1 is original_B_f1 95 assert gorilla.get_original_attribute(A, "f1") == original_A_f1 96 assert gorilla.get_original_attribute(B, "f1") == original_B_f1 97 98 gorilla.revert(patch_A_f1) 99 assert A.f1 is original_A_f1 100 assert B.f1 is original_B_f1 101 assert gorilla.get_original_attribute(A, "f1") == original_A_f1 102 assert gorilla.get_original_attribute(B, "f1") == original_B_f1 103 104 105 def test_patch_for_descriptor(gorilla_setting): 106 A, _ = gen_class_A_B() 107 108 original_A_f3_raw = object.__getattribute__(A, "f3") 109 110 def patched_A_f3(self): 111 pass 112 113 patch_A_f3 = gorilla.Patch(A, "f3", patched_A_f3, gorilla_setting) 114 115 assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3 116 assert ( 117 gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True) 118 is original_A_f3_raw 119 ) 120 121 gorilla.apply(patch_A_f3) 122 assert A.f3 is patched_A_f3 123 assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3 124 assert ( 125 gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True) 126 is original_A_f3_raw 127 ) 128 129 gorilla.revert(patch_A_f3) 130 assert A.f3 is A.delegated_f3 131 assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3 132 assert ( 133 gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True) 134 is original_A_f3_raw 135 ) 136 137 # test patch a descriptor 138 @delegate(patched_A_f3) 139 def new_patched_A_f3(self): 140 pass 141 142 new_patch_A_f3 = gorilla.Patch(A, "f3", new_patched_A_f3, gorilla_setting) 143 gorilla.apply(new_patch_A_f3) 144 assert A.f3 is patched_A_f3 145 assert object.__getattribute__(A, "f3") is new_patched_A_f3 146 assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3 147 assert ( 148 gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True) 149 is original_A_f3_raw 150 ) 151 152 153 @pytest.mark.parametrize("store_hit", [True, False]) 154 def test_patch_on_inherit_method(store_hit): 155 A, B = gen_class_A_B() 156 157 original_A_f2 = A.f2 158 159 def patched_B_f2(self): 160 pass 161 162 gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit) 163 patch_B_f2 = gorilla.Patch(B, "f2", patched_B_f2, gorilla_setting) 164 gorilla.apply(patch_B_f2) 165 166 assert B.f2 is patched_B_f2 167 168 assert gorilla.get_original_attribute(B, "f2") is original_A_f2 169 170 gorilla.revert(patch_B_f2) 171 assert B.f2 is original_A_f2 172 assert gorilla.get_original_attribute(B, "f2") is original_A_f2 173 assert "f2" not in B.__dict__ # assert no side effect after reverting 174 175 176 @pytest.mark.parametrize("store_hit", [True, False]) 177 def test_patch_on_attribute_not_exist(store_hit): 178 A, _ = gen_class_A_B() 179 180 def patched_fx(self): 181 return 101 182 183 gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit) 184 fx_patch = gorilla.Patch(A, "fx", patched_fx, gorilla_setting) 185 gorilla.apply(fx_patch) 186 a1 = A() 187 assert a1.fx() == 101 188 gorilla.revert(fx_patch) 189 assert not hasattr(A, "fx")