/ tests / utils / test_gorilla.py
test_gorilla.py
  1  import pytest
  2  
  3  from mlflow.utils import gorilla
  4  
  5  
  6  class Delegator:
  7      def __init__(self, delegated_fn):
  8          self.delegated_fn = delegated_fn
  9  
 10      def __get__(self, instance, owner):
 11          return self.delegated_fn
 12  
 13  
 14  def delegate(delegated_fn):
 15      return lambda fn: Delegator(delegated_fn)
 16  
 17  
 18  def gen_class_A_B():
 19      class A:
 20          def f1(self):
 21              pass
 22  
 23          def f2(self):
 24              pass
 25  
 26          def delegated_f3(self):
 27              pass
 28  
 29          @delegate(delegated_f3)
 30          def f3(self):
 31              pass
 32  
 33      class B(A):
 34          def f1(self):
 35              pass
 36  
 37      return A, B
 38  
 39  
 40  @pytest.fixture
 41  def gorilla_setting():
 42      return gorilla.Settings(allow_hit=True, store_hit=True)
 43  
 44  
 45  def test_basic_patch_for_class(gorilla_setting):
 46      A, B = gen_class_A_B()
 47  
 48      original_A_f1 = A.f1
 49      original_A_f2 = A.f2
 50      original_B_f1 = B.f1
 51  
 52      def patched_A_f1(self):
 53          pass
 54  
 55      def patched_A_f2(self):
 56          pass
 57  
 58      def patched_B_f1(self):
 59          pass
 60  
 61      patch_A_f1 = gorilla.Patch(A, "f1", patched_A_f1, gorilla_setting)
 62      patch_A_f2 = gorilla.Patch(A, "f2", patched_A_f2, gorilla_setting)
 63      patch_B_f1 = gorilla.Patch(B, "f1", patched_B_f1, gorilla_setting)
 64  
 65      assert gorilla.get_original_attribute(A, "f1") is original_A_f1
 66      assert gorilla.get_original_attribute(B, "f1") is original_B_f1
 67      assert gorilla.get_original_attribute(B, "f2") is original_A_f2
 68  
 69      gorilla.apply(patch_A_f1)
 70      assert A.f1 is patched_A_f1
 71      assert gorilla.get_original_attribute(A, "f1") is original_A_f1
 72      assert gorilla.get_original_attribute(B, "f1") is original_B_f1
 73  
 74      gorilla.apply(patch_B_f1)
 75      assert A.f1 is patched_A_f1
 76      assert B.f1 is patched_B_f1
 77      assert gorilla.get_original_attribute(A, "f1") is original_A_f1
 78      assert gorilla.get_original_attribute(B, "f1") is original_B_f1
 79  
 80      gorilla.apply(patch_A_f2)
 81      assert A.f2 is patched_A_f2
 82      assert B.f2 is patched_A_f2
 83      assert gorilla.get_original_attribute(A, "f2") is original_A_f2
 84      assert gorilla.get_original_attribute(B, "f2") is original_A_f2
 85  
 86      gorilla.revert(patch_A_f2)
 87      assert A.f2 is original_A_f2
 88      assert B.f2 is original_A_f2
 89      assert gorilla.get_original_attribute(A, "f2") == original_A_f2
 90      assert gorilla.get_original_attribute(B, "f2") == original_A_f2
 91  
 92      gorilla.revert(patch_B_f1)
 93      assert A.f1 is patched_A_f1
 94      assert B.f1 is original_B_f1
 95      assert gorilla.get_original_attribute(A, "f1") == original_A_f1
 96      assert gorilla.get_original_attribute(B, "f1") == original_B_f1
 97  
 98      gorilla.revert(patch_A_f1)
 99      assert A.f1 is original_A_f1
100      assert B.f1 is original_B_f1
101      assert gorilla.get_original_attribute(A, "f1") == original_A_f1
102      assert gorilla.get_original_attribute(B, "f1") == original_B_f1
103  
104  
105  def test_patch_for_descriptor(gorilla_setting):
106      A, _ = gen_class_A_B()
107  
108      original_A_f3_raw = object.__getattribute__(A, "f3")
109  
110      def patched_A_f3(self):
111          pass
112  
113      patch_A_f3 = gorilla.Patch(A, "f3", patched_A_f3, gorilla_setting)
114  
115      assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
116      assert (
117          gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True)
118          is original_A_f3_raw
119      )
120  
121      gorilla.apply(patch_A_f3)
122      assert A.f3 is patched_A_f3
123      assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
124      assert (
125          gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True)
126          is original_A_f3_raw
127      )
128  
129      gorilla.revert(patch_A_f3)
130      assert A.f3 is A.delegated_f3
131      assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
132      assert (
133          gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True)
134          is original_A_f3_raw
135      )
136  
137      # test patch a descriptor
138      @delegate(patched_A_f3)
139      def new_patched_A_f3(self):
140          pass
141  
142      new_patch_A_f3 = gorilla.Patch(A, "f3", new_patched_A_f3, gorilla_setting)
143      gorilla.apply(new_patch_A_f3)
144      assert A.f3 is patched_A_f3
145      assert object.__getattribute__(A, "f3") is new_patched_A_f3
146      assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
147      assert (
148          gorilla.get_original_attribute(A, "f3", bypass_descriptor_protocol=True)
149          is original_A_f3_raw
150      )
151  
152  
153  @pytest.mark.parametrize("store_hit", [True, False])
154  def test_patch_on_inherit_method(store_hit):
155      A, B = gen_class_A_B()
156  
157      original_A_f2 = A.f2
158  
159      def patched_B_f2(self):
160          pass
161  
162      gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
163      patch_B_f2 = gorilla.Patch(B, "f2", patched_B_f2, gorilla_setting)
164      gorilla.apply(patch_B_f2)
165  
166      assert B.f2 is patched_B_f2
167  
168      assert gorilla.get_original_attribute(B, "f2") is original_A_f2
169  
170      gorilla.revert(patch_B_f2)
171      assert B.f2 is original_A_f2
172      assert gorilla.get_original_attribute(B, "f2") is original_A_f2
173      assert "f2" not in B.__dict__  # assert no side effect after reverting
174  
175  
176  @pytest.mark.parametrize("store_hit", [True, False])
177  def test_patch_on_attribute_not_exist(store_hit):
178      A, _ = gen_class_A_B()
179  
180      def patched_fx(self):
181          return 101
182  
183      gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
184      fx_patch = gorilla.Patch(A, "fx", patched_fx, gorilla_setting)
185      gorilla.apply(fx_patch)
186      a1 = A()
187      assert a1.fx() == 101
188      gorilla.revert(fx_patch)
189      assert not hasattr(A, "fx")