|
|
@ -26,24 +26,35 @@ class SequencePadOp : public framework::OperatorWithKernel {
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
"Input(X) of SequencePadOp should not be null.");
|
|
|
|
platform::errors::NotFound(
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("PadValue"), true,
|
|
|
|
"Input(X) of SequencePadOp should not be null."));
|
|
|
|
"Input(PadValue) of SequencePadOp should not be null.");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
ctx->HasInput("PadValue"), true,
|
|
|
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
|
|
|
"Input(PadValue) of SequencePadOp should not be null."));
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
"Output(Out) of SequencePadOp should not be null.");
|
|
|
|
platform::errors::NotFound(
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Length"), true,
|
|
|
|
"Output(Out) of SequencePadOp should not be null."));
|
|
|
|
"Output(Length) of SequencePadOp should not be null.");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
ctx->HasOutput("Length"), true,
|
|
|
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
|
|
|
"Output(Length) of SequencePadOp should not be null."));
|
|
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
"The rank of Input(X) can't be less than 2.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The rank of SequencePadOp Input(X) can't be less "
|
|
|
|
|
|
|
|
"than 2. But the rank we received is %d",
|
|
|
|
|
|
|
|
x_dims.size()));
|
|
|
|
auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size());
|
|
|
|
auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size());
|
|
|
|
auto pad_value_dims = ctx->GetInputDim("PadValue");
|
|
|
|
auto pad_value_dims = ctx->GetInputDim("PadValue");
|
|
|
|
PADDLE_ENFORCE_EQ(pad_value_dims == framework::make_ddim({1}) ||
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
pad_value_dims == time_step_dims,
|
|
|
|
pad_value_dims == framework::make_ddim({1}) ||
|
|
|
|
true,
|
|
|
|
pad_value_dims == time_step_dims,
|
|
|
|
"The Input(PadValue) must be a scalar or a tensor whose "
|
|
|
|
true,
|
|
|
|
"shape equals to time steps in sequences");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The SequencePadOp Input(PadValue) must be a scalar or a tensor "
|
|
|
|
|
|
|
|
"whose shape equals to time steps in sequences"));
|
|
|
|
|
|
|
|
|
|
|
|
int out_dim_0 = -1;
|
|
|
|
int out_dim_0 = -1;
|
|
|
|
|
|
|
|
|
|
|
@ -54,22 +65,37 @@ class SequencePadOp : public framework::OperatorWithKernel {
|
|
|
|
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
|
|
|
|
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
|
|
|
|
const auto& x_lod = x_var->Get<LoDTensor>().lod();
|
|
|
|
const auto& x_lod = x_var->Get<LoDTensor>().lod();
|
|
|
|
PADDLE_ENFORCE_EQ(x_lod.empty(), false,
|
|
|
|
PADDLE_ENFORCE_EQ(x_lod.empty(), false,
|
|
|
|
"The Input(X) must hold lod info.");
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
|
|
|
"The SequencePadOp Input(X) must hold lod info."));
|
|
|
|
const auto& x_lod_0 = x_lod[0];
|
|
|
|
const auto& x_lod_0 = x_lod[0];
|
|
|
|
PADDLE_ENFORCE_GE(x_lod_0.size(), 2,
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
"The Input(X)'s lod info is corrupted.");
|
|
|
|
x_lod_0.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
x_dims[0], static_cast<int64_t>(x_lod_0.back()),
|
|
|
|
"The size of SequencePadOp Input(X)'s lod info can't be less "
|
|
|
|
"The Input(X)'s lod info mismatches the actual tensor shape.");
|
|
|
|
"than 2. But the size we received is %d",
|
|
|
|
|
|
|
|
x_lod_0.size()));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], static_cast<int64_t>(x_lod_0.back()),
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The SequencePadOp Input(X)'s lod info mismatches "
|
|
|
|
|
|
|
|
"the actual tensor shape. The 1st dimension of "
|
|
|
|
|
|
|
|
"Input(X)'s lod info is %d, the 1st dimension of "
|
|
|
|
|
|
|
|
"actual tensor shape is %d",
|
|
|
|
|
|
|
|
x_dims[0], static_cast<int64_t>(x_lod_0.back())));
|
|
|
|
|
|
|
|
|
|
|
|
int seq_num = x_lod_0.size() - 1;
|
|
|
|
int seq_num = x_lod_0.size() - 1;
|
|
|
|
int max_seq_len = math::MaximumSequenceLength(x_lod_0);
|
|
|
|
int max_seq_len = math::MaximumSequenceLength(x_lod_0);
|
|
|
|
if (padded_length == -1) {
|
|
|
|
if (padded_length == -1) {
|
|
|
|
padded_length = max_seq_len;
|
|
|
|
padded_length = max_seq_len;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
PADDLE_ENFORCE_GE(padded_length, max_seq_len,
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
"The Attr(padded_length) must be -1 or an int greater "
|
|
|
|
padded_length, max_seq_len,
|
|
|
|
"than the length of the longest original sequence.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The SequencePadOp Attr(padded_length) should be greater than or "
|
|
|
|
|
|
|
|
"equal to the "
|
|
|
|
|
|
|
|
"length of the longest original sequence. But the padded_length "
|
|
|
|
|
|
|
|
"we received is %d, the length of the longest original sequence "
|
|
|
|
|
|
|
|
"is %d",
|
|
|
|
|
|
|
|
padded_length, max_seq_len));
|
|
|
|
out_dim_0 = seq_num;
|
|
|
|
out_dim_0 = seq_num;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
// compile time
|
|
|
|
// compile time
|
|
|
@ -78,7 +104,10 @@ class SequencePadOp : public framework::OperatorWithKernel {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
ctx->GetLoDLevel("X"), 0,
|
|
|
|
ctx->GetLoDLevel("X"), 0,
|
|
|
|
"The LoD level Input(X) of sequence_pad should be larger than 0.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The LoD level of SequencePadOp Input(X) should be "
|
|
|
|
|
|
|
|
"larger than 0. But the LoD level we received is %d",
|
|
|
|
|
|
|
|
ctx->GetLoDLevel("X")));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> out_dims_vec{out_dim_0, padded_length};
|
|
|
|
std::vector<int> out_dims_vec{out_dim_0, padded_length};
|
|
|
@ -185,10 +214,12 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
"Input(X) of SequencePadGradOp should not be null.");
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
|
|
|
"Input(X) of SequencePadGradOp should not be null."));
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
|
|
|
"Input(Out@GRAD) of SequencePadGradOp should not be null."));
|
|
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|