Shortcuts

Summarization

The Task

The Summarization task requires the model to summarize a document into a shorter sentence.

Datasets

Currently supports the CNN/DailyMail and XSUM dataset or custom input text files.

In the CNN/Daily Mail dataset, this involves taking long articles and summarizing them.

document: "The car was racing towards the tunnel, whilst blue lights were flashing behind it. The car entered the tunnel and vanished..."

Model answer: "Police are chasing a car entering a tunnel."

Training

To use this task, we must select a Seq2Seq Encoder/Decoder based model, such as T5 or BART. Encoder only models like GPT/BERT will not work as they are encoder only.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.summarization import (
    SummarizationTransformer,
    XsumSummarizationDataModule,
)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="t5-base")
model = SummarizationTransformer(
    pretrained_model_name_or_path="t5-base",
    use_stemmer=True,
    val_target_max_length=142,
    num_beams=None,
    compute_generate_metrics=True,
)
dm = XsumSummarizationDataModule(
    batch_size=1,
    max_source_length=128,
    max_target_length=128,
    tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices=1, max_epochs=1)

trainer.fit(model, dm)

Summarization Using Your Own Files

To use custom text files, the files should contain new line delimited json objects within the text files.

{
    "source": "some-body",
    "target": "some-sentence"
}

We override the dataset files, allowing us to still use the data transforms defined with this dataset.

from lightning_transformers.task.nlp.summarization import (
    XsumSummarizationDataModule,
)

dm = XsumSummarizationDataModule(
    batch_size=1,
    max_source_length=128,
    max_target_length=128,
    train_file="path/train.csv",
    validation_file="/path/valid.csv"
    tokenizer=tokenizer,
)

Summarization Inference Pipeline

By default we use the summarization pipeline, which requires an input document as text.

from transformers import AutoTokenizer
from lightning_transformers.task.nlp.summarization import SummarizationTransformer

model = SummarizationTransformer(
    pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random",
    tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random"),
)

model.hf_predict(
    "The results found significant improvements over all tasks evaluated",
    min_length=2,
    max_length=12,
)
Read the Docs v: stable
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.