/ tests / sentence_transformers / test_sentence_transformers_model_export.py
test_sentence_transformers_model_export.py
  1  import json
  2  import os
  3  from unittest import mock
  4  
  5  import numpy as np
  6  import pandas as pd
  7  import pytest
  8  import sentence_transformers
  9  import yaml
 10  from packaging.version import Version
 11  from pyspark.sql import SparkSession
 12  from pyspark.sql.types import ArrayType, DoubleType
 13  from sentence_transformers import SentenceTransformer
 14  
 15  import mlflow
 16  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 17  import mlflow.sentence_transformers
 18  from mlflow import pyfunc
 19  from mlflow.exceptions import MlflowException
 20  from mlflow.models import Model, infer_signature
 21  from mlflow.models.utils import _read_example, load_serving_example
 22  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 23  from mlflow.utils.environment import _mlflow_conda_env
 24  
 25  from tests.helper_functions import (
 26      _assert_pip_requirements,
 27      _compare_logged_code_paths,
 28      _mlflow_major_version_string,
 29      assert_register_model_called_with_local_model_path,
 30      pyfunc_serve_and_score_model,
 31  )
 32  from tests.transformers.version import IS_TRANSFORMERS_V5_OR_LATER
 33  
 34  
 35  @pytest.fixture
 36  def model_path(tmp_path):
 37      return tmp_path.joinpath("model")
 38  
 39  
 40  @pytest.fixture
 41  def basic_model():
 42      return SentenceTransformer("all-MiniLM-L6-v2")
 43  
 44  
 45  @pytest.fixture
 46  def model_with_remote_code():
 47      return SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
 48  
 49  
 50  @pytest.fixture(scope="module")
 51  def spark():
 52      with SparkSession.builder.master("local[1]").getOrCreate() as s:
 53          yield s
 54  
 55  
 56  def test_model_save_and_load(model_path, basic_model):
 57      mlflow.sentence_transformers.save_model(model=basic_model, path=model_path)
 58  
 59      loaded_model = mlflow.sentence_transformers.load_model(model_path)
 60  
 61      encoded_single = loaded_model.encode("I'm just a simple string; nothing to see here.")
 62      encoded_multi = loaded_model.encode(["I'm a string", "I'm also a string", "Please encode me"])
 63  
 64      assert isinstance(encoded_single, np.ndarray)
 65      assert len(encoded_single) == 384
 66      assert isinstance(encoded_multi, np.ndarray)
 67      assert len(encoded_multi) == 3
 68      assert all(len(x) == 384 for x in encoded_multi)
 69  
 70  
 71  @pytest.mark.skipif(
 72      Version(sentence_transformers.__version__) < Version("2.4.0"),
 73      reason="`trust_remote_code` is not supported in Sentence Transformers < 2.3.0 "
 74      "and `include_prompt` from gte-base-en-v1.5 requires 2.4.0 or above",
 75  )
 76  @pytest.mark.skipif(
 77      IS_TRANSFORMERS_V5_OR_LATER,
 78      reason="Alibaba-NLP/gte-base-en-v1.5 has corrupted position_ids buffers on transformers 5.x "
 79      "due to uninitialized meta-device loading (https://github.com/huggingface/transformers/issues/43957)",
 80  )
 81  def test_model_save_and_load_with_custom_code(model_path, model_with_remote_code):
 82      mlflow.sentence_transformers.save_model(model=model_with_remote_code, path=model_path)
 83      loaded_model = mlflow.sentence_transformers.load_model(model_path)
 84  
 85      encoded_single = loaded_model.encode("I'm just a simple string; nothing to see here.")
 86      assert isinstance(encoded_single, np.ndarray)
 87      assert len(encoded_single) == 768
 88  
 89  
 90  def test_dependency_mapping():
 91      pip_requirements = mlflow.sentence_transformers.get_default_pip_requirements()
 92  
 93      expected_requirements = {"sentence-transformers", "torch", "transformers"}
 94      assert {package.split("=")[0] for package in pip_requirements}.intersection(
 95          expected_requirements
 96      ) == expected_requirements
 97  
 98      conda_requirements = mlflow.sentence_transformers.get_default_conda_env()
 99      pip_in_conda = {
100          package.split("=")[0] for package in conda_requirements["dependencies"][2]["pip"]
101      }
102      expected_conda = {"mlflow"}
103      expected_conda.update(expected_requirements)
104      assert pip_in_conda.intersection(expected_conda) == expected_conda
105  
106  
107  def test_logged_data_structure(model_path, basic_model):
108      mlflow.sentence_transformers.save_model(model=basic_model, path=model_path)
109  
110      with model_path.joinpath("requirements.txt").open() as file:
111          requirements = file.read()
112      reqs = {req.split("==")[0] for req in requirements.split("\n")}
113      expected_requirements = {"sentence-transformers", "torch", "transformers"}
114      assert reqs.intersection(expected_requirements) == expected_requirements
115      conda_env = yaml.safe_load(model_path.joinpath("conda.yaml").read_bytes())
116      assert {req.split("==")[0] for req in conda_env["dependencies"][2]["pip"]}.intersection(
117          expected_requirements
118      ) == expected_requirements
119  
120      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
121      assert "model_size_bytes" in mlmodel
122  
123      pyfunc_flavor = mlmodel["flavors"]["python_function"]
124      assert pyfunc_flavor["loader_module"] == "mlflow.sentence_transformers"
125      assert pyfunc_flavor["data"] == mlflow.sentence_transformers.SENTENCE_TRANSFORMERS_DATA_PATH
126  
127      st_flavor = mlmodel["flavors"]["sentence_transformers"]
128      assert st_flavor["pipeline_model_type"] == "BertModel"
129      assert st_flavor["source_model_name"] == "sentence-transformers/all-MiniLM-L6-v2"
130  
131  
132  @pytest.mark.parametrize(
133      ("model_name", "expected"),
134      [
135          (
136              "sentence-transformers/all-MiniLM-L6-v2",
137              "sentence-transformers/all-MiniLM-L6-v2",
138          ),
139          (
140              "/path./to_/local-/path?/sentence-transformers_all-MiniLM-L6-v2/",
141              "sentence-transformers/all-MiniLM-L6-v2",
142          ),
143          (
144              "/path/to/local/path/custom-user-009_model_name_with_underscore/",
145              "custom-user-009/model_name_with_underscore",
146          ),
147      ],
148  )
149  def test_get_transformers_model_name(model_name, expected):
150      assert mlflow.sentence_transformers._get_transformers_model_name(model_name) == expected
151  
152  
153  def test_model_logging_and_inference(basic_model):
154      artifact_path = "sentence_transformer"
155      with mlflow.start_run():
156          model_info = mlflow.sentence_transformers.log_model(basic_model, name=artifact_path)
157  
158      model = mlflow.sentence_transformers.load_model(model_info.model_uri)
159  
160      encoded_single = model.encode(
161          "Encodings provide a fixed width output regardless of input size."
162      )
163      encoded_multi = model.encode([
164          "Just a small town girl",
165          "livin' in a lonely world",
166          "she took the midnight train",
167          "going anywhere",
168      ])
169  
170      assert isinstance(encoded_single, np.ndarray)
171      assert len(encoded_single) == 384
172      assert isinstance(encoded_multi, np.ndarray)
173      assert len(encoded_multi) == 4
174      assert all(len(x) == 384 for x in encoded_multi)
175  
176  
177  def test_load_from_remote_uri(model_path, basic_model, mock_s3_bucket):
178      mlflow.sentence_transformers.save_model(model=basic_model, path=model_path)
179      artifact_root = f"s3://{mock_s3_bucket}"
180      artifact_path = "model"
181      artifact_repo = S3ArtifactRepository(artifact_root)
182      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
183      model_uri = os.path.join(artifact_root, artifact_path)
184      loaded = mlflow.sentence_transformers.load_model(model_uri=str(model_uri))
185  
186      encoding = loaded.encode(
187          "I can see why these are useful when you do distance calculations on them!"
188      )
189  
190      assert len(encoding) == 384
191  
192  
193  def test_log_model_calls_register_model(tmp_path, basic_model):
194      artifact_path = "sentence_transformer"
195      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
196      with mlflow.start_run(), register_model_patch:
197          conda_env = tmp_path.joinpath("conda_env.yaml")
198          _mlflow_conda_env(
199              conda_env, additional_pip_deps=["transformers", "torch", "sentence-transformers"]
200          )
201          model_info = mlflow.sentence_transformers.log_model(
202              basic_model,
203              name=artifact_path,
204              conda_env=str(conda_env),
205              registered_model_name="My super cool encoder",
206          )
207          assert_register_model_called_with_local_model_path(
208              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
209              model_uri=model_info.model_uri,
210              registered_model_name="My super cool encoder",
211          )
212  
213  
214  def test_log_model_with_no_registered_model_name(tmp_path, basic_model):
215      artifact_path = "sentence_transformer"
216      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
217      with mlflow.start_run(), register_model_patch:
218          conda_env = tmp_path.joinpath("conda_env.yaml")
219          _mlflow_conda_env(
220              conda_env, additional_pip_deps=["transformers", "torch", "sentence-transformers"]
221          )
222          mlflow.sentence_transformers.log_model(
223              basic_model,
224              name=artifact_path,
225              conda_env=str(conda_env),
226          )
227          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
228  
229  
230  def test_log_with_pip_requirements(tmp_path, basic_model):
231      expected_mlflow_version = _mlflow_major_version_string()
232  
233      requirements_file = tmp_path.joinpath("requirements.txt")
234      requirements_file.write_text("some-clever-package")
235      with mlflow.start_run():
236          model_info = mlflow.sentence_transformers.log_model(
237              basic_model, name="model", pip_requirements=str(requirements_file)
238          )
239          _assert_pip_requirements(
240              model_info.model_uri,
241              [expected_mlflow_version, "some-clever-package"],
242              strict=True,
243          )
244      with mlflow.start_run():
245          model_info = mlflow.sentence_transformers.log_model(
246              basic_model,
247              name="model",
248              pip_requirements=[f"-r {requirements_file}", "a-hopefully-useful-package"],
249          )
250          _assert_pip_requirements(
251              model_info.model_uri,
252              [expected_mlflow_version, "some-clever-package", "a-hopefully-useful-package"],
253              strict=True,
254          )
255      with mlflow.start_run():
256          model_info = mlflow.sentence_transformers.log_model(
257              basic_model,
258              name="model",
259              pip_requirements=[f"-c {requirements_file}", "i-dunno-maybe-its-good"],
260          )
261          _assert_pip_requirements(
262              model_info.model_uri,
263              [expected_mlflow_version, "i-dunno-maybe-its-good", "-c constraints.txt"],
264              ["some-clever-package"],
265              strict=True,
266          )
267  
268  
269  def test_log_with_extra_pip_requirements(basic_model, tmp_path):
270      expected_mlflow_version = _mlflow_major_version_string()
271      default_requirements = mlflow.sentence_transformers.get_default_pip_requirements()
272      requirements_file = tmp_path.joinpath("requirements.txt")
273      requirements_file.write_text("effective-package")
274      with mlflow.start_run():
275          model_info = mlflow.sentence_transformers.log_model(
276              basic_model, name="model", extra_pip_requirements=str(requirements_file)
277          )
278          _assert_pip_requirements(
279              model_info.model_uri,
280              [expected_mlflow_version, *default_requirements, "effective-package"],
281          )
282      with mlflow.start_run():
283          model_info = mlflow.sentence_transformers.log_model(
284              basic_model,
285              name="model",
286              extra_pip_requirements=[f"-r {requirements_file}", "useful-package"],
287          )
288          _assert_pip_requirements(
289              model_info.model_uri,
290              [expected_mlflow_version, *default_requirements, "effective-package", "useful-package"],
291          )
292      with mlflow.start_run():
293          model_info = mlflow.sentence_transformers.log_model(
294              basic_model,
295              name="model",
296              extra_pip_requirements=[f"-c {requirements_file}", "constrained-pkg"],
297          )
298          _assert_pip_requirements(
299              model_info.model_uri,
300              [
301                  expected_mlflow_version,
302                  *default_requirements,
303                  "constrained-pkg",
304                  "-c constraints.txt",
305              ],
306              ["effective-package"],
307          )
308  
309  
310  def test_model_save_without_conda_env_uses_default_env_with_expected_dependencies(
311      basic_model, model_path
312  ):
313      mlflow.sentence_transformers.save_model(basic_model, model_path)
314      _assert_pip_requirements(
315          model_path, mlflow.sentence_transformers.get_default_pip_requirements()
316      )
317  
318  
319  def test_model_log_without_conda_env_uses_default_env_with_expected_dependencies(
320      basic_model,
321  ):
322      artifact_path = "model"
323      with mlflow.start_run():
324          model_info = mlflow.sentence_transformers.log_model(basic_model, name=artifact_path)
325      _assert_pip_requirements(
326          model_info.model_uri, mlflow.sentence_transformers.get_default_pip_requirements()
327      )
328  
329  
330  def test_log_model_with_code_paths(basic_model):
331      artifact_path = "model"
332      with (
333          mlflow.start_run(),
334          mock.patch("mlflow.sentence_transformers._add_code_from_conf_to_system_path") as add_mock,
335      ):
336          model_info = mlflow.sentence_transformers.log_model(
337              basic_model, name=artifact_path, code_paths=[__file__]
338          )
339          _compare_logged_code_paths(
340              __file__, model_info.model_uri, mlflow.sentence_transformers.FLAVOR_NAME
341          )
342          mlflow.sentence_transformers.load_model(model_info.model_uri)
343          add_mock.assert_called()
344  
345  
346  def test_default_signature_assignment():
347      expected_signature = {
348          "inputs": '[{"type": "string", "required": true}]',
349          "outputs": '[{"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1]}}]',
350          "params": None,
351      }
352  
353      default_signature = mlflow.sentence_transformers._get_default_signature()
354  
355      assert default_signature.to_dict() == expected_signature
356  
357  
358  def test_model_pyfunc_save_load(basic_model, model_path):
359      mlflow.sentence_transformers.save_model(basic_model, model_path)
360      loaded_pyfunc = pyfunc.load_model(model_uri=model_path)
361  
362      sentence = "hello world and hello mlflow"
363      sentences = [sentence, "goodbye my friends", "i am a sentence"]
364      embedding_dim = basic_model.get_sentence_embedding_dimension()
365  
366      emb0 = loaded_pyfunc.predict(sentence)
367      assert emb0.shape == (1, embedding_dim)
368  
369      emb1 = loaded_pyfunc.predict(sentences)
370      emb2 = loaded_pyfunc.predict(pd.Series(sentences))
371      emb3 = loaded_pyfunc.predict(pd.Series(sentences).to_numpy())
372  
373      for emb in [emb1, emb2, emb3]:
374          assert emb.shape == (3, embedding_dim)
375  
376      np.testing.assert_array_equal(emb1, emb2)
377      np.testing.assert_array_equal(emb1, emb3)
378  
379  
380  def test_model_pyfunc_predict_with_params(basic_model, tmp_path):
381      sentence = "hello world and hello mlflow"
382      params = {"batch_size": 16}
383  
384      model_path = tmp_path / "model1"
385      signature = infer_signature(sentence, params=params)
386      mlflow.sentence_transformers.save_model(basic_model, model_path, signature=signature)
387      loaded_pyfunc = pyfunc.load_model(model_uri=model_path)
388      embedding_dim = basic_model.get_sentence_embedding_dimension()
389  
390      emb0 = loaded_pyfunc.predict(sentence, params)
391      assert emb0.shape == (1, embedding_dim)
392  
393      with pytest.raises(MlflowException, match=r"Invalid parameters found"):
394          loaded_pyfunc.predict(sentence, {"batch_size": "16"})
395  
396      model_path = tmp_path / "model3"
397      mlflow.sentence_transformers.save_model(basic_model, model_path)
398      loaded_pyfunc = pyfunc.load_model(model_uri=model_path)
399      with mock.patch("mlflow.models.utils._logger.warning") as mock_warning:
400          loaded_pyfunc.predict(sentence, params)
401      mock_warning.assert_called_with(
402          "`params` can only be specified at inference time if the model signature defines a params "
403          "schema. This model does not define a params schema. Ignoring provided params: "
404          "['batch_size']"
405      )
406  
407  
408  @pytest.mark.skipif(
409      Version(sentence_transformers.__version__) >= Version("3.1.0"),
410      reason="This test only passes for Sentence Transformers < 3.1.0",
411  )
412  def test_model_pyfunc_predict_with_invalid_params(basic_model, tmp_path):
413      sentence = "hello world and hello mlflow"
414      model_path = tmp_path / "model"
415      mlflow.sentence_transformers.save_model(
416          basic_model,
417          model_path,
418          signature=infer_signature(sentence, params={"invalid_param": "value"}),
419      )
420      loaded_pyfunc = pyfunc.load_model(model_uri=model_path)
421  
422      loaded_pyfunc = pyfunc.load_model(model_uri=model_path)
423      with pytest.raises(
424          MlflowException, match=r"Received invalid parameter value for `params` argument"
425      ):
426          loaded_pyfunc.predict(sentence, {"invalid_param": "random_value"})
427  
428  
429  def test_spark_udf(basic_model, spark):
430      params = {"batch_size": 16}
431      with mlflow.start_run():
432          signature = infer_signature(SENTENCES, basic_model.encode(SENTENCES), params)
433          model_info = mlflow.sentence_transformers.log_model(
434              basic_model, name="my_model", signature=signature
435          )
436  
437      result_type = ArrayType(DoubleType())
438      loaded_model = mlflow.pyfunc.spark_udf(
439          spark,
440          model_info.model_uri,
441          result_type=result_type,
442          params=params,
443      )
444  
445      df = spark.createDataFrame([("hello MLflow",), ("bye world",)], ["text"])
446      df = df.withColumn("embedding", loaded_model("text"))
447      assert df.schema[1].dataType == result_type
448  
449      pdf = df.toPandas()
450      assert pdf.shape == (2, 2)
451      assert pdf["embedding"].dtype == "object"
452  
453      embeddings = np.array(pdf.embedding.to_list())
454      assert embeddings.shape == (2, basic_model.get_sentence_embedding_dimension())
455  
456  
457  @pytest.mark.parametrize(
458      ("input1", "input2"),
459      [
460          (["hello world"], ["goodbye world!"]),
461          (["hello world", "i am mlflow"], ["goodbye world!", "i am mlflow"]),
462      ],
463  )
464  def test_pyfunc_serve_and_score(input1, input2, basic_model):
465      with mlflow.start_run():
466          model_info = mlflow.sentence_transformers.log_model(
467              basic_model, name="my_model", input_example=input1
468          )
469      loaded_pyfunc = pyfunc.load_model(model_uri=model_info.model_uri)
470      local_predict = loaded_pyfunc.predict(input1)
471  
472      # Check that the giving the same string to the served model results in the same result
473      inference_data = load_serving_example(model_info.model_uri)
474      assert json.loads(inference_data) == {"inputs": input1}
475      resp = pyfunc_serve_and_score_model(
476          model_info.model_uri,
477          data=inference_data,
478          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
479          extra_args=["--env-manager", "local"],
480      )
481      serving_result = json.loads(resp.content.decode("utf-8"))["predictions"]
482      np.testing.assert_array_equal(local_predict, serving_result)
483  
484      # Check that the giving a different string to the served model results in a different result
485      inference_data = json.dumps({"inputs": input2})
486      resp = pyfunc_serve_and_score_model(
487          model_info.model_uri,
488          data=inference_data,
489          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
490          extra_args=["--env-manager", "local"],
491      )
492      serving_result = json.loads(resp.content.decode("utf-8"))["predictions"]
493      assert not np.equal(local_predict, serving_result).all()
494  
495  
496  SENTENCES = ["hello world", "i am mlflow"]
497  SENTENCES_DF = pd.DataFrame(SENTENCES)
498  SIGNATURE = infer_signature(
499      model_input=SENTENCES,
500      model_output=SentenceTransformer("all-MiniLM-L6-v2").encode(SENTENCES),
501  )
502  SIGNATURE_FROM_EXAMPLE = infer_signature(
503      model_input=SENTENCES_DF,
504      model_output=SentenceTransformer("all-MiniLM-L6-v2").encode(SENTENCES),
505  )
506  
507  
508  @pytest.mark.parametrize(
509      ("example", "signature", "expected_signature"),
510      [
511          (None, None, mlflow.sentence_transformers._get_default_signature()),
512          (SENTENCES_DF, None, SIGNATURE_FROM_EXAMPLE),
513          (None, SIGNATURE, SIGNATURE),
514          (SENTENCES, SIGNATURE, SIGNATURE),
515      ],
516  )
517  def test_signature_and_examples_are_saved_correctly(
518      example, signature, expected_signature, basic_model, model_path
519  ):
520      mlflow.sentence_transformers.save_model(
521          basic_model,
522          path=model_path,
523          signature=signature,
524          input_example=example,
525      )
526      mlflow_model = Model.load(model_path)
527  
528      assert mlflow_model.signature == expected_signature
529  
530      if example is None:
531          assert mlflow_model.saved_input_example_info is None
532      else:
533          if isinstance(example, pd.DataFrame):
534              assert mlflow_model.saved_input_example_info["type"] == "dataframe"
535              pd.testing.assert_frame_equal(_read_example(mlflow_model, model_path), example)
536          else:
537              assert mlflow_model.saved_input_example_info["type"] == "json_object"
538              np.testing.assert_equal(_read_example(mlflow_model, model_path), example)
539  
540  
541  def test_model_log_with_signature_inference(basic_model):
542      artifact_path = "model"
543  
544      with mlflow.start_run():
545          model_info = mlflow.sentence_transformers.log_model(
546              basic_model, name=artifact_path, input_example=SENTENCES
547          )
548  
549      loaded_model_info = Model.load(model_info.model_uri)
550      assert loaded_model_info.signature == SIGNATURE
551  
552  
553  def test_verify_task_and_update_metadata():
554      # Update embedding task with empty metadata
555      metadata = mlflow.sentence_transformers._verify_task_and_update_metadata("llm/v1/embeddings")
556      assert metadata == {"task": "llm/v1/embeddings"}
557      # Update embedding task with metadata containing task
558      metadata = mlflow.sentence_transformers._verify_task_and_update_metadata(
559          "llm/v1/embeddings", metadata
560      )
561      assert metadata == {"task": "llm/v1/embeddings"}
562  
563      # Update embedding task with metadata containing different task
564      metadata = {"task": "llm/v1/completions"}
565      with pytest.raises(
566          MlflowException, match=r"Task type is inconsistent with the task value from metadata"
567      ):
568          mlflow.sentence_transformers._verify_task_and_update_metadata("llm/v1/embeddings", metadata)
569  
570      # Invalid task type
571      with pytest.raises(MlflowException, match=r"Task type could only be llm/v1/embeddings"):
572          mlflow.sentence_transformers._verify_task_and_update_metadata("llm/v1/completions")
573  
574  
575  def test_model_pyfunc_with_dict_input(basic_model, model_path):
576      mlflow.sentence_transformers.save_model(basic_model, model_path, task="llm/v1/embeddings")
577      loaded_pyfunc = pyfunc.load_model(model_uri=model_path)
578  
579      sentence = "hello world and hello mlflow"
580      sentences = [sentence, "goodbye my friends", "i am a sentence"]
581      embedding_dim = basic_model.get_sentence_embedding_dimension()
582  
583      single_input = {"input": sentence}
584      emb_single_input = loaded_pyfunc.predict(single_input)
585  
586      assert isinstance(emb_single_input, dict)
587      assert len(emb_single_input["data"]) == 1
588      assert isinstance(emb_single_input["data"][0], dict)
589      assert emb_single_input["data"][0]["embedding"].shape == (embedding_dim,)
590      assert emb_single_input["usage"]["prompt_tokens"] == 8
591  
592      multiple_input = {"input": sentences}
593      emb_multiple_input = loaded_pyfunc.predict(multiple_input)
594  
595      assert isinstance(emb_multiple_input, dict)
596      assert len(emb_multiple_input["data"]) == 3
597      assert emb_multiple_input["data"][0]["embedding"].shape == (embedding_dim,)
598      assert emb_multiple_input["usage"]["prompt_tokens"] == 19