# -*- coding: utf-8 -*- # File: tflayer.py import functools import six import tensorflow as tf from ..tfutils.varreplace import custom_getter_scope from ..utils.argtools import get_data_format __all__ = [] def map_common_tfargs(kwargs): df = kwargs.pop('data_format', None) if df is not None: df = get_data_format(df, keras_mode=True) kwargs['data_format'] = df old_nl = kwargs.pop('nl', None) if old_nl is not None: kwargs['activation'] = lambda x, name=None: old_nl(x, name=name) if 'W_init' in kwargs: kwargs['kernel_initializer'] = kwargs.pop('W_init') if 'b_init' in kwargs: kwargs['bias_initializer'] = kwargs.pop('b_init') return kwargs def convert_to_tflayer_args(args_names, name_mapping): """ After applying this decorator: 1. data_format becomes tf.layers style 2. nl becomes activation 3. initializers are renamed 4. positional args are transformed to corresponding kwargs, according to args_names 5. kwargs are mapped to tf.layers names if needed, by name_mapping """ def decorator(func): @functools.wraps(func) def decorated_func(inputs, *args, **kwargs): kwargs = map_common_tfargs(kwargs) posarg_dic = {} assert len(args) <= len(args_names), \ "Please use kwargs instead of positional args to call this model, " \ "except for the following arguments: {}".format(', '.join(args_names)) for pos_arg, name in zip(args, args_names): posarg_dic[name] = pos_arg ret = {} for name, arg in six.iteritems(kwargs): newname = name_mapping.get(name, None) if newname is not None: assert newname not in kwargs, \ "Argument {} and {} conflicts!".format(name, newname) else: newname = name ret[newname] = arg ret.update(posarg_dic) # Let pos arg overwrite kw arg, for argscope to work return func(inputs, **ret) return decorated_func return decorator def rename_get_variable(mapping): """ Args: mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'} Returns: A context where the variables are renamed. """ def custom_getter(getter, name, *args, **kwargs): splits = name.split('/') basename = splits[-1] if basename in mapping: basename = mapping[basename] splits[-1] = basename name = '/'.join(splits) return getter(name, *args, **kwargs) return custom_getter_scope(custom_getter) def rename_tflayer_get_variable(): """ Rename all :func:`tf.get_variable` with rules that transforms tflayer style to tensorpack style. Returns: A context where the variables are renamed. Example: .. code-block:: python with rename_tflayer_get_variable(): x = tf.layer.conv2d(input, 3, 3, name='conv0') # variables will be named 'conv0/W', 'conv0/b' """ mapping = { 'kernel': 'W', 'bias': 'b', 'moving_mean': 'mean/EMA', 'moving_variance': 'variance/EMA', } return rename_get_variable(mapping)