diff --git a/forte/data/data_store.py b/forte/data/data_store.py index aab05364b..f43fdef53 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -1481,6 +1481,173 @@ def get_attribute(self, tid: int, attr_name: str) -> Any: return entry[attr_id] + def get_attributes_of_tid(self, tid: int, attr_names: List[str]) -> dict: + r"""This function returns the value of attributes listed in + ``attr_names`` for the entry with ``tid``. It locates the entry data + with ``tid`` and finds attributes listed in ``attr_names`` and return + as a dict. + + Args: + tid: Unique id of the entry. + attr_names: List of names of the attribute. + + Returns: + A dict with keys listed in ``attr_names`` for attributes of the + entry with ``tid``. + + Raises: + KeyError: when ``tid`` or ``attr_name`` is not found. + """ + entry, entry_type = self.get_entry(tid) + attrs: dict = {} + for attr_name in attr_names: + try: + attr_id = self._get_type_attribute_dict(entry_type)[attr_name][ + constants.ATTR_INDEX_KEY + ] + except KeyError as e: + raise KeyError( + f"{entry_type} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + return attrs + + def get_attributes_of_tids( + self, list_of_tid: List[int], attr_names: List[str] + ) -> List[Any]: + r"""This function returns the value of attributes listed in + ``attr_names`` for entries in listed in the ``list_of_tid``. + It locates the entries data with ``tid`` and put attributes + listed in ``attr_name`` in a dict for each entry. + + Args: + list_of_tid: List of unique ids of the entry. + attr_names: List of name of the attribute. + + Returns: + A list of dict with ``attr_name`` as key for attributes + of the entries requested. + + Raises: + KeyError: when ``tid`` or ``attr_name`` is not found. + """ + tids_attrs = [] + for tid in list_of_tid: + entry, entry_type = self.get_entry(tid) + attrs: dict = {} + for attr_name in attr_names: + try: + attr_id = self._get_type_attribute_dict(entry_type)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{entry_type} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + tids_attrs.append(attrs) + + return tids_attrs + + def get_attributes_of_type( + self, + type_name: str, + attributes_names: List[str], + include_sub_type: bool = True, + range_span: Optional[Tuple[int, int]] = None, + ) -> Iterator[dict]: + r"""This function fetches required attributes of entries from the + data store of type ``type_name``. If `include_sub_type` is set to + True and ``type_name`` is in [Annotation], this function also + fetches entries of subtype of ``type_name``. Otherwise, it only + fetches entries of type ``type_name``. + + Args: + type_name: The fully qualified name of the entry. + attributes_names: list of attributes to be fetched for each entry + include_sub_type: A boolean to indicate whether get its subclass. + range_span: A tuple that contains the begin and end indices + of the searching range of entries. + + Returns: + An iterator of the attributes of the entry in dict matching the + provided arguments. + """ + + entry_class = get_class(type_name) + all_types = set() + if include_sub_type: + for type in self.__elements: + if issubclass(get_class(type), entry_class): + all_types.add(type) + else: + all_types.add(type_name) + all_types = list(all_types) + all_types.sort() + + if self._is_annotation(type_name): + if range_span is None: + # yield from self.co_iterator_annotation_like(all_types) + for entry in self.co_iterator_annotation_like(all_types): + attrs: dict = {"tid": entry[0]} + for attr_name in attributes_names: + try: + attr_id = self._get_type_attribute_dict(type_name)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{type_name} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + yield attrs + else: + for entry in self.co_iterator_annotation_like( + all_types, range_span=range_span + ): + attrs = {"tid": entry[0]} + for attr_name in attributes_names: + try: + attr_id = self._get_type_attribute_dict(type_name)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{type_name} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + yield attrs # attrs instead of entry + elif issubclass(entry_class, Link): + raise NotImplementedError( + f"{type_name} of Link is not currently supported." + ) + elif issubclass(entry_class, Group): + raise NotImplementedError( + f"{type_name} of Group is not currently supported." + ) + else: + if type_name not in self.__elements: + raise ValueError(f"type {type_name} does not exist") + # yield from self.iter(type_name) + for entry in self.iter(type_name): + attrs = {"tid": entry[0]} + for attr_name in attributes_names: + try: + attr_id = self._get_type_attribute_dict(type_name)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{type_name} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + yield attrs + def _get_attr(self, tid: int, attr_id: int) -> Any: r"""This function locates the entry data with ``tid`` and gets the value of ``attr_id`` of this entry. Called by `get_attribute()`. diff --git a/tests/forte/data/data_store_test.py b/tests/forte/data/data_store_test.py index 33fc1fc93..b9e7833c9 100644 --- a/tests/forte/data/data_store_test.py +++ b/tests/forte/data/data_store_test.py @@ -699,7 +699,6 @@ def value_err_fn(): self.assertRaises(ValueError, value_err_fn) def test_add_annotation_raw(self): - # test add Document entry tid_doc: int = self.data_store.add_entry_raw( type_name="ft.onto.base_ontology.Document", @@ -1039,6 +1038,126 @@ def test_get_attribute(self): ): self.data_store.get_attribute(9999, "class") + def test_get_attributes_of_tid(self): + result_dict = self.data_store.get_attributes_of_tid( + 9999, ["begin", "end", "speaker"] + ) + result_dict2 = self.data_store.get_attributes_of_tid( + 3456, ["payload_idx", "classifications"] + ) + + self.assertEqual(result_dict["begin"], 6) + self.assertEqual(result_dict["end"], 9) + self.assertEqual(result_dict["speaker"], "teacher") + self.assertEqual(result_dict2["payload_idx"], 1) + self.assertEqual(result_dict2["classifications"], {}) + + # Entry with such tid does not exist + with self.assertRaisesRegex(KeyError, "Entry with tid 1111 not found."): + self.data_store.get_attributes_of_tid(1111, ["speaker"]) + + # Get attribute field that does not exist + with self.assertRaisesRegex( + KeyError, "ft.onto.base_ontology.Sentence has no class attribute." + ): + self.data_store.get_attributes_of_tid(9999, ["class"]) + + def test_get_attributes_of_tids(self): + tids_attrs: list[dict] + # tids_attrs2: list[dict] + tids_attrs = self.data_store.get_attributes_of_tids( + [9999, 3456], ["begin", "end", "payload_idx"] + ) + tids_attrs2 = self.data_store.get_attributes_of_tids( + [9999], ["begin", "speaker"] + ) + + self.assertEqual(tids_attrs2[0]["begin"], 6) + self.assertEqual(tids_attrs[0]["end"], 9) + self.assertEqual(tids_attrs[1]["payload_idx"], 1) + self.assertEqual(tids_attrs2[0]["speaker"], "teacher") + + # Entry with such tid does not exist + with self.assertRaisesRegex(KeyError, "Entry with tid 1111 not found."): + self.data_store.get_attributes_of_tids([1111], ["speaker"]) + + # Get attribute field that does not exist + with self.assertRaisesRegex( + KeyError, "ft.onto.base_ontology.Sentence has no class attribute." + ): + self.data_store.get_attributes_of_tids([9999], ["class"]) + + def test_get_attributes_of_type(self): + # get document entries + instances = list( + self.data_store.get_attributes_of_type( + "ft.onto.base_ontology.Document", + ["begin", "end", "payload_idx"], + ) + ) + # print(instances) + self.assertEqual(len(instances), 2) + # check tid + self.assertEqual(instances[0]["tid"], 1234) + self.assertEqual(instances[0]["end"], 5) + self.assertEqual(instances[1]["tid"], 3456) + self.assertEqual(instances[1]["begin"], 10) + + # For types other than annotation, group or link, not support include_subtype + instances = list( + self.data_store.get_attributes_of_type( + "forte.data.ontology.core.Entry", ["begin", "end"] + ) + ) + self.assertEqual(len(instances), 0) + + self.assertEqual( + self.data_store.get_length("forte.data.ontology.core.Entry"), 0 + ) + + # get annotations with subclasses and range annotation + instances = list( + self.data_store.get_attributes_of_type( + "forte.data.ontology.top.Annotation", + ["begin", "end"], + range_span=(1, 20), + ) + ) + self.assertEqual(len(instances), 2) + + # get groups with subclasses + # instances = list(self.data_store.get_attributes_of_type( + # "forte.data.ontology.top.Group", ["begin", "end"])) + # self.assertEqual(len(instances), 3) + + # # get groups with subclasses and range annotation + # instances = list( + # self.data_store.get( + # "forte.data.ontology.top.Group", range_span=(1, 20) + # ) + # ) + # self.assertEqual(len(instances), 0) + # + # # get links with subclasses + # instances = list(self.data_store.get("forte.data.ontology.top.Link")) + # self.assertEqual(len(instances), 1) + # + # # get links with subclasses and range annotation + # instances = list( + # self.data_store.get( + # "forte.data.ontology.top.Link", range_span=(0, 9) + # ) + # ) + # self.assertEqual(len(instances), 1) + # + # # get links with subclasses and range annotation + # instances = list( + # self.data_store.get( + # "forte.data.ontology.top.Link", range_span=(4, 11) + # ) + # ) + # self.assertEqual(len(instances), 0) + def test_set_attribute(self): # change attribute self.data_store.set_attribute(9999, "speaker", "student") @@ -1328,7 +1447,6 @@ def test_get_entry_attribute_by_class(self): ) def test_is_subclass(self): - import forte self.assertEqual( @@ -1396,7 +1514,6 @@ def test_is_subclass(self): ) def test_check_onto_file(self): - expected_type_attributes = { "ft.onto.test.Description": { "attributes": {