test_dependencies_schema.py
1 from unittest import mock 2 3 from mlflow.models import dependencies_schemas 4 from mlflow.models.dependencies_schemas import ( 5 DependenciesSchemas, 6 DependenciesSchemasType, 7 RetrieverSchema, 8 _get_dependencies_schemas, 9 _get_retriever_schema, 10 set_retriever_schema, 11 ) 12 13 14 def test_retriever_creation(): 15 vsi = RetrieverSchema( 16 name="index-name", 17 primary_key="primary-key", 18 text_column="text-column", 19 doc_uri="doc-uri", 20 other_columns=["column1", "column2"], 21 ) 22 assert vsi.name == "index-name" 23 assert vsi.primary_key == "primary-key" 24 assert vsi.text_column == "text-column" 25 assert vsi.doc_uri == "doc-uri" 26 assert vsi.other_columns == ["column1", "column2"] 27 28 29 def test_retriever_to_dict(): 30 vsi = RetrieverSchema( 31 name="index-name", 32 primary_key="primary-key", 33 text_column="text-column", 34 doc_uri="doc-uri", 35 other_columns=["column1", "column2"], 36 ) 37 expected_dict = { 38 DependenciesSchemasType.RETRIEVERS.value: [ 39 { 40 "name": "index-name", 41 "primary_key": "primary-key", 42 "text_column": "text-column", 43 "doc_uri": "doc-uri", 44 "other_columns": ["column1", "column2"], 45 } 46 ] 47 } 48 assert vsi.to_dict() == expected_dict 49 50 51 def test_retriever_from_dict(): 52 data = { 53 "name": "index-name", 54 "primary_key": "primary-key", 55 "text_column": "text-column", 56 "doc_uri": "doc-uri", 57 "other_columns": ["column1", "column2"], 58 } 59 vsi = RetrieverSchema.from_dict(data) 60 assert vsi.name == "index-name" 61 assert vsi.primary_key == "primary-key" 62 assert vsi.text_column == "text-column" 63 assert vsi.doc_uri == "doc-uri" 64 assert vsi.other_columns == ["column1", "column2"] 65 66 67 def test_dependencies_schemas_to_dict(): 68 vsi = RetrieverSchema( 69 name="index-name", 70 primary_key="primary-key", 71 text_column="text-column", 72 doc_uri="doc-uri", 73 other_columns=["column1", "column2"], 74 ) 75 schema = DependenciesSchemas(retriever_schemas=[vsi]) 76 expected_dict = { 77 "dependencies_schemas": { 78 DependenciesSchemasType.RETRIEVERS.value: [ 79 { 80 "name": "index-name", 81 "primary_key": "primary-key", 82 "text_column": "text-column", 83 "doc_uri": "doc-uri", 84 "other_columns": ["column1", "column2"], 85 } 86 ] 87 } 88 } 89 assert schema.to_dict() == expected_dict 90 91 92 def test_set_retriever_schema_creation(): 93 set_retriever_schema( 94 primary_key="primary-key", 95 text_column="text-column", 96 doc_uri="doc-uri", 97 other_columns=["column1", "column2"], 98 ) 99 with _get_dependencies_schemas() as schema: 100 assert schema.to_dict()["dependencies_schemas"] == { 101 DependenciesSchemasType.RETRIEVERS.value: [ 102 { 103 "doc_uri": "doc-uri", 104 "name": "retriever", 105 "other_columns": ["column1", "column2"], 106 "primary_key": "primary-key", 107 "text_column": "text-column", 108 } 109 ] 110 } 111 112 # Schema is automatically reset 113 with _get_dependencies_schemas() as schema: 114 assert schema.to_dict() is None 115 assert _get_retriever_schema() == [] 116 117 118 def test_set_retriever_schema_creation_with_name(): 119 set_retriever_schema( 120 name="my_ret_2", 121 primary_key="primary-key", 122 text_column="text-column", 123 doc_uri="doc-uri", 124 other_columns=["column1", "column2"], 125 ) 126 with _get_dependencies_schemas() as schema: 127 assert schema.to_dict()["dependencies_schemas"] == { 128 DependenciesSchemasType.RETRIEVERS.value: [ 129 { 130 "doc_uri": "doc-uri", 131 "name": "my_ret_2", 132 "other_columns": ["column1", "column2"], 133 "primary_key": "primary-key", 134 "text_column": "text-column", 135 } 136 ] 137 } 138 139 # Schema is automatically reset 140 with _get_dependencies_schemas() as schema: 141 assert schema.to_dict() is None 142 assert _get_retriever_schema() == [] 143 144 145 def test_set_retriever_schema_empty_creation(): 146 with _get_dependencies_schemas() as schema: 147 assert schema.to_dict() is None 148 149 150 def test_multiple_set_retriever_schema_creation_with_name(): 151 set_retriever_schema( 152 name="my_ret_1", 153 primary_key="primary-key-2", 154 text_column="text-column-1", 155 doc_uri="doc-uri-3", 156 other_columns=["column1", "column2"], 157 ) 158 159 set_retriever_schema( 160 name="my_ret_2", 161 primary_key="primary-key", 162 text_column="text-column", 163 doc_uri="doc-uri", 164 other_columns=["column1", "column2"], 165 ) 166 with _get_dependencies_schemas() as schema: 167 assert schema.to_dict()["dependencies_schemas"] == { 168 DependenciesSchemasType.RETRIEVERS.value: [ 169 { 170 "doc_uri": "doc-uri-3", 171 "name": "my_ret_1", 172 "other_columns": ["column1", "column2"], 173 "primary_key": "primary-key-2", 174 "text_column": "text-column-1", 175 }, 176 { 177 "doc_uri": "doc-uri", 178 "name": "my_ret_2", 179 "other_columns": ["column1", "column2"], 180 "primary_key": "primary-key", 181 "text_column": "text-column", 182 }, 183 ] 184 } 185 186 # Schema is automatically reset 187 with _get_dependencies_schemas() as schema: 188 assert schema.to_dict() is None 189 assert _get_retriever_schema() == [] 190 191 192 def test_multiple_set_retriever_schema_with_same_name_with_different_schemas(): 193 set_retriever_schema( 194 name="my_ret_1", 195 primary_key="primary-key-2", 196 text_column="text-column-1", 197 doc_uri="doc-uri-3", 198 other_columns=["column1", "column2"], 199 ) 200 set_retriever_schema( 201 name="my_ret_2", 202 primary_key="primary-key", 203 text_column="text-column", 204 doc_uri="doc-uri", 205 other_columns=["column1", "column2"], 206 ) 207 208 with mock.patch.object(dependencies_schemas, "_logger") as mock_logger: 209 set_retriever_schema( 210 name="my_ret_1", 211 primary_key="primary-key", 212 text_column="text-column", 213 doc_uri="doc-uri", 214 other_columns=["column1", "column2"], 215 ) 216 mock_logger.warning.assert_called_once_with( 217 "A retriever schema with the name 'my_ret_1' already exists. " 218 "Overriding the existing schema." 219 ) 220 221 with _get_dependencies_schemas() as schema: 222 assert schema.to_dict()["dependencies_schemas"] == { 223 DependenciesSchemasType.RETRIEVERS.value: [ 224 { 225 "doc_uri": "doc-uri", 226 "name": "my_ret_1", 227 "other_columns": ["column1", "column2"], 228 "primary_key": "primary-key", 229 "text_column": "text-column", 230 }, 231 { 232 "doc_uri": "doc-uri", 233 "name": "my_ret_2", 234 "other_columns": ["column1", "column2"], 235 "primary_key": "primary-key", 236 "text_column": "text-column", 237 }, 238 ] 239 } 240 241 242 def test_multiple_set_retriever_schema_with_same_name_with_same_schema(): 243 set_retriever_schema( 244 name="my_ret_1", 245 primary_key="primary-key", 246 text_column="text-column", 247 doc_uri="doc-uri", 248 other_columns=["column1", "column2"], 249 ) 250 set_retriever_schema( 251 name="my_ret_2", 252 primary_key="primary-key", 253 text_column="text-column", 254 doc_uri="doc-uri", 255 other_columns=["column1", "column2"], 256 ) 257 258 with mock.patch.object(dependencies_schemas, "_logger") as mock_logger: 259 set_retriever_schema( 260 name="my_ret_1", 261 primary_key="primary-key", 262 text_column="text-column", 263 doc_uri="doc-uri", 264 other_columns=["column1", "column2"], 265 ) 266 mock_logger.warning.assert_not_called() 267 268 with _get_dependencies_schemas() as schema: 269 assert schema.to_dict()["dependencies_schemas"] == { 270 DependenciesSchemasType.RETRIEVERS.value: [ 271 { 272 "doc_uri": "doc-uri", 273 "name": "my_ret_1", 274 "other_columns": ["column1", "column2"], 275 "primary_key": "primary-key", 276 "text_column": "text-column", 277 }, 278 { 279 "doc_uri": "doc-uri", 280 "name": "my_ret_2", 281 "other_columns": ["column1", "column2"], 282 "primary_key": "primary-key", 283 "text_column": "text-column", 284 }, 285 ] 286 }