Shortcuts

Big Transformers Model Inference

Lightning Transformers provides out of the box support for running inference with very large billion parameter models. Under-the-hood we use HF Transformer’s large model support to auto-select devices for optimal throughput and memory usage.

Below is an example of how you can run generation with a large 6B parameter transformer model using Lightning Transformers. We’ve also managed to run bigscience/bloom which is 176B parameters using 8 A100s with the below code.

pip install accelerate
import torch
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import LanguageModelingTransformer

model = LanguageModelingTransformer(
    pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
    tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
    low_cpu_mem_usage=True,
    device_map="auto",
)

output = model.generate("Hello, my name is", device=torch.device("cuda"))
print(model.tokenizer.decode(output[0].tolist()))

This will allow the model to be split onto GPUs/CPUs and even kept onto Disk to optimize memory space.

Inference with Manual Checkpoints

Download the sharded checkpoint weights that we’ll be using:

git clone https://huggingface.co/sgugger/sharded-gpt-j-6B
cd sharded-gpt-j-6B
git-lfs install
git pull
import torch
from accelerate import init_empty_weights
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import LanguageModelingTransformer

# initializes empty model for us to the load the checkpoint.
with init_empty_weights():
   model = LanguageModelingTransformer(
   pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
   tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
 )

# automatically selects the best devices (cpu/gpu) to load model layers based on available memory
model.load_checkpoint_and_dispatch("sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"])

output = model.generate("Hello, my name is", device=torch.device("cuda"))
print(model.tokenizer.decode(output[0].tolist()))

To see more details about the API, see here.

Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.