Source code for quanda.utils.tokenization

"""Utils for tokenization of HuggingFace datasets."""

from typing import Any, Optional, Tuple

import datasets as hf_datasets  # type: ignore
from transformers import AutoTokenizer


class _TikTokenHFAdapter:
    """HF-tokenizer-shaped adapter around a tiktoken encoding.

    Exposes the subset of ``AutoTokenizer.__call__`` that
    :class:`~quanda.benchmarks.config_parser.FactTracingConfigParser`
    relies on — ``(text, padding, truncation, max_length)`` returning
    ``{"input_ids": ..., "attention_mask": ...}`` — so the parser can
    treat tiktoken and HF tokenizers uniformly.
    """

    def __init__(self, encoding_name: str = "gpt2"):
        import tiktoken  # type: ignore

        self._enc = tiktoken.get_encoding(encoding_name)
        self.pad_token_id: int = self._enc.eot_token

    def __call__(
        self,
        text: str,
        padding: Any = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
        **_: Any,
    ) -> dict:
        ids = self._enc.encode_ordinary(text)
        if truncation and max_length is not None:
            ids = ids[:max_length]
        if padding == "max_length" and max_length is not None:
            pad_len = max_length - len(ids)
            input_ids = ids + [self.pad_token_id] * pad_len
            attention_mask = [1] * len(ids) + [0] * pad_len
        else:
            input_ids = list(ids)
            attention_mask = [1] * len(ids)
        return {"input_ids": input_ids, "attention_mask": attention_mask}


def _hf_tokenizer(tokenizer_name: str):
    """Return a HF AutoTokenizer, ensuring a pad token is set."""
    tok = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    return tok


[docs] def resolve_tokenizer(tokenizer_cfg: dict) -> Tuple[Any, int]: """Resolve a tokenizer config to ``(tokenizer, pad_token_id)``. ``tokenizer`` exposes HF's ``__call__(text, padding, truncation, max_length) -> {"input_ids", "attention_mask"}``. Supported backends: - ``backend: hf`` with ``name`` (HF tokenizer repo) — returns the ``AutoTokenizer`` directly. - ``backend: tiktoken`` with ``encoding`` (default ``gpt2``) — returns a :class:`_TikTokenHFAdapter` with the same interface. """ backend = tokenizer_cfg.get("backend", "hf") if backend == "tiktoken": adapter = _TikTokenHFAdapter(tokenizer_cfg.get("encoding", "gpt2")) return adapter, adapter.pad_token_id if backend == "hf": tok = _hf_tokenizer(tokenizer_cfg["name"]) return tok, tok.pad_token_id raise ValueError(f"Unknown tokenizer backend: {backend}")
[docs] def tokenize_dataset( hf_dataset: hf_datasets.Dataset, tokenizer_cfg: dict, ) -> hf_datasets.Dataset: """Tokenize an HF dataset for transformer models. Parameters ---------- hf_dataset : datasets.Dataset Raw HuggingFace dataset. tokenizer_cfg : dict Keys: ``name``, ``text_fields``, ``max_length``, ``label_field``. Returns ------- datasets.Dataset Tokenized dataset formatted as torch tensors. """ tokenizer = AutoTokenizer.from_pretrained(tokenizer_cfg["name"]) text_fields = tokenizer_cfg["text_fields"] max_length = tokenizer_cfg.get("max_length", 128) label_field = tokenizer_cfg.get("label_field", "label") def tokenize_fn(examples: dict) -> dict: texts = [examples[f] for f in text_fields] result = tokenizer( *texts, padding="max_length", truncation=True, max_length=max_length, ) result["labels"] = examples[label_field] return result tokenized = hf_dataset.map( tokenize_fn, batched=True, remove_columns=hf_dataset.column_names, ) tokenized.set_format("torch") return tokenized