# pylint: disable=logging-fstring-interpolation
from typing import Dict, List, Optional, Tuple, Any, Set
import logging
import numpy as np
import torch
from forte.common.configuration import Config
from forte.common.resources import Resources
from forte.data.data_pack import DataPack
from forte.data.ontology.top import Annotation
from forte.processors.base.batch_processor import RequestPackingProcessor
from ft.onto.base_ontology import EntityMention, Subword
from transformers import (
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
)
__all__ = [
"BioBERTNERPredictor",
]
[docs]class BioBERTNERPredictor(RequestPackingProcessor):
"""
An Named Entity Recognizer fine-tuned on BioBERT
Note that to use :class:`BioBERTNERPredictor`, the :attr:`ontology` of
:class:`Pipeline` must be an ontology that include
``ft.onto.base_ontology.Subword`` and ``ft.onto.base_ontology.Sentence``.
"""
def __init__(self):
super().__init__()
self.resources = None
self.device = None
self.ft_configs = None
self.model_config = None
self.model = None
self.tokenizer = None
[docs] def initialize(self, resources: Resources, configs: Config):
super().initialize(resources, configs)
if resources.get("device"):
self.device = resources.get("device")
else:
self.device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.resources = resources
self.ft_configs = configs
model_path = self.ft_configs.model_path
self.model_config = AutoConfig.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForTokenClassification.from_pretrained(
model_path,
from_tf=bool(".ckpt" in model_path),
config=self.model_config,
)
self.model.to(self.device)
[docs] @torch.no_grad()
def predict(
self, data_batch: Dict[str, Dict[str, List[str]]]
) -> Dict[str, Dict[str, List]]: # numpy typing is removed due to
# incompatibility issues due to upgrading from numpy verions 1.19 to
# 1.21
sentences = data_batch["context"]
subwords = data_batch["Subword"]
inputs = self.tokenizer(sentences, return_tensors="pt", padding=True)
inputs = {key: value.to(self.device) for key, value in inputs.items()}
input_shape = inputs["input_ids"].shape
if input_shape[1] > 512:
# TODO: Temporarily work around the length problem.
# The real solution should further split up the sentences to make
# the sentences shorter.
labels_idx = (
inputs["input_ids"]
.new_full(input_shape, 2, device="cpu")[:, 1:-1]
.numpy()
)
else:
outputs = self.model(**inputs)[0].cpu().numpy()
score = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1)[:, 1:-1] # Remove placeholders.
pred: Dict = {"Subword": {"ner": [], "tid": []}}
for i in range(len(subwords["tid"])):
tids = subwords["tid"][i]
ner_tags = []
for j in range(len(tids)):
ner_tags.append(self.model.config.id2label[labels_idx[i, j]])
pred["Subword"]["ner"].append(np.array(ner_tags))
pred["Subword"]["tid"].append(np.array(tids))
return pred
def _complete_entity(
self,
subword_entities: List[Dict[str, Any]],
data_pack: DataPack,
tids: List[int],
) -> Tuple[int, int]:
"""Complete entity span from predicted subword entities
Start from the first subword with predicted entity. If this entity
is a subword (e.g. "##on"), then move on to the previous subword until
it's no longer a subword (e.g. "br")
"""
first_idx: int = subword_entities[0]["idx"]
first_tid = subword_entities[0]["tid"]
while (
first_idx > 0
and not data_pack.get_entry(first_tid).is_first_segment
):
first_idx -= 1
first_tid = tids[first_idx]
last_idx: int = subword_entities[-1]["idx"]
while (
last_idx < len(tids) - 1
and not data_pack.get_entry(tids[last_idx + 1]).is_first_segment
):
last_idx += 1
return first_idx, last_idx
def _compose_entities(
self,
entities: List[Dict[str, Any]],
data_pack: DataPack,
tids: List[int],
) -> List[Tuple[int, int]]:
"""Composes entity spans from subword entity predictions
Label Syntax:
A "B" label indicates the beginning of an entity, an "I" label
indicates the continuation of an entity, and an "O" label indicates
the absence of an entity.
Example: with - br - ##on - ##chi - ##oli - ##tis - .
O - B - I - I - I - I - O
Due to possible instabilities of the model on out-of-distribution data,
sometimes the prediction may not follow the label format.
Example 1: with - br - ##on - ##chi - ##oli - ##tis - .
O - B - I - O - I - O - O
Example 2: with - br - ##on - ##chi - ##oli - ##tis - .
O - O - O - I - I - I - O
This method takes entity predictions of subwords and recovers the
set of complete entities, defined by the indices of their beginning
and ending subwords. (begin_idx, end_idx)
"""
complete_entities: List[Tuple[int, int]] = []
subword_entities: List[Dict[str, Any]] = []
for entity in entities:
subword = data_pack.get_entry(entity["tid"])
if entity["label"] == "B" and subword.is_first_segment:
# Flush the existing entity and start a new entity
if subword_entities:
complete_entity = self._complete_entity(
subword_entities, data_pack, tids
)
complete_entities.append(complete_entity)
subword_entities = [entity]
else:
# Continue accumuulating subword entities
subword_entities.append(entity)
if subword_entities:
complete_entity = self._complete_entity(
subword_entities, data_pack, tids
)
complete_entities.append(complete_entity)
return complete_entities
[docs] def pack(
self,
pack: DataPack,
predict_results: Optional[Dict[str, Dict[str, List[Any]]]] = None,
context: Optional[Annotation] = None,
):
"""
Write the prediction results back to datapack by aggregating subwords
into named entity mentions.
"""
if predict_results is None:
return
if context is not None:
logging.warning("context parameter is not used in pack() method.")
for i in range(len(predict_results["Subword"]["tid"])):
tids = predict_results["Subword"]["tid"][i]
labels = predict_results["Subword"]["ner"][i]
# Filter to labels not in `self.ft_configs.ignore_labels`
entities = [
dict(idx=idx, label=label, tid=tid)
for idx, (label, tid) in enumerate(zip(labels, tids))
if label not in self.ft_configs.ignore_labels
]
entity_groups = self._compose_entities(entities, pack, tids)
# Add NER tags and create EntityMention ontologies.
for first_idx, last_idx in entity_groups:
first_token: Subword = pack.get_entry(tids[first_idx])
begin = first_token.span.begin
last_token: Subword = pack.get_entry(tids[last_idx])
end = last_token.span.end
entity = EntityMention(pack, begin, end)
entity.ner_type = self.ft_configs.ner_type
[docs] @classmethod
def default_configs(cls):
r"""Default config for NER Predictor"""
return {
"model_path": None,
"ner_type": "BioEntity",
"ignore_labels": ["O"],
"batcher": {
"batch_size": 10,
"context_type": "ft.onto.base_ontology.Sentence",
"requests": {
"ft.onto.base_ontology.Subword": [],
"ft.onto.base_ontology.Sentence": [],
},
},
}
# @staticmethod
# def _define_context() -> Type[Annotation]:
# return Sentence
#
# @staticmethod
# def _define_input_info() -> DataRequest:
# input_info: DataRequest = {
# Subword: [],
# Sentence: [],
# }
# return input_info
[docs] def record(self, record_meta: Dict[str, Set[str]]):
r"""Method to add output type record of current processor
to :attr:`forte.data.data_pack.Meta.record`.
Args:
record_meta: the field in the datapack for type record that need to
fill in for consistency checking.
"""
record_meta["ft.onto.base_ontology.EntityMention"] = {"ner_type"}
[docs] def expected_types_and_attributes(self):
r"""Method to add expected type ft.onto.base_ontology.Subword` with
attribute `is_first_segment` and `ft.onto.base_ontology.Sentence` which
would be checked before running the processor if
the pipeline is initialized with
`enforce_consistency=True` or
:meth:`~forte.pipeline.Pipeline.enforce_consistency` was enabled for
the pipeline.
"""
return {
"ft.onto.base_ontology.Subword": {"is_first_segment"},
"ft.onto.base_ontology.Sentence": set(),
}