import multiprocessing as mp from concurrent import futures import grpc import torch from torchvision import datasets, transforms import dataset_feed_pb2 import dataset_feed_pb2_grpc import logging import sys # Logging initialization logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) # The following class implements the data feeding service class DatasetFeedService(dataset_feed_pb2_grpc.DatasetFeedServicer): def __init__(self, q, kill_event): ''' param q: A shared queue containing data batches param kill: Kill event for graceful shutdown ''' self.q = q self.kill_event = kill_event def get_examples(self, request, context): while True: #print('DEBUG: get_examples') example = self.q.get() yield dataset_feed_pb2.Example(image=example[0], label=example[1]) def shutdown(self, request, context): logger.info("Received shutdown request - Not implemented") # from main_grpc_client import shutdown_data_service # shutdown_data_service() context.set_code(grpc.StatusCode.OK) context.set_details('Shutting down') return dataset_feed_pb2.Dummy() # The data loading and preprocessing logic. # We chose to keep the existing logic unchanged, just instead # of feeding the model, the dataloader feeds a shared queue class MyMNIST(datasets.MNIST): ''' A personalized extension of the MNIST class in which we modify the __len__ operation to return the maximum value of int32 so that we do not run out of data. ''' def __init__(self, batch_size : int, iterations : int, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size self.iterations = iterations def __len__(self) -> int: size = self.batch_size * self.iterations return size def __getitem__(self, index: int): return super(MyMNIST, self).__getitem__(index % len(self.data)) def fill_queue(q,kill, args): MyMNIST.mirrors = ["https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/"] train_kwargs = {'batch_size': args.batch_size, 'num_workers': args.num_data_workers} transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.GaussianBlur(11) ]) dataset = MyMNIST(batch_size=args.batch_size, iterations=args.iterations, root='./data', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader(dataset, **train_kwargs) for batch_idx, (data, target) in enumerate(loader): if kill.is_set(): logger.info('kill signal received, exiting fill_queue') break added = False while not added and not kill.is_set(): try: # convert the data to bytestrings and add to queue q.put((data.numpy().tobytes(), target.type(torch.int8).numpy().tobytes()), timeout=1) #print(f'DEBUG: Added example to queue') added = True except: continue logger.info('Finished filling queue with dataset.') def start(kill_event, args): q = mp.Queue(maxsize=32) queuing_process = mp.Process(target=fill_queue, args=(q, kill_event, args)) queuing_process.start() logger.info('Started queuing process.') server = grpc.server(futures.ThreadPoolExecutor(max_workers=args.grpc_workers)) dataset_feed_pb2_grpc.add_DatasetFeedServicer_to_server( DatasetFeedService(q, kill_event), server) server.add_insecure_port('[::]:6000') server.start() logger.info('gRPC Data Server started at port 6000.') return queuing_process,server def shutdown(queuing_process, grpc_server): logger.info('Shutting down...') logger.info('Stopping gRPC server...') grpc_server.stop(2).wait() logger.info('Stopping queuing process...') queuing_process.join(1) queuing_process.terminate() logger.info('Shutdown done.') import os, time os.system('kill -9 %d' % os.getpid()) def wait_for_shutdown_signal(): SHUTDOWN_PORT = 16000 import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(('', SHUTDOWN_PORT)) s.listen(1) logger.info('Awaiting shutdown signal on port {}'.format(SHUTDOWN_PORT)) conn, addr = s.accept() print('Received shutdown signal from: ', addr) try: conn.close() s.close() except Exception as e: logger.info(e) def serve(args): kill_event = mp.Event() # an mp.Event for graceful shutdown queue_data_loader_process, grpc_server = start(kill_event, args) wait_for_shutdown_signal() kill_event.set() shutdown(queue_data_loader_process, grpc_server) def read_args(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--batch-size", type=int, default=4, metavar="N", help="input batch size for training",) parser.add_argument("--num-data-workers", type=int, default=1, metavar="N", help="based on no. of cpu per training instance",) parser.add_argument("--num-dnn-workers", type=int, default=1, help="based on no. of cpu per training instance",) parser.add_argument("--iterations", type=int, default=10, metavar="N", help="The number of iterations per epoch (multiply of 10)",) parser.add_argument("--grpc-workers", type=int, default=1, metavar="N", help="No. of gRPC server workers",) parser.add_argument("--pin-memory", type=bool, default=1, help="pin to GPU memory (default: True)",) parser.add_argument("--region", type=str, help="aws region") parser.add_argument("--first_data_host", type=str) args, unknown = parser.parse_known_args() return args if __name__ == "__main__": serve(read_args())