|
|
|
@ -31,11 +31,16 @@ class ShrinkRNNMemoryOp : public ArrayOp {
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto *x_var = scope.FindVar(Input("X"));
|
|
|
|
|
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(x_var,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(X) of ShrinkRNNMemoryOp is not found."));
|
|
|
|
|
auto &x_tensor = x_var->Get<framework::LoDTensor>();
|
|
|
|
|
size_t offset = this->GetOffset(scope, place);
|
|
|
|
|
auto *rank_table_var = scope.FindVar(Input("RankTable"));
|
|
|
|
|
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
rank_table_var,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(RankTable) of ShrinkRNNMemoryOp is not found."));
|
|
|
|
|
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
|
|
|
|
|
|
|
|
|
|
auto &rank_items = rank_table.items();
|
|
|
|
@ -46,7 +51,9 @@ class ShrinkRNNMemoryOp : public ArrayOp {
|
|
|
|
|
rank_items.begin();
|
|
|
|
|
|
|
|
|
|
auto *out_var = scope.FindVar(Output("Out"));
|
|
|
|
|
PADDLE_ENFORCE(out_var != nullptr, "Output(Out) must be set.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
out_var, platform::errors::NotFound(
|
|
|
|
|
"Output(Out) of ShrinkRNNMemoryOp is not found."));
|
|
|
|
|
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
size_t height = dst_num_rows;
|
|
|
|
@ -96,9 +103,10 @@ batch size for the next time step.
|
|
|
|
|
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("I"));
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("RankTable"));
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "ShrinkRNNMemory");
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput("I"), "Input", "I", "ShrinkRNNMemory");
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput("RankTable"), "Input", "RankTable",
|
|
|
|
|
"ShrinkRNNMemory");
|
|
|
|
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
|
|
|
|
// For runtime, output's lod is computed according to input's lod, but
|
|
|
|
|
// remove the finished sequence. It is set in detail kernel implementation.
|
|
|
|
@ -121,10 +129,13 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
|
|
|
|
|
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
|
|
|
|
|
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
dx_var, platform::errors::NotFound(
|
|
|
|
|
"Input(X@GRAD) of ShrinkRNNMemoryGradOp is not found."));
|
|
|
|
|
auto *x_var = scope.FindVar(Input("X"));
|
|
|
|
|
PADDLE_ENFORCE(x_var != nullptr);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
x_var, platform::errors::NotFound(
|
|
|
|
|
"Input(x) of ShrinkRNNMemoryGradOp is not found."));
|
|
|
|
|
auto &x_tensor = x_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
dx_tensor.Resize(x_tensor.dims());
|
|
|
|
@ -154,8 +165,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
|
|
|
|
|
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "ShrinkRNNMemoryGrad");
|
|
|
|
|
OP_INOUT_CHECK(context->HasOutput(framework::GradVarName("X")), "Output",
|
|
|
|
|
"X", "ShrinkRNNMemoryGrad");
|
|
|
|
|
|
|
|
|
|
context->ShareDim("X", /*->*/ framework::GradVarName("X"));
|
|
|
|
|
context->ShareLoD("X", /*->*/ framework::GradVarName("X"));
|
|
|
|
|