/ test / python / testpipeline / testtext / testentity.py
testentity.py
 1  """
 2  Entity module tests
 3  """
 4  
 5  import unittest
 6  
 7  from txtai.pipeline import Entity
 8  
 9  
10  class TestEntity(unittest.TestCase):
11      """
12      Entity tests.
13      """
14  
15      @classmethod
16      def setUpClass(cls):
17          """
18          Create entity instance.
19          """
20  
21          cls.entity = Entity("dslim/bert-base-NER")
22  
23      def testEntity(self):
24          """
25          Test entity
26          """
27  
28          # Run entity extraction
29          entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg")
30          self.assertEqual([e[0] for e in entities], ["Canada", "Manhattan"])
31  
32      def testEntityFlatten(self):
33          """
34          Test entity with flattened output
35          """
36  
37          # Test flatten
38          entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", flatten=True)
39          self.assertEqual(entities, ["Canada", "Manhattan"])
40  
41          # Test flatten with join
42          entities = self.entity(
43              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", flatten=True, join=True
44          )
45          self.assertEqual(entities, "Canada Manhattan")
46  
47      def testEntityTypes(self):
48          """
49          Test entity type filtering
50          """
51  
52          # Run entity extraction
53          entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", labels=["PER"])
54          self.assertFalse(entities)
55  
56      def testGliner(self):
57          """
58          Test entity pipeline with a GLiNER model
59          """
60  
61          entity = Entity("neuml/gliner-bert-tiny")
62          entities = entity("My name is John Smith.", flatten=True)
63          self.assertEqual(entities, ["John Smith"])