#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (c) 2019, Open-MMLab. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABCMeta, abstractmethod from ..hook import Hook class LoggerHook(Hook): """Base class for logger hooks. Args: interval (int): Logging interval (every k iterations). ignore_last (bool): Ignore the log of last iterations in each epoch if less than `interval`. reset_flag (bool): Whether to clear the output buffer after logging. """ __metaclass__ = ABCMeta def __init__(self, interval=10, ignore_last=True, reset_flag=False): self.interval = interval self.ignore_last = ignore_last self.reset_flag = reset_flag @abstractmethod def log(self, runner): pass def before_run(self, runner): for hook in runner.hooks[::-1]: if isinstance(hook, LoggerHook): hook.reset_flag = True break def before_epoch(self, runner): runner.log_buffer.clear() # clear logs of last epoch def after_train_iter(self, runner): if self.every_n_inner_iters(runner, self.interval): runner.log_buffer.average(self.interval) elif self.end_of_epoch(runner) and not self.ignore_last: # not precise but more stable runner.log_buffer.average(self.interval) if runner.log_buffer.ready: self.log(runner) if self.reset_flag: runner.log_buffer.clear_output() def after_train_epoch(self, runner): if runner.log_buffer.ready: self.log(runner) if self.reset_flag: runner.log_buffer.clear_output() def after_val_epoch(self, runner): runner.log_buffer.average() self.log(runner) if self.reset_flag: runner.log_buffer.clear_output()