import random
from functools import partial
from re import A
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Union,
)
import numpy as np
from spacy.language import Language
from spacy.tokens import Doc, Span, Token
from spacy.training import Example
from spacy.util import registry
from augmenty.util import Augmenter
from ..augment_utilities import make_text_from_orth
from .utils import offset_range
# create entity type
ENTITY = Union[str, List[str], Span, Doc]
def _spacing_to_str(spacing: Union[List[str], List[bool]]) -> List[str]:
def to_string(x: Union[str, bool]) -> str:
if isinstance(x, str):
return x
return " " if x else ""
return [to_string(x) for x in spacing]
def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, Any]:
spacy = None
pos = None
tag = None
morph = None
lemma = None
if isinstance(entity, str):
ent_doc = nlp(entity)
orth = [tok.text for tok in ent_doc]
spacy = [tok.whitespace_ for tok in ent_doc]
elif isinstance(entity, list):
orth = entity
elif isinstance(entity, (Span, Doc)):
orth = [tok.text for tok in entity]
spacy = [tok.whitespace_ for tok in entity]
pos = [tok.pos_ for tok in entity]
tag = [tok.tag_ for tok in entity]
morph = [tok.morph for tok in entity]
lemma = [tok.lemma_ for tok in entity]
else:
raise ValueError(
f"entity must be of type str, List[str] or Span, not {type(entity)}",
)
# if not specifed use default values
if spacy is None:
spacy = [" "] * len(orth)
if pos is None:
pos = ["PROPN"] * len(orth)
if tag is None:
tag = ["PROPN"] * len(orth)
if morph is None:
morph = [""] * len(orth)
if lemma is None:
lemma = orth
_spacy = _spacing_to_str(spacy)
str_repr = ""
for e, s in zip(orth[:-1], _spacy[:-1]):
str_repr += e + s
str_repr += orth[-1]
return {
"ORTH": orth,
"SPACY": spacy,
"POS": pos,
"TAG": tag,
"MORPH": morph,
"LEMMA": lemma,
"STR": str_repr,
}
def _update_span_annotations(
span_anno: Dict[str, list],
ent: Span,
offset: int,
entity_offset: int,
) -> Dict[str, list]:
"""Update the span annotations to be in line with the new doc."""
ent_range = (ent.start + offset, ent.end + offset)
for anno_key, spans in span_anno.items():
new_spans = []
for span_start, span_end, _, __ in spans:
span_start, span_end = offset_range(
current_range=(span_start, span_end),
inserted_range=ent_range,
offset=entity_offset,
)
new_spans.append((span_start, span_end, _, __))
span_anno[anno_key] = new_spans
return span_anno
def ent_augmenter_v1(
nlp: Language,
example: Example,
level: float,
ent_dict: Dict[str, Iterable[ENTITY]],
replace_consistency: bool,
resolve_dependencies: bool,
) -> Iterator[Example]:
replaced_ents: Dict[str, ENTITY] = {}
example_dict = example.to_dict()
offset = 0
str_offset = 0
spans_anno = example_dict["doc_annotation"]["spans"]
tok_anno = example_dict["token_annotation"]
ents = example_dict["doc_annotation"]["entities"]
should_update_heads = example.y.has_annotation("HEAD") and resolve_dependencies
if should_update_heads:
head = np.array(tok_anno["HEAD"]) # type: ignore
for ent in example.y.ents:
if ent.label_ in ent_dict and random.random() < level:
if replace_consistency and ent.text in replaced_ents:
new_ent = replaced_ents[ent.text]
else:
if isinstance(ent_dict[ent.label_], Generator):
new_ent = next(ent_dict[ent.label_]) # type: ignore
else:
new_ent = random.sample(ent_dict[ent.label_], k=1)[0] # type: ignore
if replace_consistency:
replaced_ents[ent.text] = new_ent
normalized_ent = __normalize_entity(new_ent, nlp)
new_ent = normalized_ent["ORTH"]
spacing = normalized_ent["SPACY"]
str_ent = normalized_ent["STR"]
# Handle token annotations
len_ent = len(new_ent)
str_len_ent = len(str_ent)
ent_range = (ent.start + offset, ent.end + offset)
i = slice(*ent_range)
tok_anno["ORTH"][i] = new_ent
tok_anno["LEMMA"][i] = normalized_ent["LEMMA"]
tok_anno["POS"][i] = normalized_ent["POS"]
tok_anno["TAG"][i] = normalized_ent["TAG"]
tok_anno["MORPH"][i] = normalized_ent["MORPH"]
tok_anno["DEP"][i] = [ent[0].dep_] + ["flat"] * (len_ent - 1)
# Set sentence start based on first token in previous entity
tok_anno["SENT_START"][i] = [ent[0].sent_start] + [0] * (len_ent - 1)
# set the last spacing to be equal to the last token spacing in the previous entity
spacing[-1:] = [ent[-1].whitespace_]
tok_anno["SPACY"][i] = spacing
entity_offset = len_ent - (ent.end - ent.start)
entity_str_offset = str_len_ent - len(ent.text)
if should_update_heads:
# Handle HEAD
head[head > ent.start + offset] += entity_offset # type: ignore
# keep first head correcting for changing entity size, set rest to
# refer to index of first name
head = np.concatenate(
[
np.array(head[: ent.start + offset]), # before # type: ignore
np.array(
[head[ent.root.i + offset]] # type: ignore
+ [ent.start + offset] * (len_ent - 1),
), # the entity
np.array(head[ent.end + offset :]), # after # type: ignore
],
)
spans_anno = _update_span_annotations(
spans_anno,
ent,
str_offset,
entity_str_offset,
)
offset += entity_offset
str_offset += entity_str_offset
# Handle entities IOB tags
if len_ent == 1:
ents[i] = ["U-" + ent.label_]
else:
ents[i] = (
["B-" + ent.label_]
+ ["I-" + ent.label_] * (len_ent - 2)
+ ["L-" + ent.label_]
)
if should_update_heads:
tok_anno["HEAD"] = head.tolist() # type: ignore
else:
tok_anno["HEAD"] = list(range(len(tok_anno["ORTH"]))) # type: ignore
text = make_text_from_orth(example_dict)
doc = nlp.make_doc(text)
yield Example.from_dict(doc, example_dict)
[docs]@registry.augmenters("ents_replace_v1")
def create_ent_augmenter_v1(
level: float,
ent_dict: Dict[str, Iterable[ENTITY]],
replace_consistency: bool = True,
resolve_dependencies: bool = True,
) -> Augmenter:
"""Create an augmenter which replaces an entity based on a dictionary
lookup.
Args:
level: the percentage of entities to be augmented.
ent_dict: A dictionary with keys corresponding
the the entity type you wish to replace (e.g. "PER") and a itarable of the
replacements entities. A replacement can be either 1) a list of string of the desired entity
i.e. ["Kenneth", "Enevoldsen"], 2) a string of the desired entity i.e. "Kenneth Enevoldsen", this
will be split using the tokenizer of the nlp pipeline, or 3) Span object with the desired entity, here all information will be passed
on except for the dependency tree.
replace_consistency: Should an entity always be replaced with the same entity?
resolve_dependencies: Attempts to resolve the dependency tree
by setting head of the original entitity aa the head of the
first token in the new entity. The remainder is the passed as
Returns:
The augmenter
Example:
>>> ent_dict = {"ORG": [["Google"], ["Apple"]],
>>> "PERSON": [["Kenneth"], ["Lasse", "Hansen"]]}
>>> # augment 10% of names
>>> ent_augmenter = create_ent_augmenter(ent_dict, level = 0.1)
"""
return partial(
ent_augmenter_v1,
level=level,
ent_dict=ent_dict,
replace_consistency=replace_consistency,
resolve_dependencies=resolve_dependencies,
)
def generator_from_name_dict(
names: Dict[str, List[str]], # type: ignore
patterns: List[List[str]], # type: ignore
names_p: Dict[str, List[float]], # type: ignore
patterns_p: Optional[List[float]], # type: ignore
):
"""A utility function for create_pers_replace_augmenter, which creates an
infinite generator based on a names dictionary and a list of patterns,
where the string in the pattern correspond to the list in the pattern."""
lp = len(patterns)
while True:
i = np.random.choice(lp, size=1, replace=True, p=patterns_p)[0] # type: ignore
yield [
str(np.random.choice(names[p], size=1, replace=True, p=names_p.get(p))[0]) # type: ignore
for p in patterns[i]
]
[docs]@registry.augmenters("per_replace_v1")
def create_per_replace_augmenter_v1(
names: Dict[
str,
List[str],
], # {"firstname": ["Kenneth", "Lasse"], "lastname": ["Enevoldsen", "Hansen"]}
patterns: List[List[str]], # ["firstname", "firstname", "lastname"]
level: float,
names_p: Optional[Dict[str, List[float]]] = None,
patterns_p: Optional[List[float]] = None,
replace_consistency: bool = True,
person_tag: str = "PERSON",
) -> Augmenter:
"""Create an augmenter which replaces a name (PER) with a news sampled from
the names dictionary.
Args:
names: A dictionary of list of names to sample from.
These could for example include first name and last names.
pattern: The pattern to create the names. This should be a
list of patterns.
Where a pattern is a list of strings, where the string denote the list in
the names dictionary in which to sample from.
level: The proportion of PER entities to replace.
names_p: The probability to sample each name. An empty dictionary "{}", indicates equal probability for each name.
patterns_p: The probability to sample each pattern. None indicates equal probability for each pattern.
replace_consistency: Should the entity always be replaced with the same entity?
person_tag: The tag of the person entity (e.g. "PERSON" or "PER").
Returns:
The augmenter
Example:
>>> names = {"firstname": ["Kenneth", "Lasse"],
>>> "lastname": ["Enevoldsen", "Hansen"]}
>>> patterns = [["firstname"], ["firstname", "lastname"],
>>> ["firstname", "firstname", "lastname"]]
>>> person_tag = "PERSON"
>>> # replace 10% of names:
>>> per_augmenter = create_per_replace_augmenter(names, patterns, level=0.1,
>>> person_tag=person_tag)
"""
if names_p is None:
names_p = {}
names_gen = generator_from_name_dict(names, patterns, names_p, patterns_p)
return create_ent_augmenter_v1(
ent_dict={person_tag: names_gen},
level=level,
replace_consistency=replace_consistency,
)
def ent_format_augmenter_v1(
nlp: Language,
example: Example,
reordering: List[Union[int, None]],
formatter: List[Union[Callable[[Token], str], None]],
level: float,
ent_types: Optional[List[str]] = None,
) -> Iterator[Example]:
example_dict = example.to_dict()
tok_anno = example_dict["token_annotation"]
for ent in example.y.ents:
if (ent_types is None or ent.label_ in ent_types) and random.random() < level:
# reorder tokens
new_ent = []
ent_ = list(ent)
for i in reordering:
if i is not None and i >= len(ent):
continue
new_ent += ent_ if i is None else [ent_.pop(i)]
# format tokens
new_ent_ = [
e.text if f is None else f(e) for e, f in zip(new_ent, formatter)
]
if len(new_ent_) < len(new_ent):
new_ent_ += [e.text for e in new_ent[len(new_ent_) :]]
tok_anno["ORTH"][ent.start : ent.end] = new_ent_
tok_anno["LEMMA"][ent.start : ent.end] = new_ent_
text = make_text_from_orth(example_dict)
doc = nlp.make_doc(text)
yield Example.from_dict(doc, example_dict)