# pylint: disable=logging-fstring-interpolation
from typing import Dict, List, Optional, Tuple, Any, Set
import logging
import numpy as np
import torch

from forte.common.configuration import Config
from forte.common.resources import Resources
from import DataPack
from import Annotation
from forte.processors.base.batch_processor import RequestPackingProcessor
from ft.onto.base_ontology import EntityMention, Subword
from transformers import (

[docs]class BioBERTNERPredictor(RequestPackingProcessor): """ An Named Entity Recognizer fine-tuned on BioBERT Note that to use :class:`BioBERTNERPredictor`, the :attr:`ontology` of :class:`Pipeline` must be an ontology that include ``ft.onto.base_ontology.Subword`` and ``ft.onto.base_ontology.Sentence``. """ def __init__(self): super().__init__() self.resources = None self.device = None self.ft_configs = None self.model_config = None self.model = None self.tokenizer = None
[docs] def initialize(self, resources: Resources, configs: Config): super().initialize(resources, configs) if resources.get("device"): self.device = resources.get("device") else: self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) self.resources = resources self.ft_configs = configs model_path = self.ft_configs.model_path self.model_config = AutoConfig.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForTokenClassification.from_pretrained( model_path, from_tf=bool(".ckpt" in model_path), config=self.model_config, )
[docs] @torch.no_grad() def predict( self, data_batch: Dict[str, Dict[str, List[str]]] ) -> Dict[str, Dict[str, List]]: # numpy typing is removed due to # incompatibility issues due to upgrading from numpy verions 1.19 to # 1.21 sentences = data_batch["context"] subwords = data_batch["Subword"] inputs = self.tokenizer(sentences, return_tensors="pt", padding=True) inputs = {key: for key, value in inputs.items()} input_shape = inputs["input_ids"].shape if input_shape[1] > 512: # TODO: Temporarily work around the length problem. # The real solution should further split up the sentences to make # the sentences shorter. labels_idx = ( inputs["input_ids"] .new_full(input_shape, 2, device="cpu")[:, 1:-1] .numpy() ) else: outputs = self.model(**inputs)[0].cpu().numpy() score = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True) labels_idx = score.argmax(axis=-1)[:, 1:-1] # Remove placeholders. pred: Dict = {"Subword": {"ner": [], "tid": []}} for i in range(len(subwords["tid"])): tids = subwords["tid"][i] ner_tags = [] for j in range(len(tids)): ner_tags.append(self.model.config.id2label[labels_idx[i, j]]) pred["Subword"]["ner"].append(np.array(ner_tags)) pred["Subword"]["tid"].append(np.array(tids)) return pred
def _complete_entity( self, subword_entities: List[Dict[str, Any]], data_pack: DataPack, tids: List[int], ) -> Tuple[int, int]: """Complete entity span from predicted subword entities Start from the first subword with predicted entity. If this entity is a subword (e.g. "##on"), then move on to the previous subword until it's no longer a subword (e.g. "br") """ first_idx: int = subword_entities[0]["idx"] first_tid = subword_entities[0]["tid"] while ( first_idx > 0 and not data_pack.get_entry(first_tid).is_first_segment ): first_idx -= 1 first_tid = tids[first_idx] last_idx: int = subword_entities[-1]["idx"] while ( last_idx < len(tids) - 1 and not data_pack.get_entry(tids[last_idx + 1]).is_first_segment ): last_idx += 1 return first_idx, last_idx def _compose_entities( self, entities: List[Dict[str, Any]], data_pack: DataPack, tids: List[int], ) -> List[Tuple[int, int]]: """Composes entity spans from subword entity predictions Label Syntax: A "B" label indicates the beginning of an entity, an "I" label indicates the continuation of an entity, and an "O" label indicates the absence of an entity. Example: with - br - ##on - ##chi - ##oli - ##tis - . O - B - I - I - I - I - O Due to possible instabilities of the model on out-of-distribution data, sometimes the prediction may not follow the label format. Example 1: with - br - ##on - ##chi - ##oli - ##tis - . O - B - I - O - I - O - O Example 2: with - br - ##on - ##chi - ##oli - ##tis - . O - O - O - I - I - I - O This method takes entity predictions of subwords and recovers the set of complete entities, defined by the indices of their beginning and ending subwords. (begin_idx, end_idx) """ complete_entities: List[Tuple[int, int]] = [] subword_entities: List[Dict[str, Any]] = [] for entity in entities: subword = data_pack.get_entry(entity["tid"]) if entity["label"] == "B" and subword.is_first_segment: # Flush the existing entity and start a new entity if subword_entities: complete_entity = self._complete_entity( subword_entities, data_pack, tids ) complete_entities.append(complete_entity) subword_entities = [entity] else: # Continue accumuulating subword entities subword_entities.append(entity) if subword_entities: complete_entity = self._complete_entity( subword_entities, data_pack, tids ) complete_entities.append(complete_entity) return complete_entities
[docs] def pack( self, pack: DataPack, predict_results: Optional[Dict[str, Dict[str, List[Any]]]] = None, context: Optional[Annotation] = None, ): """ Write the prediction results back to datapack by aggregating subwords into named entity mentions. """ if predict_results is None: return if context is not None: logging.warning("context parameter is not used in pack() method.") for i in range(len(predict_results["Subword"]["tid"])): tids = predict_results["Subword"]["tid"][i] labels = predict_results["Subword"]["ner"][i] # Filter to labels not in `self.ft_configs.ignore_labels` entities = [ dict(idx=idx, label=label, tid=tid) for idx, (label, tid) in enumerate(zip(labels, tids)) if label not in self.ft_configs.ignore_labels ] entity_groups = self._compose_entities(entities, pack, tids) # Add NER tags and create EntityMention ontologies. for first_idx, last_idx in entity_groups: first_token: Subword = pack.get_entry(tids[first_idx]) begin = first_token.span.begin last_token: Subword = pack.get_entry(tids[last_idx]) end = last_token.span.end entity = EntityMention(pack, begin, end) entity.ner_type = self.ft_configs.ner_type
[docs] @classmethod def default_configs(cls): r"""Default config for NER Predictor""" return { "model_path": None, "ner_type": "BioEntity", "ignore_labels": ["O"], "batcher": { "batch_size": 10, "context_type": "ft.onto.base_ontology.Sentence", "requests": { "ft.onto.base_ontology.Subword": [], "ft.onto.base_ontology.Sentence": [], }, }, }
[docs] def record(self, record_meta: Dict[str, Set[str]]): r"""Method to add output type record of current processor to :attr:``. Args: record_meta: the field in the datapack for type record that need to fill in for consistency checking. """ record_meta["ft.onto.base_ontology.EntityMention"] = {"ner_type"}
[docs] def expected_types_and_attributes(self): r"""Method to add expected type ft.onto.base_ontology.Subword` with attribute `is_first_segment` and `ft.onto.base_ontology.Sentence` which would be checked before running the processor if the pipeline is initialized with `enforce_consistency=True` or :meth:`~forte.pipeline.Pipeline.enforce_consistency` was enabled for the pipeline. """ return { "ft.onto.base_ontology.Subword": {"is_first_segment"}, "ft.onto.base_ontology.Sentence": set(), }