testmodels.py
1 """ 2 Models module tests 3 """ 4 5 import unittest 6 7 from unittest.mock import patch 8 9 import torch 10 11 from txtai.models import Models 12 13 14 class TestModels(unittest.TestCase): 15 """ 16 Models tests. 17 """ 18 19 @patch("torch.cuda.is_available") 20 def testDeviceid(self, cuda): 21 """ 22 Test the deviceid method 23 """ 24 25 cuda.return_value = True 26 self.assertEqual(Models.deviceid(True), 0) 27 self.assertEqual(Models.deviceid(False), -1) 28 self.assertEqual(Models.deviceid(0), 0) 29 self.assertEqual(Models.deviceid(1), 1) 30 31 # Test direct torch device 32 # pylint: disable=E1101 33 self.assertEqual(Models.deviceid(torch.device("cpu")), torch.device("cpu")) 34 35 cuda.return_value = False 36 self.assertEqual(Models.deviceid(True), -1) 37 self.assertEqual(Models.deviceid(False), -1) 38 self.assertEqual(Models.deviceid(0), -1) 39 self.assertEqual(Models.deviceid(1), -1) 40 41 def testDevice(self): 42 """ 43 Test the device method 44 """ 45 46 # pylint: disable=E1101 47 self.assertEqual(Models.device("cpu"), torch.device("cpu")) 48 self.assertEqual(Models.device(torch.device("cpu")), torch.device("cpu"))