import torch from transformers import BertModel import argparse import os if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--save", default="model.onnx") args = parser.parse_args() model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) bs = 1 seq_len = 128 dummy_inputs = (torch.randint(1000, (bs, seq_len)), torch.zeros(bs, seq_len, dtype=torch.int)) torch.onnx.export( model, dummy_inputs, args.save, export_params=True, opset_version=10, input_names=["token_ids", "attn_mask"], output_names=["output"], dynamic_axes={"token_ids": [0, 1], "attn_mask": [0, 1], "output": [0]}, ) print("Saved {}".format(args.save))