|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/gru_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
|
|
|
|
@ -221,6 +222,13 @@ class GRUGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (ctx->HasOutput(weight_grad_name))
|
|
|
|
|
ctx->SetOutputDim(weight_grad_name, weight_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
|
|
|
|
ctx, framework::GradVarName("Hidden")),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -376,15 +384,53 @@ class GRUCPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<T> Apply() const override {
|
|
|
|
|
auto* grad_op = new T();
|
|
|
|
|
grad_op->SetType("gru_grad");
|
|
|
|
|
grad_op->SetInput("Input", this->Input("Input"));
|
|
|
|
|
grad_op->SetInput("H0", this->Input("H0"));
|
|
|
|
|
grad_op->SetInput("Bias", this->Input("Bias"));
|
|
|
|
|
grad_op->SetInput("Weight", this->Input("Weight"));
|
|
|
|
|
|
|
|
|
|
grad_op->SetInput("BatchGate", this->Output("BatchGate"));
|
|
|
|
|
grad_op->SetInput("BatchResetHiddenPrev",
|
|
|
|
|
this->Output("BatchResetHiddenPrev"));
|
|
|
|
|
grad_op->SetInput("BatchHidden", this->Output("BatchHidden"));
|
|
|
|
|
grad_op->SetInput("Hidden", this->Output("Hidden"));
|
|
|
|
|
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("Hidden"),
|
|
|
|
|
this->OutputGrad("Hidden"));
|
|
|
|
|
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("Input"),
|
|
|
|
|
this->InputGrad("Input"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("Weight"),
|
|
|
|
|
this->InputGrad("Weight"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
|
|
|
|
|
|
|
|
|
|
grad_op->SetAttrMap(this->Attrs());
|
|
|
|
|
return std::unique_ptr<T>(grad_op);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GRUGradOpNoNeedBufferVarInference,
|
|
|
|
|
"Input", "Bias");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
gru, ops::GRUOp, ops::GRUOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>)
|
|
|
|
|
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp);
|
|
|
|
|
REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
|
|
|
|
|
ops::GRUGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::GRUGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp,
|
|
|
|
|
ops::GRUGradOpNoNeedBufferVarInference);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
|
|
|
|
|
ops::GRUCPUKernel<double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|