If your classes are not ordered by decreasing frequency, do not use this op. Additionaly, it also returns the number of times each of the true classes and the sampled classes is expected to occur. Parameters ---------- true_classes : Symbol The target classes in 1-D. num_sampled: int The number of classes to randomly sample. range_max: int The number of possible classes. Returns ------- samples: Symbol The sampled candidate classes in 1-D `int64` dtype. expected_count_true: Symbol The expected count for true classes in 1-D `float64` dtype. expected_count_sample: Symbol The expected count for sampled candidates in 1-D `float64` dtype. Examples -------- >>> true_cls = mx.nd.array([3]) >>> samples, exp_count_true, exp_count_sample = mx.nd.contrib.rand_zipfian(true_cls, 4, 5) >>> samples [1 3 3 3] >>> exp_count_true [ 0.12453879] >>> exp_count_sample [ 0.22629439 0.12453879 0.12453879 0.12453879] This operator simulates a for loop and body has the computation for an iteration of the for loop. It runs the computation in body on each slice from the input NDArrays. body takes two arguments as input and outputs a tuple of two elements, as illustrated below: out, states = body(data1, states) data1 can be either a symbol or a list of symbols. If data is a symbol, data1 is a symbol. Otherwise, data1 is a list of symbols and has the same size as data. states is a list of symbols and have the same size as init_states. Similarly, out can be either a symbol or a list of symbols, which are concatenated as the first output of foreach; states from the last execution of body are the second output of foreach. foreach can output only output data or states. If a user only wants states, the body function can return ([], states). Similarly, if a user only wants output data, the body function can return (out, []). The computation done by this operator is equivalent to the pseudo code below when the input data is NDArray: states = init_states outs = [] for i in data.shape[0]: s = data[i] out, states = body(s, states) outs.append(out) outs = stack(*outs) Parameters ---------- body : a Python function. Define computation in an iteration. data: a symbol or a list of symbols. The input data. init_states: a Symbol or nested lists of symbols. The initial values of the loop states. name: string. The name of the operator. Returns ------- outputs: a Symbol or nested lists of Symbols. The output data concatenated from the output of all iterations. states: a Symbol or nested lists of Symbols. The loop states in the last iteration. Examples -------- >>> step = lambda data, states: (data + states[0], [states[0] * 2]) >>> data = mx.sym.var('data') >>> states = [mx.sym.var('state')] >>> outs, states = mx.sym.contrib.foreach(step, data, states) This operator simulates a while loop which iterately does customized computation as long as the condition is satisfied. `loop_vars` is a Symbol or nested lists of Symbols on which the computation uses. `cond` is a user-defined function, used as the loop condition. It consumes `loop_vars`, and produces a scalar MXNet symbol, indicating the termination of the loop. The loop ends when `cond` returns false (zero). The `cond` is variadic, and its signature should be `cond(*loop_vars) => Symbol`. `func` is a user-defined function, used as the loop body. It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. In each step, `step_output` should contain the same number elements. Through all steps, the i-th element of `step_output` should have the same shape and dtype. Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, and the corresponding element should have the same shape and dtype. The `func` is variadic, and its signature should be `func(*loop_vars) => (Symbol or nested List[Symbol] step_output, Symbol or nested List[Symbol] new_loop_vars)`. `max_iterations` is a scalar that defines the maximum number of iterations allowed. This function returns two lists. The first list has the length of `|step_output|`, in which the i-th element are all i-th elements of `step_output` from all steps, stacked along axis 0. The second list has the length of `|loop_vars|`, which represents final states of loop variables. .. warning:: For now, the axis 0 of all Symbols in the first list are `max_iterations`, due to lack of dynamic shape inference. .. warning:: Even if `cond` is never satisfied, while_loop returns a list of outputs with inferred dtype and shape. This is different from the Symbol version, where in this case `step_outputs` are assumed as an empty list. Parameters ---------- cond: a Python function. The loop condition. func: a Python function. The loop body. loop_vars: a Symbol or nested lists of Symbol. The initial values of the loop variables. max_iterations: a python int. Maximum number of iterations. Returns ------ outputs: a Symbol or nested lists of Symbols stacked output from each step states: a Symbol or nested lists of Symbols final state Examples -------- >>> cond = lambda i, s: i <= 5 >>> func = lambda i, s: ([i + s], [i + 1, s + i]) >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) >>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10) It produces `outputs`, which is a list of Symbols. The signature of `then_func` should be `then_func() => nested List[Symbol]`. `else_func` is a user-defined function, used as computation of the else branch. It produces `outputs`, which is a list of Symbols. The signature of `else_func` should be `else_func() => nested List[Symbol]`. The `outputs` produces by `then_func` and `else_func` should have the same number of elements, all of which should be in the same shape, of the same dtype and stype. This function returns a list of symbols, representing the computation result. Parameters ---------- pred: a MXNet Symbol representing a scalar. The branch condition. then_func: a Python function. The computation to be executed if `pred` is true. else_func: a Python function. The computation to be executed if `pred` is false. Returns ------- outputs: a Symbol or nested lists of Symbols, representing the result of computation. Examples -------- >>> a, b = mx.sym.var('a'), mx.sym.var('b') >>> pred = a * b < 5 >>> then_func = lambda: (a + 5) * (b + 5) >>> else_func = lambda: (a - 5) * (b - 5) >>> outputs = mx.sym.contrib.cond(pred, then_func, else_func)