Fix python wrapper for layer_norm

emailweixu-patch-1
guosheng 7 years ago
parent d63b7c6042
commit 0999347910

@ -116,8 +116,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),

@ -1637,7 +1637,7 @@ def layer_norm(input,
dtype=dtype,
default_initializer=Constant(1.0))
inputs['Scale'] = scale
if center:
if shift:
assert bias_attr is not False
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)

Loading…
Cancel
Save