""" A deepforest callback Callbacks must have the following methods on_epoch_begin, on_epoch_end, on_fit_end, on_fit_begin methods and inject model and epoch kwargs. """ from deepforest import visualize from matplotlib import pyplot as plt import pandas as pd import numpy as np import glob import tempfile from pytorch_lightning import Callback from libs.deepforest import dataset from libs.deepforest import utilities from libs.deepforest import predict import torch class images_callback(Callback): """Run evaluation on a file of annotations during training Args: model: pytorch model csv_file: path to csv with columns, image_path, xmin, ymin, xmax, ymax, label epoch: integer. current epoch experiment: optional comet_ml experiment savedir: optional, directory to save predicted images project: whether to project image coordinates into geographic coordinations, see deepforest.evaluate root_dir: root directory of images to search for 'image path' values from the csv file iou_threshold: intersection-over-union threshold, see deepforest.evaluate probability_threshold: minimum probablity for inclusion, see deepforest.evaluate n: number of images to upload every_n_epochs: run epoch interval Returns: None: either prints validation scores or logs them to a comet experiment """ def __init__(self, csv_file, root_dir, savedir, n=2, every_n_epochs=5): self.savedir = savedir self.root_dir = root_dir self.n = n #limit to n images df = pd.read_csv(csv_file) selected_images = np.random.choice(df.image_path.unique(), self.n) df = df[df.image_path.isin(selected_images)] df.to_csv("{}/image_callback.csv".format(savedir)) self.csv_file = "{}/image_callback.csv".format(savedir) self.every_n_epochs = every_n_epochs def log_images(self, pl_module): boxes = predict.predict_file(model = pl_module.model, csv_file=self.csv_file, root_dir=self.root_dir, savedir=self.savedir, device=pl_module.device) try: saved_plots = glob.glob("{}/*.png".format(self.savedir)) for x in saved_plots: pl_module.logger.experiment.log_image(x) except Exception as e: print( "Could not find logger in lightning module, skipping upload, images were saved to {}, error was rasied {}" .format(self.savedir, e)) def on_epoch_end(self, trainer, pl_module): if trainer.current_epoch % self.every_n_epochs == 0: print("Running image callback") self.log_images(pl_module)