# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import pytest from pathlib import Path import torch import esm # Directly from hubconf.py model_names = """ esm1_t6_43M_UR50S, esm1_t12_85M_UR50S, esm1_t34_670M_UR50S, esm1_t34_670M_UR50D, esm1_t34_670M_UR100, esm1b_t33_650M_UR50S, esm_msa1_t12_100M_UR50S, esm_msa1b_t12_100M_UR50S, esm1v_t33_650M_UR90S, esm1v_t33_650M_UR90S_1, esm1v_t33_650M_UR90S_2, esm1v_t33_650M_UR90S_3, esm1v_t33_650M_UR90S_4, esm1v_t33_650M_UR90S_5, esm_if1_gvp4_t16_142M_UR50, esm2_t6_8M_UR50D, esm2_t12_35M_UR50D, esm2_t30_150M_UR50D, esm2_t33_650M_UR50D, esm2_t36_3B_UR50D, esm2_t48_15B_UR50D """ model_names = [mn.strip() for mn in model_names.strip(" ,\n").split(",")] @pytest.mark.parametrize("model_name", model_names) def test_load_hub_fwd_model(model_name: str) -> None: model, alphabet = getattr(esm.pretrained, model_name)() # batch_size = 2, seq_len = 3, tokens within vocab dummy_inp = torch.tensor([[0, 1, 2], [3, 4, 5]]) if "esm_msa" in model_name: dummy_inp = dummy_inp.unsqueeze(0) output = model(dummy_inp) # dict logits = output["logits"].squeeze(0) assert logits.shape == (2, 3, len(alphabet)) @pytest.mark.parametrize("model_name", model_names) def test_load_local(model_name: str) -> None: # Assumes everything has already been loaded & cached. local_path = Path.home() / ".cache/torch/hub/checkpoints" / (model_name + ".pt") if model_name.endswith("esm1v_t33_650M_UR90S"): return # skip; needs to get rerouted to specific instance model, alphabet = esm.pretrained.load_model_and_alphabet_local(local_path)