|
|
|
@ -42,7 +42,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto &mem_tensor = mem_var->Get<framework::LoDTensor>();
|
|
|
|
|
out_tensor->ShareDataWith(mem_tensor);
|
|
|
|
|
framework::TensorCopySync(mem_tensor, dev_place, out_tensor);
|
|
|
|
|
out_tensor->set_lod(mem_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -50,8 +50,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
|
|
|
|
|
class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of rnn_memory_helper op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output of rnn_memory_helper op should not be null.");
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
|
|
|
|
|
} else {
|
|
|
|
|
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
in_grad_tensor->ShareDataWith(out_grad_tensor);
|
|
|
|
|
framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor);
|
|
|
|
|
in_grad_tensor->set_lod(out_grad_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name),
|
|
|
|
|
"Gradient of Input(X) in rnn_memory_helper_grad of should "
|
|
|
|
|
"not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of rnn_memory_helper_grad of should not be null.");
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ x_grad_name);
|
|
|
|
|
}
|
|
|
|
|