model_name: xdxd_spacenet4

model_path:
train: false
infer: true

pretrained: true
nn_framework:  torch
batch_size: 12

data_specs:
  width: 512
  height:  512
  dtype:
  image_type: zscore
  rescale: false
  rescale_minima: auto
  rescale_maxima: auto
  channels: 4
  label_type: mask
  is_categorical: false
  mask_channels: 1
  val_holdout_frac: 0.2
  data_workers:

training_data_csv: '/path/to/training_df.csv'
validation_data_csv:
inference_data_csv: '/path/to/test_df.csv'

training_augmentation:
  augmentations:
    DropChannel:
      idx: 3
      axis: 2
    HorizontalFlip:
      p: 0.5
    RandomRotate90:
      p: 0.5
    RandomCrop:
      height: 512
      width: 512
      p: 1.0
    Normalize:
      mean:
        - 0.006479
        - 0.009328
        - 0.01123
      std:
        - 0.004986
        - 0.004964
        - 0.004950
      max_pixel_value: 65535.0
      p: 1.0
  p: 1.0
  shuffle: true
validation_augmentation:
  augmentations:
    DropChannel:
      idx: 3
      axis: 2
    CenterCrop:
      height: 512
      width: 512
      p: 1.0
    Normalize:
      mean:
        - 0.006479
        - 0.009328
        - 0.01123
      std:
        - 0.004986
        - 0.004964
        - 0.004950
      max_pixel_value: 65535.0
      p: 1.0
  p: 1.0
inference_augmentation:
  augmentations:
    DropChannel:
      idx: 3
      axis: 2
      p: 1.0
    Normalize:
      mean:
        - 0.006479
        - 0.009328
        - 0.01123
      std:
        - 0.004986
        - 0.004964
        - 0.004950
      max_pixel_value: 65535.0
      p: 1.0
  p: 1.0
training:
  epochs:  60
  steps_per_epoch:
  optimizer: Adam
  lr: 1e-4
  opt_args:
  loss:
    bcewithlogits:
    jaccard:
  loss_weights:
    bcewithlogits: 10
    jaccard: 2.5
  metrics:
    training:
    validation:
  checkpoint_frequency: 10
  callbacks:
    model_checkpoint:
      filepath: 'xdxd_best.pth'
      monitor: val_loss
  model_dest_path: 'xdxd.pth'
  verbose: true

inference:
  window_step_size_x:
  window_step_size_y:
  output_dir: 'inference_out/'