--------------------------------------------------------------------------------
Adopted from https://github.com/alpa-projects/alpa/tree/main/examples/language_model
and https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling
More on Alpa: https://ai.googleblog.com/2022/05/alpa-automated-model-parallel-deep.html
Use `alpa.parallelize` to parallelize the training loop.
--------------------------------------------------------------------------------
# Language model training examples
The following example showcases how to train a language model from scratch
using the JAX/Flax backend.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
## Causal language modeling
In the following, we demonstrate how to train an auto-regressive causal transformer model
in JAX/Flax.
More specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian
to pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2)
in Norwegian.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
To setup all relevant files for training, let's create a directory.
```bash
mkdir ./norwegian-gpt2
```
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in the cloned model directory.
This can take up to 10 minutes depending on your hardware ☕.
```python
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50256, min_frequency=2, special_tokens=[
"",
"",
"",
"",
"",
])
# Save files to disk
tokenizer.save("./norwegian-gpt2/tokenizer.json")
```
### Create configuration
Next, we create the model's configuration file. This is as simple
as loading and storing [`**gpt2**`](https://huggingface.co/gpt2)
in the local model folder:
```python
from transformers import GPT2Config
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50256)
config.save_pretrained("./norwegian-gpt2")
```
Great, we have set up our model repository. During training, we will now automatically
push the training logs and model weights to the repo.
### Train model
Finally, we can run the example script to pretrain the model:
#### Launch a Ray cluster
1. Use the command below to launch ray on a head node
```ray start --head```
2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command.
```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'```
##### Run
```bash
python3 run_clm_flax.py \
--output_dir="./norwegian-gpt2" \
--model_type="gpt2" \
--config_name="./norwegian-gpt2" \
--tokenizer_name="./norwegian-gpt2" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--do_train --do_eval \
--block_size="512" \
--per_device_train_batch_size="96" \
--per_device_eval_batch_size="96" \
--num_micro_batches="4" \
--dtype="float16" \
--learning_rate="1e-3" --warmup_steps="1000" \
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
--overwrite_output_dir \
--num_train_epochs="20" \
--logging_steps="100" \
--save_steps="2500" \
--eval_steps="2500"
```
Training should converge at a loss and perplexity
of 3.24 and 25.72 respectively after 20 epochs
This should take less than ~21 hours on a machine with 8 V100 GPUs.
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).
For a step-by-step walkthrough of how to do causal language modeling in Flax, please have a
look at https://github.com/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb