import os import numpy as np import skimage import torch from tqdm import tqdm from warnings import warn from .model_io import get_model from .transform import process_aug_dict from .datagen import InferenceTiler from ..raster.image import stitch_images from ..utils.core import get_data_paths class Inferer(object): """Object for training `solaris` models using PyTorch or Keras.""" def __init__(self, config, custom_model_dict=None): self.config = config self.batch_size = self.config['batch_size'] self.framework = self.config['nn_framework'] self.model_name = self.config['model_name'] # check if the model was trained as part of the same pipeline; if so, # use the output from that. If not, use the pre-trained model directly. if self.config['train']: warn('Because the configuration specifies both training and ' 'inference, solaris is switching the model weights path ' 'to the training output path.') self.model_path = self.config['training']['model_dest_path'] if custom_model_dict is not None: custom_model_dict['weight_path'] = self.config[ 'training']['model_dest_path'] else: self.model_path = self.config.get('model_path', None) self.model = get_model(self.model_name, self.framework, self.model_path, pretrained=True, custom_model_dict=custom_model_dict) self.window_step_x = self.config['inference'].get('window_step_size_x', None) self.window_step_y = self.config['inference'].get('window_step_size_y', None) if self.window_step_x is None: self.window_step_x = self.config['data_specs']['width'] if self.window_step_y is None: self.window_step_y = self.config['data_specs']['height'] self.stitching_method = self.config['inference'].get( 'stitching_method', 'average') self.output_dir = self.config['inference']['output_dir'] if not os.path.isdir(self.output_dir): os.makedirs(self.output_dir) def __call__(self, infer_df=None): """Run inference. Arguments --------- infer_df : :class:`pandas.DataFrame` or `str` A :class:`pandas.DataFrame` with a column, ``'image'``, specifying paths to images for inference. Alternatively, `infer_df` can be a path to a CSV file containing the same information. Defaults to ``None``, in which case the file path specified in the Inferer's configuration dict is used. """ if infer_df is None: infer_df = get_infer_df(self.config) inf_tiler = InferenceTiler( self.framework, width=self.config['data_specs']['width'], height=self.config['data_specs']['height'], x_step=self.window_step_x, y_step=self.window_step_y, augmentations=process_aug_dict( self.config['inference_augmentation']) ) for idx, im_path in tqdm(enumerate(infer_df['image']), total=len(infer_df['image'])): inf_input, idx_refs, ( src_im_height, src_im_width) = inf_tiler(im_path) if self.framework == 'keras': subarr_preds = self.model.predict(inf_input, batch_size=self.batch_size) elif self.framework in ['torch', 'pytorch']: with torch.no_grad(): self.model.eval() if torch.cuda.is_available(): device = torch.device('cuda') self.model = self.model.cuda() else: device = torch.device('cpu') inf_input = torch.from_numpy(inf_input).float().to(device) # add additional input data, if applicable if self.config['data_specs'].get('additional_inputs', None) is not None: inf_input = [inf_input] for i in self.config['data_specs']['additional_inputs']: inf_input.append( infer_df[i].iloc[idx].to(device)) # Revision: allow batch process to save Mem cost. subarr_preds_list = [] for batch_i in range(0, inf_input.shape[0], self.batch_size): if batch_i + self.batch_size <= inf_input.shape[0]: subarr_pred = self.model(inf_input[ batch_i:batch_i+self.batch_size, ... ]) else: subarr_pred = self.model(inf_input[ batch_i:, ... ]) subarr_preds_list.append(subarr_pred.cpu().data.numpy()) subarr_preds = np.concatenate(subarr_preds_list, axis=0) stitched_result = stitch_images(subarr_preds, idx_refs=idx_refs, out_width=src_im_width, out_height=src_im_height, method=self.stitching_method) skimage.io.imsave(os.path.join(self.output_dir, os.path.split(im_path)[1]), stitched_result, check_contrast=False) def get_infer_df(config): """Get the inference df based on the contents of ``config`` . This function uses the logic described in the documentation for the config file to determine where to find images to be used for inference. See the docs and the comments in solaris/data/config_skeleton.yml for details. Arguments --------- config : dict The loaded configuration dict for model training and/or inference. Returns ------- infer_df : :class:`dict` :class:`dict` containing at least one column: ``'image'`` . The values in this column correspond to the path to filenames to perform inference on. """ infer_df = get_data_paths(config['inference_data_csv'], infer=True) return infer_df