Calculating named entity frequencies
Munging data with 🤗 Datasets
For named entity recognition tasks, a handy measure of class imbalance is to calculate the frequency of named entities in the data. I wanted to do this with the datasets library for documents annotated in the "inside-outside-beginning" (IOB2) format.
One problem I encountered was that datasets tends to represent the entities in terms of label IDs
from datasets import load_dataset
conll = load_dataset("conll2003")
conll['train'][0]
so I created a simple function that makes use of the Dataset.features attribute and ClassLabel.int2str method to perform the mapping from ID to human-readable string:
from datasets import Dataset
def create_tag_names(ds: Dataset, tags_col: str) -> Dataset:
    # pick out the ClassLabel feature from feature
    tags = ds["train"].features[tags_col].feature
    # apply the ClassLabel.int2str method to each token
    proc_fn = lambda x : {f"{tags_col}_str": [tags.int2str(idx) for idx in x[tags_col]]}
    return ds.map(proc_fn)
conll = create_tag_names(conll, 'ner_tags')
conll['train'][0]
With some help from my partner-in-crime, the final step was to iterate over each example, collect all the B- tags in a list (since the I- tags refer to the same entity), and then use a bit of chain magic to flatten the list of lists per split:
import pandas as pd
from itertools import chain
from collections import Counter
def calculate_tag_frequencies(ds: Dataset, tags_col: str) -> pd.DataFrame:
    split2freqs = {}
    for split in ds.keys():
        tag_names = []
        for row in ds[split][tags_col]:
            tag_names.append([tag.split('-')[1] for tag in row if tag.startswith("B")])
            # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
        split2freqs[split] = Counter(chain.from_iterable(tag_names))
    return pd.DataFrame.from_dict(split2freqs, orient="index")
calculate_tag_frequencies(conll, 'ner_tags_str')