# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import argparse import json import logging import os import sys import torch import torch.distributed as dist from torch.multiprocessing import Process import torch.utils.data import torch.utils.data.distributed logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) def _get_tensor(rank, rows, columns): device = torch.device( "cuda:{}".format(dist.get_rank() % torch.cuda.device_count()) if torch.cuda.is_available() else "cpu" ) tensor = torch.ones(rows, columns) * (rank + 1) return tensor.to(device) def _get_zeros_tensor(rows, columns): device = torch.device( "cuda:{}".format(dist.get_rank() % torch.cuda.device_count()) if torch.cuda.is_available() else "cpu" ) tensor = torch.zeros(rows, columns) return tensor.to(device) def _get_zeros_tensors_list(rows, columns): return [_get_zeros_tensor(rows, columns) for _ in range(dist.get_world_size())] def _get_tensors_sum(rows, columns): device = torch.device( "cuda:{}".format(dist.get_rank() % torch.cuda.device_count()) if torch.cuda.is_available() else "cpu" ) result = (1 + dist.get_world_size()) * dist.get_world_size() / 2 tensor = torch.ones(rows, columns) * result return tensor.to(device) def _send_recv(rank, rows, columns): source = 0 tensor = _get_tensor(rank, rows, columns) logger.debug("Rank: {},\nTensor BEFORE send_recv: {}".format(rank, tensor)) if rank == 0: for i in range(1, dist.get_world_size()): dist.send(tensor=tensor, dst=i) else: dist.recv(tensor=tensor, src=source) logger.debug("Rank: {},\nTensor AFTER send_recv: {}\n".format(rank, tensor)) assert torch.equal( tensor, _get_tensor(source, rows, columns) ), "Rank {}: Tensor was not equal to rank {} tensor after send-recv.".format(rank, source) def _broadcast(rank, rows, columns): source = 0 tensor = _get_tensor(rank, rows, columns) logger.debug("Rank: {},\nTensor BEFORE broadcast: {}".format(rank, tensor)) dist.broadcast(tensor, src=source) logger.debug("Rank: {},\nTensor AFTER broadcast: {}\n".format(rank, tensor)) assert torch.equal( tensor, _get_tensor(source, rows, columns) ), "Rank {}: Tensor was not equal to rank {} tensor after broadcast.".format(rank, source) def _all_reduce(rank, rows, columns): tensor = _get_tensor(rank, rows, columns) logger.debug("Rank: {},\nTensor BEFORE all_reduce: {}".format(rank, tensor)) dist.all_reduce(tensor, op=dist.reduce_op.SUM) logger.debug("Rank: {},\nTensor AFTER all_reduce: {}\n".format(rank, tensor)) assert torch.equal( tensor, _get_tensors_sum(rows, columns) ), "Rank {}: Tensor was not equal to SUM of {} tensors after all_reduce.".format( rank, dist.get_world_size() ) def _reduce(rank, rows, columns): dest = 0 tensor = _get_tensor(rank, rows, columns) logger.debug("Rank: {},\nTensor BEFORE reduce: {}".format(rank, tensor)) # this is inplace operation dist.reduce(tensor, op=dist.reduce_op.SUM, dst=dest) logger.debug("Rank: {},\nTensor AFTER reduce: {}\n".format(rank, tensor)) if rank == dest: assert torch.equal( tensor, _get_tensors_sum(rows, columns) ), "Rank {}: Tensor was not equal to SUM of {} tensors after reduce.".format( rank, dist.get_world_size() ) def _all_gather(rank, rows, columns): tensor = _get_tensor(rank, rows, columns) tensors_list = _get_zeros_tensors_list(rows, columns) logger.debug("Rank: {},\nTensor BEFORE all_gather: {}".format(rank, tensor)) dist.all_gather(tensors_list, tensor) logger.debug( "Rank: {},\nTensor AFTER all_gather: {}. tensors_list: {}\n".format( rank, tensor, tensors_list ) ) # tensor shouldn't have changed assert torch.equal( tensor, _get_tensor(rank, rows, columns) ), "Rank {}: Tensor got changed after all_gather.".format(rank) for i in range(dist.get_world_size()): assert torch.equal( tensors_list[i], _get_tensor(i, rows, columns) ), "Rank {}: tensors lists are not the same after all_gather." def _gather(rank, rows, columns): dest = 0 tensor = _get_tensor(rank, rows, columns) if rank == dest: tensors_list = _get_zeros_tensors_list(rows, columns) logger.debug( "Rank: {},\nTensor BEFORE gather: {}. tensors_list: {}".format( rank, tensor, tensors_list ) ) dist.gather(tensor=tensor, gather_list=tensors_list) logger.debug( "Rank: {},\nTensor AFTER gather: {}. tensors_list: {}\n".format( rank, tensor, tensors_list ) ) for i in range(dist.get_world_size()): assert torch.equal( tensors_list[i], _get_tensor(i, rows, columns) ), "Rank {}: tensors lists are not the same after gather." else: logger.debug("Rank: {},\nTensor BEFORE gather: {}\n".format(rank, tensor)) dist.gather(tensor=tensor, dst=dest) logger.debug("Rank: {},\nTensor AFTER gather: {}\n".format(rank, tensor)) # tensor shouldn't have changed assert torch.equal( tensor, _get_tensor(rank, rows, columns) ), "Rank {}: Tensor got changed after gather.".format(rank) def _scatter(rank, rows, columns): source = 0 tensor = _get_tensor(rank, rows, columns) if rank == source: tensors_list = _get_zeros_tensors_list(rows, columns) logger.debug( "Rank: {},\nTensor BEFORE scatter: {}. tensors_list: {}".format( rank, tensor, tensors_list ) ) dist.scatter(tensor=tensor, scatter_list=tensors_list) else: logger.debug("Rank: {},\nTensor BEFORE scatter: {}\n".format(rank, tensor)) dist.scatter(tensor=tensor, src=source) logger.debug("Rank: {},\nTensor AFTER scatter: {}\n".format(rank, tensor)) assert torch.equal( tensor, _get_zeros_tensor(rows, columns) ), "Rank {}: Tensor should be all zeroes after scatter.".format(rank) def _barrier(rank): logger.debug("Rank: {}, Waiting for other processes before the barrier.".format(rank)) dist.barrier() logger.debug("Rank: {}, Passing the barrier".format(rank)) def main(): print("Starting") parser = argparse.ArgumentParser() # Configurable hyperparameters parser.add_argument("--rows", type=int, default=1, help="Number of rows in the tensor.") parser.add_argument("--columns", type=int, default=1, help="Number of columns in the tensor.") parser.add_argument( "--backend", type=str, default=None, help="backend for distributed operations." ) # Container environment parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"])) parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"]) parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"]) parser.add_argument("--num-cpus", type=int, default=os.environ["SM_NUM_CPUS"]) args = parser.parse_args() number_of_processes = args.num_gpus if args.num_gpus > 0 else args.num_cpus world_size = number_of_processes * len(args.hosts) logger.info( "Running '{}' backend on {} nodes and {} processes. World size is {}.".format( args.backend, len(args.hosts), number_of_processes, world_size ) ) host_rank = args.hosts.index(args.current_host) master_addr = args.hosts[0] master_port = "55555" processes = [] for rank in range(number_of_processes): process_rank = host_rank * number_of_processes + rank p = Process( target=init_processes, args=( args.backend, master_addr, master_port, process_rank, world_size, args.rows, args.columns, args.current_host, args.num_gpus, ), ) p.start() processes.append(p) for p in processes: p.join() save("success", args.model_dir) def init_processes( backend, master_addr, master_port, rank, world_size, rows, columns, host, num_gpus ): # Initialize the distributed environment. os.environ["WORLD_SIZE"] = str(world_size) os.environ["RANK"] = str(rank) os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port logger.info("Init process rank {} on host '{}'".format(rank, host)) dist.init_process_group(backend=backend, rank=rank, world_size=world_size) run(backend, rank, rows, columns, num_gpus) def run(backend, rank, rows, columns, num_gpus): # https://pytorch.org/docs/master/distributed.html if backend == "gloo": print("Run operations supported by 'gloo' backend.") _broadcast(rank, rows, columns) _all_reduce(rank, rows, columns) _barrier(rank) # this operation supported only on cpu if num_gpus == 0: _send_recv(rank, rows, columns) elif backend == "nccl": print("Run operations supported by 'nccl' backend.") # Note: nccl does not support gather or scatter as well: # https://github.com/pytorch/pytorch/blob/v0.4.0/torch/lib/THD/base/data_channels/DataChannelNccl.cpp _broadcast(rank, rows, columns) _all_reduce(rank, rows, columns) _reduce(rank, rows, columns) _all_gather(rank, rows, columns) def save(result, model_dir): filename = os.path.join(model_dir, result) if not os.path.exists(filename): logger.info("Saving success result") with open(filename, "w") as f: f.write(result) if __name__ == "__main__": from torch.multiprocessing import set_start_method try: set_start_method("spawn") except RuntimeError: pass main()