/ test / python / testmodels / testmodels.py
testmodels.py
 1  """
 2  Models module tests
 3  """
 4  
 5  import unittest
 6  
 7  from unittest.mock import patch
 8  
 9  import torch
10  
11  from txtai.models import Models
12  
13  
14  class TestModels(unittest.TestCase):
15      """
16      Models tests.
17      """
18  
19      @patch("torch.cuda.is_available")
20      def testDeviceid(self, cuda):
21          """
22          Test the deviceid method
23          """
24  
25          cuda.return_value = True
26          self.assertEqual(Models.deviceid(True), 0)
27          self.assertEqual(Models.deviceid(False), -1)
28          self.assertEqual(Models.deviceid(0), 0)
29          self.assertEqual(Models.deviceid(1), 1)
30  
31          # Test direct torch device
32          # pylint: disable=E1101
33          self.assertEqual(Models.deviceid(torch.device("cpu")), torch.device("cpu"))
34  
35          cuda.return_value = False
36          self.assertEqual(Models.deviceid(True), -1)
37          self.assertEqual(Models.deviceid(False), -1)
38          self.assertEqual(Models.deviceid(0), -1)
39          self.assertEqual(Models.deviceid(1), -1)
40  
41      def testDevice(self):
42          """
43          Test the device method
44          """
45  
46          # pylint: disable=E1101
47          self.assertEqual(Models.device("cpu"), torch.device("cpu"))
48          self.assertEqual(Models.device(torch.device("cpu")), torch.device("cpu"))