Translation¶
The Task¶
The Translation task fine-tunes the model to translate text from one language to another.
Datasets¶
Currently supports the WMT16 dataset or custom input text files.
Input Text (English): "The ground is black, the sky is blue and the car is red."
Model Output (German): "Der Boden ist schwarz, der Himmel ist blau und das Auto ist rot."
Training¶
To use this task, select a Seq2Seq Encoder/Decoder based model, such as multi-lingual T5 or BART. Conventional models like GPT/BERT will not work as they are encoder only. In addition, you also need a tokenizer that has been created on multi-lingual text. This is true for mt5 and mbart.
import pytorch_lightning as pl
from transformers import AutoTokenizer
from lightning_transformers.task.nlp.translation import (
TranslationTransformer,
WMT16TranslationDataModule,
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="t5-base")
model = TranslationTransformer(
pretrained_model_name_or_path="t5-base",
n_gram=4,
smooth=False,
val_target_max_length=142,
num_beams=None,
compute_generate_metrics=True,
)
dm = WMT16TranslationDataModule(
# WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
dataset_config_name="ro-en",
source_language="en",
target_language="ro",
max_source_length=128,
max_target_length=128,
tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)
Translation Using Your Own Files¶
To use custom text files, the files should contain new line delimited json objects within the text files.
{
"source": "example source text",
"target": "example target text"
}
We override the dataset files, allowing us to still use the data transforms defined with this dataset.
from lightning_transformers.task.nlp.translation import WMT16TranslationDataModule
dm = WMT16TranslationDataModule(
# WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
dataset_config_name="ro-en",
source_language="en",
target_language="ro",
max_source_length=128,
max_target_length=128,
train_file="path/train.json",
validation_file="/path/valid.json"
tokenizer=tokenizer,
)
Translation Inference Pipeline¶
By default we use the translation pipeline, which requires a source text string.
from transformers import AutoTokenizer
from lightning_transformers.task.nlp.translation import TranslationTransformer
model = TranslationTransformer(
pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random",
tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random"),
)
model.hf_predict("¡Hola Sean!")