# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, too-many-lines # pylint: disable=too-many-branches, too-many-statements """MXNet model module""" from __future__ import absolute_import, print_function import os import time import logging import warnings from collections import namedtuple import numpy as np from . import io from . import ndarray as nd from . import symbol as sym from . import optimizer as opt from . import metric from . import kvstore as kvs from .context import Context, cpu from .initializer import Uniform from .optimizer import get_updater from .executor_manager import DataParallelExecutorManager, _check_arguments, _load_data from .io import DataDesc from .base import mx_real_t BASE_ESTIMATOR = object try: from sklearn.base import BaseEstimator BASE_ESTIMATOR = BaseEstimator except ImportError: SKLEARN_INSTALLED = False # Parameter to pass to batch_end_callback BatchEndParam = namedtuple('BatchEndParams', ['epoch', 'nbatch', 'eval_metric', 'locals']) def _create_sparse_kvstore(kvstore): """Create kvstore assuming some parameters' storage types are row_sparse. Parameters ---------- kvstore : KVStore or str The kvstore. """ # always update on kvstore update_on_kvstore = True if isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): kv = kvs.create(kvstore) else: raise TypeError("Cannot create '%s' KVStore with row_sparse parameters. " "The type must be KVStore or str." % kvstore) return (kv, update_on_kvstore) def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore This function select and create a proper kvstore if given the kvstore type. Parameters ---------- kvstore : KVStore or str The kvstore. num_device : int The number of devices arg_params : dict of str to `NDArray`. Model parameter, dict of name to `NDArray` of net's weights. """ update_on_kvstore = True if kvstore is None: kv = None elif isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): # create kvstore using the string type if num_device is 1 and 'dist' not in kvstore: # no need to use kv for single device and single machine kv = None else: kv = kvs.create(kvstore) if kvstore == 'local': # automatically select a proper local max_size = max(np.prod(param.shape) for param in arg_params.values()) if max_size > 1024 * 1024 * 16: update_on_kvstore = False else: raise TypeError('kvstore must be KVStore, str or None') if kv is None: update_on_kvstore = False return (kv, update_on_kvstore) def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore): """Initialize kvstore""" for idx, param_on_devs in enumerate(param_arrays): name = param_names[idx] kvstore.init(name, arg_params[name]) if update_on_kvstore: kvstore.pull(name, param_on_devs, priority=-idx) def _update_params_on_kvstore_nccl(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on NCCL kvstore.""" valid_indices = [index for index, grad_list in enumerate(grad_arrays) if grad_list[0] is not None] valid_grad_arrays = [grad_arrays[i] for i in valid_indices] valid_param_arrays = [param_arrays[i] for i in valid_indices] valid_param_names = [param_names[i] for i in valid_indices] size = len(valid_grad_arrays) start = 0 # Use aggregation by default only with NCCL default_batch = '16' batch = int(os.getenv('MXNET_UPDATE_AGGREGATION_SIZE', default_batch)) while start < size: end = start + batch if start + batch < size else size # push gradient, priority is negative index kvstore.push(valid_param_names[start:end], valid_grad_arrays[start:end], priority=-start) # pull back the weights kvstore.pull(valid_param_names[start:end], valid_param_arrays[start:end], priority=-start) start = end def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue name = param_names[index] # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) # pull back the weights kvstore.pull(name, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue index = i if kvstore: name = param_names[index] # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) # pull back the sum gradients, to the same locations. kvstore.pull(name, grad_list, priority=-index) for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) # use a better solution later w, g = p updater(index*num_device+k, g, w) def _multiple_callbacks(callbacks, *args, **kwargs): """Sends args and kwargs to any configured callbacks. This handles the cases where the 'callbacks' variable is ``None``, a single function, or a list. """ if isinstance(callbacks, list): for cb in callbacks: cb(*args, **kwargs) return if callbacks: callbacks(*args, **kwargs) def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, arg_params, aux_params, begin_epoch, end_epoch, epoch_size, optimizer, kvstore, update_on_kvstore, train_data, eval_data=None, eval_metric=None, epoch_end_callback=None, batch_end_callback=None, logger=None, work_load_list=None, monitor=None, eval_end_callback=None, eval_batch_end_callback=None, sym_gen=None): """Internal training function on multiple devices. This function will also work for single device as well. Parameters ---------- symbol : Symbol The network configuration. ctx : list of Context The training devices. arg_names: list of str Name of all arguments of the network. param_names: list of str Name of all trainable parameters of the network. aux_names: list of str Name of all auxiliary states of the network. arg_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's weights. aux_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's auxiliary states. begin_epoch : int The begining training epoch. end_epoch : int The end training epoch. epoch_size : int, optional Number of batches in a epoch. In default, it is set to ``ceil(num_train_examples / batch_size)``. optimizer : Optimizer The optimization algorithm train_data : DataIter Training data iterator. eval_data : DataIter Validation data iterator. eval_metric : EvalMetric An evaluation function or a list of evaluation functions. epoch_end_callback : callable(epoch, symbol, arg_params, aux_states) A callback that is invoked at end of each epoch. This can be used to checkpoint model each epoch. batch_end_callback : callable(BatchEndParams) A callback that is invoked at end of each batch. This can be used to measure speed, get result from evaluation metric. etc. kvstore : KVStore The KVStore. update_on_kvstore : bool Whether or not perform weight updating on kvstore. logger : logging logger When not specified, default logger will be used. work_load_list : list of float or int, optional The list of work load for different devices, in the same order as ``ctx``. monitor : Monitor, optional Monitor installed to executor, for monitoring outputs, weights, and gradients for debugging. Notes ----- - This function will inplace update the NDArrays in `arg_params` and `aux_states`. """ if logger is None: logger = logging executor_manager = DataParallelExecutorManager(symbol=symbol, sym_gen=sym_gen, ctx=ctx, train_data=train_data, param_names=param_names, arg_names=arg_names, aux_names=aux_names, work_load_list=work_load_list, logger=logger) if monitor: executor_manager.install_monitor(monitor) executor_manager.set_params(arg_params, aux_params) if not update_on_kvstore: updater = get_updater(optimizer) else: kvstore.set_optimizer(optimizer) if kvstore: _initialize_kvstore(kvstore=kvstore, param_arrays=executor_manager.param_arrays, arg_params=arg_params, param_names=executor_manager.param_names, update_on_kvstore=update_on_kvstore) # Now start training train_data.reset() for epoch in range(begin_epoch, end_epoch): # Training phase tic = time.time() eval_metric.reset() nbatch = 0 # Iterate over training data. while True: do_reset = True for data_batch in train_data: executor_manager.load_data_batch(data_batch) if monitor is not None: monitor.tic() executor_manager.forward(is_train=True) executor_manager.backward() if update_on_kvstore: if 'nccl' in kvstore.type: _update_params_on_kvstore_nccl(executor_manager.param_arrays, executor_manager.grad_arrays, kvstore, executor_manager.param_names) else: _update_params_on_kvstore(executor_manager.param_arrays, executor_manager.grad_arrays, kvstore, executor_manager.param_names) else: _update_params(executor_manager.param_arrays, executor_manager.grad_arrays, updater=updater, num_device=len(ctx), kvstore=kvstore, param_names=executor_manager.param_names) if monitor is not None: monitor.toc_print() # evaluate at end, so we can lazy copy executor_manager.update_metric(eval_metric, data_batch.label) nbatch += 1 # batch callback (for print purpose) if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) _multiple_callbacks(batch_end_callback, batch_end_params) # this epoch is done possibly earlier if epoch_size is not None and nbatch >= epoch_size: do_reset = False break if do_reset: logger.info('Epoch[%d] Resetting Data Iterator', epoch) train_data.reset() # this epoch is done if epoch_size is None or nbatch >= epoch_size: break toc = time.time() logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) if epoch_end_callback or epoch + 1 == end_epoch: executor_manager.copy_to(arg_params, aux_params) _multiple_callbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params) # evaluation if eval_data: eval_metric.reset() eval_data.reset() total_num_batch = 0 for i, eval_batch in enumerate(eval_data): executor_manager.load_data_batch(eval_batch) executor_manager.forward(is_train=False) executor_manager.update_metric(eval_metric, eval_batch.label) if eval_batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=i, eval_metric=eval_metric, locals=locals()) _multiple_callbacks(eval_batch_end_callback, batch_end_params) total_num_batch += 1 if eval_end_callback is not None: eval_end_params = BatchEndParam(epoch=epoch, nbatch=total_num_batch, eval_metric=eval_metric, locals=locals()) _multiple_callbacks(eval_end_callback, eval_end_params) eval_data.reset() # end of all epochs def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params): """Checkpoint the model data into file. Parameters ---------- prefix : str Prefix of model name. epoch : int The epoch number of the model. symbol : Symbol The input Symbol. arg_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's weights. aux_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's auxiliary states. Notes ----- - ``prefix-symbol.json`` will be saved for symbol. - ``prefix-epoch.params`` will be saved for parameters. """ if symbol is not None: symbol.save('%s-symbol.json' % prefix) save_dict = {('arg:%s' % k) : v.as_in_context(cpu()) for k, v in arg_params.items()} save_dict.update({('aux:%s' % k) : v.as_in_context(cpu()) for k, v in aux_params.items()}) param_name = '%s-%04d.params' % (prefix, epoch) nd.save(param_name, save_dict) logging.info('Saved checkpoint to \"%s\"', param_name) def load_checkpoint(prefix, epoch): """Load model checkpoint from file. Parameters ---------- prefix : str Prefix of model name. epoch : int Epoch number of model we would like to load. Returns ------- symbol : Symbol The symbol configuration of computation network. arg_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's weights. aux_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's auxiliary states. Notes ----- - Symbol will be loaded from ``prefix-symbol.json``. - Parameters will be loaded from ``prefix-epoch.params``. """ symbol = sym.load('%s-symbol.json' % prefix) save_dict = nd.load('%s-%04d.params' % (prefix, epoch)) arg_params = {} aux_params = {} for k, v in save_dict.items(): tp, name = k.split(':', 1) if tp == 'arg': arg_params[name] = v if tp == 'aux': aux_params[name] = v return (symbol, arg_params, aux_params) from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position class FeedForward(BASE_ESTIMATOR): """Model class of MXNet for training and predicting feedforward nets. This class is designed for a single-data single output supervised network. Parameters ---------- symbol : Symbol The symbol configuration of computation network. ctx : Context or list of Context, optional The device context of training and prediction. To use multi GPU training, pass in a list of gpu contexts. num_epoch : int, optional Training parameter, number of training epochs(epochs). epoch_size : int, optional Number of batches in a epoch. In default, it is set to ``ceil(num_train_examples / batch_size)``. optimizer : str or Optimizer, optional Training parameter, name or optimizer object for training. initializer : initializer function, optional Training parameter, the initialization scheme used. numpy_batch_size : int, optional The batch size of training data. Only needed when input array is numpy. arg_params : dict of str to NDArray, optional Model parameter, dict of name to NDArray of net's weights. aux_params : dict of str to NDArray, optional Model parameter, dict of name to NDArray of net's auxiliary states. allow_extra_params : boolean, optional Whether allow extra parameters that are not needed by symbol to be passed by aux_params and ``arg_params``. If this is True, no error will be thrown when ``aux_params`` and ``arg_params`` contain more parameters than needed. begin_epoch : int, optional The begining training epoch. kwargs : dict The additional keyword arguments passed to optimizer. """ def __init__(self, symbol, ctx=None, num_epoch=None, epoch_size=None, optimizer='sgd', initializer=Uniform(0.01), numpy_batch_size=128, arg_params=None, aux_params=None, allow_extra_params=False, begin_epoch=0, **kwargs): warnings.warn( '\033[91mmxnet.model.FeedForward has been deprecated. ' + \ 'Please use mxnet.mod.Module instead.\033[0m', DeprecationWarning, stacklevel=2) if isinstance(symbol, sym.Symbol): self.symbol = symbol self.sym_gen = None else: assert(callable(symbol)) self.symbol = None self.sym_gen = symbol # model parameters self.arg_params = arg_params self.aux_params = aux_params self.allow_extra_params = allow_extra_params self.argument_checked = False if self.sym_gen is None: self._check_arguments() # basic configuration if ctx is None: ctx = [cpu()] elif isinstance(ctx, Context): ctx = [ctx] self.ctx = ctx # training parameters self.num_epoch = num_epoch self.epoch_size = epoch_size self.kwargs = kwargs.copy() self.optimizer = optimizer self.initializer = initializer self.numpy_batch_size = numpy_batch_size # internal helper state self._pred_exec = None self.begin_epoch = begin_epoch def _check_arguments(self): """verify the argument of the default symbol and user provided parameters""" if self.argument_checked: return assert(self.symbol is not None) self.argument_checked = True # check if symbol contain duplicated names. _check_arguments(self.symbol) # rematch parameters to delete useless ones if self.allow_extra_params: if self.arg_params: arg_names = set(self.symbol.list_arguments()) self.arg_params = {k : v for k, v in self.arg_params.items() if k in arg_names} if self.aux_params: aux_names = set(self.symbol.list_auxiliary_states()) self.aux_params = {k : v for k, v in self.aux_params.items() if k in aux_names} @staticmethod def _is_data_arg(name): """Check if name is a data argument.""" return name.endswith('data') or name.endswith('label') def _init_params(self, inputs, overwrite=False): """Initialize weight parameters and auxiliary states.""" inputs = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in inputs] input_shapes = {item.name: item.shape for item in inputs} arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes) assert arg_shapes is not None input_dtypes = {item.name: item.dtype for item in inputs} arg_dtypes, _, aux_dtypes = self.symbol.infer_type(**input_dtypes) assert arg_dtypes is not None arg_names = self.symbol.list_arguments() input_names = input_shapes.keys() param_names = [key for key in arg_names if key not in input_names] aux_names = self.symbol.list_auxiliary_states() param_name_attrs = [x for x in zip(arg_names, arg_shapes, arg_dtypes) if x[0] in param_names] arg_params = {k : nd.zeros(shape=s, dtype=t) for k, s, t in param_name_attrs} aux_name_attrs = [x for x in zip(aux_names, aux_shapes, aux_dtypes) if x[0] in aux_names] aux_params = {k : nd.zeros(shape=s, dtype=t) for k, s, t in aux_name_attrs} for k, v in arg_params.items(): if self.arg_params and k in self.arg_params and (not overwrite): arg_params[k][:] = self.arg_params[k][:] else: self.initializer(k, v) for k, v in aux_params.items(): if self.aux_params and k in self.aux_params and (not overwrite): aux_params[k][:] = self.aux_params[k][:] else: self.initializer(k, v) self.arg_params = arg_params self.aux_params = aux_params return (arg_names, list(param_names), aux_names) def __getstate__(self): this = self.__dict__.copy() this['_pred_exec'] = None return this def __setstate__(self, state): self.__dict__.update(state) def _init_predictor(self, input_shapes, type_dict=None): """Initialize the predictor module for running prediction.""" shapes = {name: self.arg_params[name].shape for name in self.arg_params} shapes.update(dict(input_shapes)) if self._pred_exec is not None: arg_shapes, _, _ = self.symbol.infer_shape(**shapes) assert arg_shapes is not None, "Incomplete input shapes" pred_shapes = [x.shape for x in self._pred_exec.arg_arrays] if arg_shapes == pred_shapes: return # for now only use the first device pred_exec = self.symbol.simple_bind( self.ctx[0], grad_req='null', type_dict=type_dict, **shapes) pred_exec.copy_params_from(self.arg_params, self.aux_params) _check_arguments(self.symbol) self._pred_exec = pred_exec def _init_iter(self, X, y, is_train): """Initialize the iterator given input.""" if isinstance(X, (np.ndarray, nd.NDArray)): if y is None: if is_train: raise ValueError('y must be specified when X is numpy.ndarray') else: y = np.zeros(X.shape[0]) if not isinstance(y, (np.ndarray, nd.NDArray)): raise TypeError('y must be ndarray when X is numpy.ndarray') if X.shape[0] != y.shape[0]: raise ValueError("The numbers of data points and labels not equal") if y.ndim == 2 and y.shape[1] == 1: y = y.flatten() if y.ndim != 1: raise ValueError("Label must be 1D or 2D (with 2nd dimension being 1)") if is_train: return io.NDArrayIter(X, y, min(X.shape[0], self.numpy_batch_size), shuffle=is_train, last_batch_handle='roll_over') else: return io.NDArrayIter(X, y, min(X.shape[0], self.numpy_batch_size), shuffle=False) if not isinstance(X, io.DataIter): raise TypeError('X must be DataIter, NDArray or numpy.ndarray') return X def _init_eval_iter(self, eval_data): """Initialize the iterator given eval_data.""" if eval_data is None: return eval_data if isinstance(eval_data, (tuple, list)) and len(eval_data) == 2: if eval_data[0] is not None: if eval_data[1] is None and isinstance(eval_data[0], io.DataIter): return eval_data[0] input_data = (np.array(eval_data[0]) if isinstance(eval_data[0], list) else eval_data[0]) input_label = (np.array(eval_data[1]) if isinstance(eval_data[1], list) else eval_data[1]) return self._init_iter(input_data, input_label, is_train=True) else: raise ValueError("Eval data is NONE") if not isinstance(eval_data, io.DataIter): raise TypeError('Eval data must be DataIter, or ' \ 'NDArray/numpy.ndarray/list pair (i.e. tuple/list of length 2)') return eval_data def predict(self, X, num_batch=None, return_data=False, reset=True): """Run the prediction, always only use one device. Parameters ---------- X : mxnet.DataIter num_batch : int or None The number of batch to run. Go though all batches if ``None``. Returns ------- y : numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs. The predicted value of the output. """ X = self._init_iter(X, None, is_train=False) if reset: X.reset() data_shapes = X.provide_data data_names = [x[0] for x in data_shapes] type_dict = dict((key, value.dtype) for (key, value) in self.arg_params.items()) for x in X.provide_data: if isinstance(x, DataDesc): type_dict[x.name] = x.dtype else: type_dict[x[0]] = mx_real_t self._init_predictor(data_shapes, type_dict) batch_size = X.batch_size data_arrays = [self._pred_exec.arg_dict[name] for name in data_names] output_list = [[] for _ in range(len(self._pred_exec.outputs))] if return_data: data_list = [[] for _ in X.provide_data] label_list = [[] for _ in X.provide_label] i = 0 for batch in X: _load_data(batch, data_arrays) self._pred_exec.forward(is_train=False) padded = batch.pad real_size = batch_size - padded for o_list, o_nd in zip(output_list, self._pred_exec.outputs): o_list.append(o_nd[0:real_size].asnumpy()) if return_data: for j, x in enumerate(batch.data): data_list[j].append(x[0:real_size].asnumpy()) for j, x in enumerate(batch.label): label_list[j].append(x[0:real_size].asnumpy()) i += 1 if num_batch is not None and i == num_batch: break outputs = [np.concatenate(x) for x in output_list] if len(outputs) == 1: outputs = outputs[0] if return_data: data = [np.concatenate(x) for x in data_list] label = [np.concatenate(x) for x in label_list] if len(data) == 1: data = data[0] if len(label) == 1: label = label[0] return outputs, data, label else: return outputs def score(self, X, eval_metric='acc', num_batch=None, batch_end_callback=None, reset=True): """Run the model given an input and calculate the score as assessed by an evaluation metric. Parameters ---------- X : mxnet.DataIter eval_metric : metric.metric The metric for calculating score. num_batch : int or None The number of batches to run. Go though all batches if ``None``. Returns ------- s : float The final score. """ # setup metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) X = self._init_iter(X, None, is_train=False) if reset: X.reset() data_shapes = X.provide_data data_names = [x[0] for x in data_shapes] type_dict = dict((key, value.dtype) for (key, value) in self.arg_params.items()) for x in X.provide_data: if isinstance(x, DataDesc): type_dict[x.name] = x.dtype else: type_dict[x[0]] = mx_real_t self._init_predictor(data_shapes, type_dict) data_arrays = [self._pred_exec.arg_dict[name] for name in data_names] for i, batch in enumerate(X): if num_batch is not None and i == num_batch: break _load_data(batch, data_arrays) self._pred_exec.forward(is_train=False) eval_metric.update(batch.label, self._pred_exec.outputs) if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=0, nbatch=i, eval_metric=eval_metric, locals=locals()) _multiple_callbacks(batch_end_callback, batch_end_params) return eval_metric.get()[1] def fit(self, X, y=None, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None, work_load_list=None, monitor=None, eval_end_callback=LogValidationMetricsCallback(), eval_batch_end_callback=None): """Fit the model. Parameters ---------- X : DataIter, or numpy.ndarray/NDArray Training data. If `X` is a `DataIter`, the name or (if name not available) the position of its outputs should match the corresponding variable names defined in the symbolic graph. y : numpy.ndarray/NDArray, optional Training set label. If X is ``numpy.ndarray`` or `NDArray`, `y` is required to be set. While y can be 1D or 2D (with 2nd dimension as 1), its first dimension must be the same as `X`, i.e. the number of data points and labels should be equal. eval_data : DataIter or numpy.ndarray/list/NDArray pair If eval_data is numpy.ndarray/list/NDArray pair, it should be ``(valid_data, valid_label)``. eval_metric : metric.EvalMetric or str or callable The evaluation metric. This could be the name of evaluation metric or a custom evaluation function that returns statistics based on a minibatch. epoch_end_callback : callable(epoch, symbol, arg_params, aux_states) A callback that is invoked at end of each epoch. This can be used to checkpoint model each epoch. batch_end_callback: callable(epoch) A callback that is invoked at end of each batch for purposes of printing. kvstore: KVStore or str, optional The KVStore or a string kvstore type: 'local', 'dist_sync', 'dist_async' In default uses 'local', often no need to change for single machiine. logger : logging logger, optional When not specified, default logger will be used. work_load_list : float or int, optional The list of work load for different devices, in the same order as `ctx`. Note ---- KVStore behavior - 'local', multi-devices on a single machine, will automatically choose best type. - 'dist_sync', multiple machines communicating via BSP. - 'dist_async', multiple machines with asynchronous communication. """ data = self._init_iter(X, y, is_train=True) eval_data = self._init_eval_iter(eval_data) if self.sym_gen: self.symbol = self.sym_gen(data.default_bucket_key) # pylint: disable=no-member self._check_arguments() self.kwargs["sym"] = self.symbol arg_names, param_names, aux_names = \ self._init_params(data.provide_data+data.provide_label) # setup metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) # create kvstore (kvstore, update_on_kvstore) = _create_kvstore( kvstore, len(self.ctx), self.arg_params) param_idx2name = {} if update_on_kvstore: param_idx2name.update(enumerate(param_names)) else: for i, n in enumerate(param_names): for k in range(len(self.ctx)): param_idx2name[i*len(self.ctx)+k] = n self.kwargs["param_idx2name"] = param_idx2name # init optmizer if isinstance(self.optimizer, str): batch_size = data.batch_size if kvstore and 'dist' in kvstore.type and '_async' not in kvstore.type: batch_size *= kvstore.num_workers optimizer = opt.create(self.optimizer, rescale_grad=(1.0/batch_size), **(self.kwargs)) elif isinstance(self.optimizer, opt.Optimizer): optimizer = self.optimizer # do training _train_multi_device(self.symbol, self.ctx, arg_names, param_names, aux_names, self.arg_params, self.aux_params, begin_epoch=self.begin_epoch, end_epoch=self.num_epoch, epoch_size=self.epoch_size, optimizer=optimizer, train_data=data, eval_data=eval_data, eval_metric=eval_metric, epoch_end_callback=epoch_end_callback, batch_end_callback=batch_end_callback, kvstore=kvstore, update_on_kvstore=update_on_kvstore, logger=logger, work_load_list=work_load_list, monitor=monitor, eval_end_callback=eval_end_callback, eval_batch_end_callback=eval_batch_end_callback, sym_gen=self.sym_gen) def save(self, prefix, epoch=None): """Checkpoint the model checkpoint into file. You can also use `pickle` to do the job if you only work on Python. The advantage of `load` and `save` (as compared to `pickle`) is that the resulting file can be loaded from other MXNet language bindings. One can also directly `load`/`save` from/to cloud storage(S3, HDFS) Parameters ---------- prefix : str Prefix of model name. Notes ----- - ``prefix-symbol.json`` will be saved for symbol. - ``prefix-epoch.params`` will be saved for parameters. """ if epoch is None: epoch = self.num_epoch assert epoch is not None save_checkpoint(prefix, epoch, self.symbol, self.arg_params, self.aux_params) @staticmethod def load(prefix, epoch, ctx=None, **kwargs): """Load model checkpoint from file. Parameters ---------- prefix : str Prefix of model name. epoch : int epoch number of model we would like to load. ctx : Context or list of Context, optional The device context of training and prediction. kwargs : dict Other parameters for model, including `num_epoch`, optimizer and `numpy_batch_size`. Returns ------- model : FeedForward The loaded model that can be used for prediction. Notes ----- - ``prefix-symbol.json`` will be saved for symbol. - ``prefix-epoch.params`` will be saved for parameters. """ symbol, arg_params, aux_params = load_checkpoint(prefix, epoch) return FeedForward(symbol, ctx=ctx, arg_params=arg_params, aux_params=aux_params, begin_epoch=epoch, **kwargs) @staticmethod def create(symbol, X, y=None, ctx=None, num_epoch=None, epoch_size=None, optimizer='sgd', initializer=Uniform(0.01), eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None, work_load_list=None, eval_end_callback=LogValidationMetricsCallback(), eval_batch_end_callback=None, **kwargs): """Functional style to create a model. This function is more consistent with functional languages such as R, where mutation is not allowed. Parameters ---------- symbol : Symbol The symbol configuration of a computation network. X : DataIter Training data. y : numpy.ndarray, optional If `X` is a ``numpy.ndarray``, `y` must be set. ctx : Context or list of Context, optional The device context of training and prediction. To use multi-GPU training, pass in a list of GPU contexts. num_epoch : int, optional The number of training epochs(epochs). epoch_size : int, optional Number of batches in a epoch. In default, it is set to ``ceil(num_train_examples / batch_size)``. optimizer : str or Optimizer, optional The name of the chosen optimizer, or an optimizer object, used for training. initializer : initializer function, optional The initialization scheme used. eval_data : DataIter or numpy.ndarray pair If `eval_set` is ``numpy.ndarray`` pair, it should be (`valid_data`, `valid_label`). eval_metric : metric.EvalMetric or str or callable The evaluation metric. Can be the name of an evaluation metric or a custom evaluation function that returns statistics based on a minibatch. epoch_end_callback : callable(epoch, symbol, arg_params, aux_states) A callback that is invoked at end of each epoch. This can be used to checkpoint model each epoch. batch_end_callback: callable(epoch) A callback that is invoked at end of each batch for print purposes. kvstore: KVStore or str, optional The KVStore or a string kvstore type: 'local', 'dist_sync', 'dis_async'. Defaults to 'local', often no need to change for single machine. logger : logging logger, optional When not specified, default logger will be used. work_load_list : list of float or int, optional The list of work load for different devices, in the same order as `ctx`. """ model = FeedForward(symbol, ctx=ctx, num_epoch=num_epoch, epoch_size=epoch_size, optimizer=optimizer, initializer=initializer, **kwargs) model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric, epoch_end_callback=epoch_end_callback, batch_end_callback=batch_end_callback, kvstore=kvstore, logger=logger, work_load_list=work_load_list, eval_end_callback=eval_end_callback, eval_batch_end_callback=eval_batch_end_callback) return model