#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (c) 2019, Open-MMLab. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict import tensorflow as tf import numpy as np class LogBuffer(object): def __init__(self): self.val_history = OrderedDict() self.n_history = OrderedDict() self.output = OrderedDict() self.ready = False def clear(self): self.val_history.clear() self.n_history.clear() self.clear_output() def clear_output(self): self.output.clear() self.ready = False def update(self, vars, count=1): assert isinstance(vars, dict) for key, var in vars.items(): if key not in self.val_history: self.val_history[key] = [] self.n_history[key] = [] if tf.is_tensor(var): var = var.numpy() self.val_history[key].append(var) self.n_history[key].append(count) def average(self, n=0): """Average latest n values or all values""" assert n >= 0 for key in self.val_history: if 'time' in key: self.output[key] = self.val_history[key][-1] elif 'image' not in key: # skip images values = np.array(self.val_history[key][-n:]) nums = np.array(self.n_history[key][-n:]) avg = np.sum(values * nums) / np.sum(nums) self.output[key] = avg self.ready = True