For named entity recognition tasks, a handy measure of class imbalance is to calculate the frequency of named entities in the data. I wanted to do this with the datasets library for documents annotated in the "inside-outside-beginning" (IOB2) format.

One problem I encountered was that datasets tends to represent the entities in terms of label IDs

from datasets import load_dataset

conll = load_dataset("conll2003")
conll['train'][0]
{'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'id': '0',
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.']}

so I created a simple function that makes use of the Dataset.features attribute and ClassLabel.int2str method to perform the mapping from ID to human-readable string:

from datasets import Dataset

def create_tag_names(ds: Dataset, tags_col: str) -> Dataset:
    # pick out the ClassLabel feature from feature
    tags = ds["train"].features[tags_col].feature
    # apply the ClassLabel.int2str method to each token
    proc_fn = lambda x : {f"{tags_col}_str": [tags.int2str(idx) for idx in x[tags_col]]}
    return ds.map(proc_fn)


conll = create_tag_names(conll, 'ner_tags')
conll['train'][0]
{'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'id': '0',
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0],
 'ner_tags_str': ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.']}

With some help from my partner-in-crime, the final step was to iterate over each example, collect all the B- tags in a list (since the I- tags refer to the same entity), and then use a bit of chain magic to flatten the list of lists per split:

import pandas as pd
from itertools import chain
from collections import Counter

def calculate_tag_frequencies(ds: Dataset, tags_col: str) -> pd.DataFrame:
    split2freqs = {}

    for split in ds.keys():
        tag_names = []
        for row in ds[split][tags_col]:
            tag_names.append([tag.split('-')[1] for tag in row if tag.startswith("B")])
            # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
        split2freqs[split] = Counter(chain.from_iterable(tag_names))

    return pd.DataFrame.from_dict(split2freqs, orient="index")

calculate_tag_frequencies(conll, 'ner_tags_str')
ORG MISC PER LOC
train 6321 3438 6600 7140
validation 1341 922 1842 1837
test 1661 702 1617 1668

As a sanity check, let's compare with Table 2 from the CoNLL-2003 paper:

It works!