Source code for augmenty.token.swap

import random
from functools import partial
from typing import Callable, Iterator

import numpy as np
import spacy
from spacy.language import Language
from import Example

from ..augment_utilities import make_text_from_orth

def token_swap_augmenter_v1(
    nlp: Language,
    example: Example,
    level: float,
    respect_ents: bool,
    respect_sentences: bool,
) -> Iterator[Example]:  # type: ignore
    example_dict = example.to_dict()

    n_tok = len(example.y)

    if respect_ents is True:
        swap_ents = False

    is_swapped = set()  # type: ignore

    tok_anno = example_dict["token_annotation"]
    for i in range(n_tok):  # type: ignore
        if i in is_swapped:
        if random.random() < level:
            # select which neighbour
            fb = random.choice([1, -1])

            min_i = i + fb if 0 < i + fb < n_tok else i - fb

            if min_i in is_swapped:
            if min_i > i:
                i, min_i = min_i, i  # make so that i is always the biggest
            if min_i < 0 or i == n_tok:  # e.g. if n_tok == 1

            if respect_sentences is True and (
                example.y[i].is_sent_end is True or example.y[min_i].is_sent_end is True

            if respect_ents is True:
                # 0: not labelled
                # 2 not an ent
                # 3 start ent
                # 1 in ent
                if example.y[min_i].ent_iob in {0, 2} and example.y[i].ent_iob in {
                    # Neither is an entity
                    # swap and keep ent spans the same
                elif (example.y[min_i].ent_iob == 3 and example.y[i].ent_iob == 1) or (
                    example.y[min_i].ent_iob == 1 and example.y[i].ent_iob == 1
                    # both part of the same entity
                    # swap and keep ent spans the same
                elif (example.y[min_i].ent_iob == 3 and example.y[i].ent_type == 0) or (
                    example.y[i].ent_iob == 3
                    and i != n_tok
                    and example.y[i + 1].ent_iob in {0, 2}
                    # 1st or second token is a one word entity
                    # swap and swap ents
                    swap_ents = True
                    # don't swap

            for k in tok_anno:
                if k in ["SENT_START", "SPACY"]:
                if k == "HEAD":
                    if example.y.has_annotation("HEAD"):
                        head = np.array(tok_anno[k])  # type: ignore
                        head[head == i], head[head == min_i] = min_i, i
                        tok_anno[k] = head
                tok_anno[k][i], tok_anno[k][min_i] = tok_anno[k][min_i], tok_anno[k][i]

            if respect_ents is True and swap_ents is True:  # type: ignore
                ents = example_dict["doc_annotation"]["entities"]
                ent1, ent2 = ents[min_i], ents[i]
                if ent1 != "O" or ent2 != "O":
                    ent1, ent2 = (
                        ent2[0] + ent1[1:],
                        ent1[0] + ent2[1:],
                    )  # swap the BILOU tag
                ents[i], ents[min_i] = ent1, ent2
    if respect_ents is False:
    text = make_text_from_orth(example_dict)
    doc = nlp.make_doc(text)
    yield example.from_dict(doc, example_dict)

[docs]@spacy.registry.augmenters("token_swap_v1") # type: ignore def create_token_swap_augmenter_v1( level: float, respect_ents: bool = True, respect_sentences: bool = True, ) -> Callable[[Language, Example], Iterator[Example]]: # type: ignore """Creates an augmenter that randomly swaps two neighbouring tokens. Args: level: The probability to swap two tokens. respect_ents: Should the pipeline respect entities? Defaults to True. In which case it will not swap a token inside an entity with a token outside the entity span, unless it is a one word span. If false it will disregard correcting the entity labels. respect_sentences: Should it respect end of sentence bounderies? Default to True, indicating that it will not swap and end of sentence token. If False it will disregard correcting the sentence start as this becomes arbitrary. Returns: The augmenter. """ return partial( token_swap_augmenter_v1, level=level, respect_sentences=respect_sentences, respect_ents=respect_ents, )