################################################################################
################# SOLARIS MODEL CONFIGURATION SKELETON #########################
################################################################################

# This skeleton lays out the required instructions for running a model using
# solaris. See the full documentation at [INCLUDE DOC LINK HERE] for details on
# options, required arguments, and sample usage.

model_name: # include the name of the model to be used here. See the docs
            # for options.
model_path: # leave this blank unless you're using a custom model not
                # native to solaris. solaris will automatically find your
                # model.
train: true  # set to false for inference only
infer: true  # set to false for training only

pretrained: true  # use pretrained weights associated with the model?
nn_framework:  # if not using a model included with the package, use this
               # argument to specify if it uses keras, pytorch, tf, etc.
batch_size:  # size of each batch fed into nn.

data_specs:
  width: # width of the input images taken in by the neural net.
  height:  # height of the input images taken in by the neural net.
  dtype:  # dtype of the inputs ingested by the neural net.
  rescale: false  # should image pixel values be rescaled before pre-processing?
                  # If so, the image will be rescaled to the pixel range defined
                  # by rescale_min and rescale_max below.
  rescale_minima: auto  # the minimum values to use in rescaling (if
                        # rescale=true). If 'auto', the minimum pixel intensity
                        # in each channel of the image will be subtracted. If
                        # a single value is provided, that value will be set to
                        # zero for each channel in the input image.
                        # if a list of values are provided, then minima in the
                        # separate channels (in that order) will be set to that
                        # value PRIOR to any augmentation.
  rescale_maxima: auto  # same as rescale_minima, but for the maximum value for
                        # each channel in the image.
  channels:  # number of channels in the input imagery.
  label_type: mask  # one of ['mask', 'bbox'] (CURRENTLY ONLY MASK IMPLEMENTED)
  is_categorical: false  # are the labels binary (default) or categorical?
  mask_channels: 1  # number of channels in the training mask
  val_holdout_frac:  # if empty, assumes that separate data ref files define the
                     # training and validation dataset. If a float between 0 and
                     # 1, indicates the fraction of training data that's held
                     # out for validation (and validation_data_csv will be
                     # ignored)
  data_workers:  # number of cpu threads to use for loading and preprocessing
                 # input images.
#  other_inputs:  # this can provide a list of additional inputs to pass to the
                 # neural net for training. These inputs should be specified in
                 # extra columns of the csv files (denoted below), either as
                 # filepaths to additional data to load or as values to include.
                 # NOTE: This is not currently implemented.

training_data_csv:  # path to the reference csv that defines training data.
                    # see the documentation for the specifications of this file.
validation_data_csv:  # path to the validation ref csv. See the docs. If
                      # val_holdout_frac is specified (under data_specs), then
                      # this argument will be ignored.
inference_data_csv:  # path to the reference csv that defines inference data.
                     # see the documentation for the specs of this file.

training_augmentation:  # augmentations for use with training data
  augmentations:
    # include augmentations here. See the documentation for options and
    # required arguments.
  p: 1.0  # probability of applying the entire training augmentation pipeline.
  shuffle: true  # should the image order be shuffled in each epoch.
validation_augmentation:  # augmentations for use with validation data
  augmentations:
    # include augmentations here
  p:  # probability of applying the full validation augmentation pipeline.
inference_augmentation:  # this is optional. If not provided,
                         # validation_augmentation will be used instead.

training:
  epochs:  # number of epochs. A list can also be provided here indicating
           # distinct sets of epochs at different learning rates, etc; if so,
           # a list of equal length must be provided in the parameter that varies
           # with the values for each set of epochs.
  steps_per_epoch:  # optional argument defining # steps/epoch. If not provided,
                    # each epoch will include the number of steps needed to go
                    # through the entire training dataset.
  optimizer:  # optimizer function name. see docs for options.
  lr:  # learning rate.
  opt_args:  # dictionary of values (e.g. alpha, gamma, momentum) specific to
             # the optimizer.
  loss:  # loss function(s). See docs for options. This should be a list of loss
         # names with sublists of loss function hyperparameters (if applicable).
         # See the docs for more details and usage examples.
  loss_weights:  # (optional) weights to use for each loss function if using
                 # loss: composite. This must be a set of key:value pairs where
                 # defining the weight for each sub-loss within the composite.
                 # If using a composite and a value isn't provided here, all
                 # losses will be weighted equally.
  metrics:  # metrics to monitor on the training and validation sets.
    training:  # training set metrics.
    validation:  # validation set metrics.
  checkpoint_frequency:  # how frequently should checkpoints be saved?
                         # this can be an int, in which case every n epochs
                         # a checkpoint will be made, or a list, in which case
                         # checkpoints will be saved on those specific epochs.
                         # if not provided, only the final model is saved.
  callbacks:  # a list of callbacks to use.
  model_dest_path:   # path to save the trained model output and checkpoint(s)
                     # to. Should be a filename ending in .h5, .hdf5 for keras
                     # or .pth, .pt for torch. Epoch numbers will be appended
                     # for checkpoints.
  verbose: true  # verbose text output during training

inference:
  window_step_size_x:  # size of each step for the sliding window for inference.
                       # set to the same size as the input image size for zero
                       # overlap; to average predictions across multiple images,
                       # use a smaller step size.
  window_step_size_y:  # size of each step for the sliding window for inference.
                       # set to the same size as the input image size for zero
                       # overlap; to average predictions across multiple images,
                       # use a smaller step size.
  stitching_method:    # the method to use to stitch together tiles used during
                       # inference. defaults to average if not provided. see
                       # the documentation for sol.raster.image.stitch_images()
                       # for more.
  output_dir:  inference_out # the path to save inference outputs to.