Calculating named entity frequencies
Munging data with 🤗 Datasets
For named entity recognition tasks, a handy measure of class imbalance is to calculate the frequency of named entities in the data. I wanted to do this with the datasets
library for documents annotated in the "inside-outside-beginning" (IOB2) format.
One problem I encountered was that datasets
tends to represent the entities in terms of label IDs
from datasets import load_dataset
conll = load_dataset("conll2003")
conll['train'][0]
so I created a simple function that makes use of the Dataset.features
attribute and ClassLabel.int2str
method to perform the mapping from ID to human-readable string:
from datasets import Dataset
def create_tag_names(ds: Dataset, tags_col: str) -> Dataset:
# pick out the ClassLabel feature from feature
tags = ds["train"].features[tags_col].feature
# apply the ClassLabel.int2str method to each token
proc_fn = lambda x : {f"{tags_col}_str": [tags.int2str(idx) for idx in x[tags_col]]}
return ds.map(proc_fn)
conll = create_tag_names(conll, 'ner_tags')
conll['train'][0]
With some help from my partner-in-crime, the final step was to iterate over each example, collect all the B- tags in a list (since the I- tags refer to the same entity), and then use a bit of chain
magic to flatten the list of lists per split:
import pandas as pd
from itertools import chain
from collections import Counter
def calculate_tag_frequencies(ds: Dataset, tags_col: str) -> pd.DataFrame:
split2freqs = {}
for split in ds.keys():
tag_names = []
for row in ds[split][tags_col]:
tag_names.append([tag.split('-')[1] for tag in row if tag.startswith("B")])
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
split2freqs[split] = Counter(chain.from_iterable(tag_names))
return pd.DataFrame.from_dict(split2freqs, orient="index")
calculate_tag_frequencies(conll, 'ner_tags_str')
As a sanity check, let's compare with Table 2 from the CoNLL-2003 paper:
It works!