# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 import subprocess subprocess.call(["pip", "install", "transformers~=4.5.1"]) import logging import json import torch from transformers import BartTokenizer, BartForConditionalGeneration logger = logging.getLogger(__name__) tokenizer = '' def model_fn(model_dir): logger.info(model_dir) model_path = f'{model_dir}/bart_model/' logger.info(model_path) tokenizer_path = f'{model_dir}/bart_tokenizer/' logger.info(tokenizer_path) device = get_device() model = BartForConditionalGeneration.from_pretrained(model_path).to(device) global tokenizer tokenizer = BartTokenizer.from_pretrained(tokenizer_path) logger.info('model loaded') return model def input_fn(json_request_data, content_type='application/json'): input_data = json.loads(json_request_data) text_to_summarize = input_data['text'] return text_to_summarize def predict_fn(text_to_summarize, model): device = get_device() text_input_ids = tokenizer.batch_encode_plus([text_to_summarize], return_tensors='pt', max_length=1024, return_token_type_ids=False, return_attention_mask=False).to(device) summary_ids = model.generate(text_input_ids['input_ids'], num_beams=4, length_penalty=2.0, max_length=256, min_length=56, no_repeat_ngram_size=3) summary_txt = tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True) return summary_txt def output_fn(summary_txt, accept='application/json'): return json.dumps(summary_txt), accept def get_device(): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' return device