Source code for augmenty.token.swap

import random
from functools import partial
from typing import Callable, Iterator

import numpy as np
import spacy
from spacy.language import Language
from spacy.training import Example

from ..augment_utilities import make_text_from_orth


def token_swap_augmenter_v1(
    nlp: Language,
    example: Example,
    level: float,
    respect_ents: bool,
    respect_sentences: bool,
) -> Iterator[Example]:  # type: ignore
    example_dict = example.to_dict()

    n_tok = len(example.y)

    if respect_ents is True:
        swap_ents = False

    is_swapped = set()  # type: ignore

    tok_anno = example_dict["token_annotation"]
    for i in range(n_tok):  # type: ignore
        if i in is_swapped:
            continue
        if random.random() < level:
            # select which neighbour
            fb = random.choice([1, -1])

            min_i = i + fb if 0 < i + fb < n_tok else i - fb

            if min_i in is_swapped:
                continue
            if min_i > i:
                i, min_i = min_i, i  # make so that i is always the biggest
                is_swapped.add(i)
            if min_i < 0 or i == n_tok:  # e.g. if n_tok == 1
                continue

            if respect_sentences is True and (
                example.y[i].is_sent_end is True or example.y[min_i].is_sent_end is True
            ):
                continue

            if respect_ents is True:
                # 0: not labelled
                # 2 not an ent
                # 3 start ent
                # 1 in ent
                if example.y[min_i].ent_iob in {0, 2} and example.y[i].ent_iob in {
                    0,
                    2,
                }:
                    # Neither is an entity
                    # swap and keep ent spans the same
                    pass
                elif (example.y[min_i].ent_iob == 3 and example.y[i].ent_iob == 1) or (
                    example.y[min_i].ent_iob == 1 and example.y[i].ent_iob == 1
                ):
                    # both part of the same entity
                    # swap and keep ent spans the same
                    pass
                elif (example.y[min_i].ent_iob == 3 and example.y[i].ent_type == 0) or (
                    example.y[i].ent_iob == 3
                    and i != n_tok
                    and example.y[i + 1].ent_iob in {0, 2}
                ):
                    # 1st or second token is a one word entity
                    # swap and swap ents
                    swap_ents = True
                else:
                    # don't swap
                    continue

            for k in tok_anno:
                if k in ["SENT_START", "SPACY"]:
                    continue
                if k == "HEAD":
                    if example.y.has_annotation("HEAD"):
                        head = np.array(tok_anno[k])  # type: ignore
                        head[head == i], head[head == min_i] = min_i, i
                        tok_anno[k] = head
                    else:
                        continue
                tok_anno[k][i], tok_anno[k][min_i] = tok_anno[k][min_i], tok_anno[k][i]

            if respect_ents is True and swap_ents is True:  # type: ignore
                ents = example_dict["doc_annotation"]["entities"]
                ent1, ent2 = ents[min_i], ents[i]
                if ent1 != "O" or ent2 != "O":
                    ent1, ent2 = (
                        ent2[0] + ent1[1:],
                        ent1[0] + ent2[1:],
                    )  # swap the BILOU tag
                ents[i], ents[min_i] = ent1, ent2
    if respect_ents is False:
        example_dict["doc_annotation"].pop("entities")
    text = make_text_from_orth(example_dict)
    doc = nlp.make_doc(text)
    yield example.from_dict(doc, example_dict)


[docs]@spacy.registry.augmenters("token_swap_v1") # type: ignore def create_token_swap_augmenter_v1( level: float, respect_ents: bool = True, respect_sentences: bool = True, ) -> Callable[[Language, Example], Iterator[Example]]: # type: ignore """Creates an augmenter that randomly swaps two neighbouring tokens. Args: level: The probability to swap two tokens. respect_ents: Should the pipeline respect entities? Defaults to True. In which case it will not swap a token inside an entity with a token outside the entity span, unless it is a one word span. If false it will disregard correcting the entity labels. respect_sentences: Should it respect end of sentence bounderies? Default to True, indicating that it will not swap and end of sentence token. If False it will disregard correcting the sentence start as this becomes arbitrary. Returns: The augmenter. """ return partial( token_swap_augmenter_v1, level=level, respect_sentences=respect_sentences, respect_ents=respect_ents, )