Shortcuts

Text Classification

The Task

The Text Classification Task fine-tunes the model to predict probabilities across a set of labels given input text. The task supports both binary and multi-class/multi-label classification.

Datasets

Currently supports the XLNI, GLUE and emotion datasets, or custom input files.

Input: I don't like this at all!

Model answer: {"label": "angry", "score": 0.8}

Training

Use this task when you would like to fine-tune Transformers on a labeled text classification task. For this task, you can rely on most Transformer models as your backbone.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-uncased")
dm = TextClassificationDataModule(
    batch_size=1,
    dataset_name="glue",
    dataset_config_name="sst2",
    max_length=512,
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", num_labels=dm.num_classes)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

We report the Precision, Recall, Accuracy and Cross Entropy Loss for validation.

Text Classification Using Your Own Files

To use custom text files, the files should contain new line delimited json objects within the text files.

The label mapping is automatically generated from the training dataset labels if no mapping is given.

{
    "label": "sad",
    "text": "I'm feeling quite sad and sorry for myself but I'll snap out of it soon."
}
from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)

dm = TextClassificationDataModule(
    batch_size=1,
    max_length=512,
    train_file="path/train.json",
    validation_file="/path/valid.json"
    tokenizer=tokenizer,
)

Text Classification Inference Pipeline

By default we use the sentiment-analysis pipeline, which requires an input string.

from transformers import AutoTokenizer
from lightning_transformers.task.nlp.text_classification import TextClassificationTransformer

model = TextClassificationTransformer(
    pretrained_model_name_or_path="prajjwal1/bert-tiny",
    tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path="prajjwal1/bert-tiny"),
)
model.hf_predict("Lightning rocks!")
Read the Docs v: stable
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.