# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 # -*- coding: utf-8 -*- from ..dist_utils import master_only from .hook import Hook import os.path as osp class CheckpointHook(Hook): def __init__(self, interval=-1, save_optimizer=True, out_dir=None, **kwargs): self.interval = interval self.save_optimizer = save_optimizer self.out_dir = out_dir self.args = kwargs @master_only def after_train_epoch(self, runner): if not self.every_n_epochs(runner, self.interval): return if not self.out_dir: self.out_dir = runner.work_dir runner.save_checkpoint(osp.join(self.out_dir, "{:03d}".format(runner.epoch)))