# coding: utf8 from __future__ import unicode_literals from .model import Model from .._lsuv import do_lsuv from ... import describe from ...describe import Weights, Dimension, Gradient from ..util import copy_array def LSUVinit(model, X, y=None): if model.vectors is not None: do_lsuv(model.ops, model.vectors, model, X) return X def _uniform_init(lo, hi): def wrapped(W, ops): if (W ** 2).sum() == 0.0: copy_array(W, ops.xp.random.uniform(lo, hi, W.shape)) return wrapped # @describe.on_data(LSUVinit) @describe.attributes( nO=Dimension("Vector dimensions"), nV=Dimension("Number of vectors"), vectors=Weights( "Embedding table", lambda obj: (obj.nV, obj.nO), _uniform_init(-0.1, 0.1) ), d_vectors=Gradient("vectors"), ) class HashEmbed(Model): name = "hash-embed" def __init__(self, nO, nV, seed=None, **kwargs): Model.__init__(self, **kwargs) self.column = kwargs.get("column", 0) self.nO = nO self.nV = nV if seed is not None: self.seed = seed else: self.seed = self.id def predict(self, ids): if ids.ndim >= 2: ids = self.ops.xp.ascontiguousarray(ids[:, self.column], dtype="uint64") keys = self.ops.hash(ids, self.seed) % self.nV vectors = self.vectors[keys] summed = vectors.sum(axis=1) return summed def begin_update(self, ids, drop=0.0): if ids.ndim >= 2: ids = self.ops.xp.ascontiguousarray(ids[:, self.column], dtype="uint64") keys = self.ops.hash(ids, self.seed) % self.nV vectors = self.vectors[keys].sum(axis=1) mask = self.ops.get_dropout_mask((vectors.shape[1],), drop) if mask is not None: vectors *= mask def finish_update(delta, sgd=None): if mask is not None: delta *= mask keys = self.ops.hash(ids, self.seed) % self.nV d_vectors = self.d_vectors keys = self.ops.xp.ascontiguousarray(keys.T, dtype="i") for i in range(keys.shape[0]): self.ops.scatter_add(d_vectors, keys[i], delta) if sgd is not None: sgd(self._mem.weights, self._mem.gradient, key=self.id) return None return vectors, finish_update