enhance attention_lstm and param_attr error message (#23678)

* enhance attention_lstm and param_attr error message
* fix: fix param_attr type check
revert-22778-infer_var_type
xiaogang 5 years ago committed by GitHub
parent 600cb8c828
commit f11af6a935
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,97 +23,119 @@ namespace paddle {
namespace operators { namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AttentionLstm");
"Assert only one Input(X) of AttentionLSTM."); OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("C0"), OP_INOUT_CHECK(ctx->HasInput("LSTMWeight"), "Input", "LSTMWeight",
"Assert only one Input(C0) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), OP_INOUT_CHECK(ctx->HasInput("LSTMBias"), "Input", "LSTMBias",
"Assert only one Input(LSTMWeight) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), OP_INOUT_CHECK(ctx->HasInput("AttentionWeight"), "Input", "AttentionWeight",
"Assert only one Input(LSTMBias) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Assert only one Input(AttentionWeight) of AttentionLSTM."); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "AttentionLstm");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), OP_INOUT_CHECK(ctx->HasOutput("AttentionedX"), "Output", "AttentionedX",
"Assert only one Output(Hidden) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), OP_INOUT_CHECK(ctx->HasOutput("AttentionFCOut"), "Output", "AttentionFCOut",
"Assert only one Output(Cell) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), OP_INOUT_CHECK(ctx->HasOutput("LSTMX"), "Output", "LSTMX", "AttentionLstm");
"Assert only one Output(AttentionedX) of AttentionLSTM."); OP_INOUT_CHECK(ctx->HasOutput("LSTMOUT"), "Output", "LSTMOUT",
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), "AttentionLstm");
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1]; const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
"Input(X)'s rank must be 2."));
auto w_dims = ctx->GetInputDim("LSTMWeight"); auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4; const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(w_dims[0], D + M, w_dims.size(), 2,
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); platform::errors::InvalidArgument("Input(LSTMWeight)'s rank must be 2."));
PADDLE_ENFORCE_EQ(
w_dims[0], D + M,
platform::errors::InvalidArgument(
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D));
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); "Input(LSTMBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); PADDLE_ENFORCE_EQ(b_dims[0], 1,
platform::errors::InvalidArgument(
"LSTMBias dims should be 1 x %d.", 4 * D));
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D,
platform::errors::InvalidArgument(
"LSTMBias dims should be 1 x %d.", 4 * D));
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, platform::errors::InvalidArgument(
"Input(C0)'s rank must be 2."));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); PADDLE_ENFORCE_EQ(c_dims[1], D, platform::errors::InvalidArgument(
"C0 dims should be N x %d.", D));
} }
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2."); PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, platform::errors::InvalidArgument(
"Input(H0)'s rank must be 2."));
if (ctx->IsRuntime() || if (ctx->IsRuntime() ||
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) { (framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE_EQ(h_dims, c_dims,
platform::errors::InvalidArgument(
"The dimension of Input(H0) and Input(C0) " "The dimension of Input(H0) and Input(C0) "
"should be the same."); "should be the same."));
} }
} }
auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2."); platform::errors::InvalidArgument(
"Input(AttentionWeight)'s rank must be 2."));
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); platform::errors::InvalidArgument(
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); platform::errors::InvalidArgument(
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
if (ctx->HasInput("AttentionBias")) { if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias"); auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2."); platform::errors::InvalidArgument(
"Input(AttentionBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1."); platform::errors::InvalidArgument(
"AttentionBias shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1."); platform::errors::InvalidArgument(
"AttentionBias shapes must be 1 * 1."));
} }
if (ctx->HasInput("AttentionScalar")) { if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar"); auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); "Input(AttentionScalar)'s rank must be 2."));
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[0], 1, platform::errors::InvalidArgument(
"AttentionScalar shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(dims[1], 1, platform::errors::InvalidArgument(
"AttentionScalar shapes must be 1 * 1."));
} }
if (ctx->HasInput("AttentionScalarBias")) { if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias"); auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasInput("AttentionScalar"), "Input", "AttentionScalar",
ctx->HasInput("AttentionScalar"), "AttentionLstm");
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); "Input(AttentionScalarBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[0], 1,
platform::errors::InvalidArgument(
"AttentionScalarBias shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(dims[1], 1,
platform::errors::InvalidArgument(
"AttentionScalarBias shapes must be 1 * 1."));
} }
framework::DDim out_dims({x_dims[0], D}); framework::DDim out_dims({x_dims[0], D});
@ -301,8 +323,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
int len = x_lod[0][i + 1] - x_lod[0][i]; int len = x_lod[0][i + 1] - x_lod[0][i];
max_seq_len = max_seq_len < len ? len : max_seq_len; max_seq_len = max_seq_len < len ? len : max_seq_len;
} }
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, "Input(X)'s lod size must be 1."); PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); "Input(X)'s lod size must be 1."));
PADDLE_ENFORCE_EQ(
c0->dims()[0], N,
platform::errors::InvalidArgument("C0 dims should be %d x %d.", N, D));
fc_out->Resize({max_seq_len, 1}); fc_out->Resize({max_seq_len, 1});
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;

@ -16,9 +16,11 @@ from __future__ import print_function
import six import six
import warnings import warnings
import sys
from .initializer import Initializer, Xavier, Constant from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer from .regularizer import WeightDecayRegularizer
from paddle.fluid.data_feeder import check_type
__all__ = [ __all__ = [
'ParamAttr', 'ParamAttr',
@ -77,8 +79,17 @@ class ParamAttr(object):
regularizer=None, regularizer=None,
trainable=True, trainable=True,
do_model_average=True): do_model_average=True):
if sys.version_info.major == 2:
check_type(name, "name", (str, type(None), unicode), "ParamAttr")
else:
check_type(name, "name", (str, type(None)), "ParamAttr")
check_type(learning_rate, "learning_rate", (float, int), "ParamAttr")
check_type(trainable, "trainable", (bool), "ParamAttr")
check_type(do_model_average, "do_model_average", (bool), "ParamAttr")
self.name = name self.name = name
if isinstance(self.name, six.string_types) and self.name == "": if self.name == "":
raise ValueError("name of ParamAttr can not be empty str") raise ValueError("name of ParamAttr can not be empty str")
self.initializer = initializer self.initializer = initializer

Loading…
Cancel
Save