import torch from torch import nn import torch.utils.data import torch.utils.data.distributed from torch.utils.data import Dataset, DataLoader, RandomSampler, TensorDataset class ProteinSequenceDataset(Dataset): def __init__(self, sequence, targets, tokenizer, max_len): self.sequence = sequence self.targets = targets self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.sequence) def __getitem__(self, item): sequence = str(self.sequence[item]) target = self.targets[item] encoding = self.tokenizer.encode_plus( sequence, truncation=True, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=False, padding='max_length', return_attention_mask=True, return_tensors='pt', ) return { 'protein_sequence': sequence, 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'targets': torch.tensor(target, dtype=torch.long) }