reduce the occupied size of memory for the fused pattern of elementwise_add Op and activation Op(relu Op for example) (#29885)

revert-31562-mean
wangchaochaohu 5 years ago committed by GitHub
parent 5932fee60a
commit af80859dd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -183,7 +183,7 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::string d_ele_y_n = d_ele_y->Name();
OpDesc desc;
desc.SetType("fused_elemwise_activation_grad");
desc.SetType("fused_elemwise_add_activation_grad");
desc.SetInput("IntermediateOut", {});
desc.SetInput("X", {});
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
@ -231,7 +231,7 @@ Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode(
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
desc.SetOutput("Out", std::vector<std::string>({act_out_n}));
desc.SetOutput("IntermediateOut", std::vector<std::string>({ele_out_n}));
desc.SetType("fused_elemwise_activation");
desc.SetType("fused_elemwise_add_activation");
desc.SetAttr("save_intermediate_out", true);
desc.SetAttr("functor_list", std::vector<std::string>(
{op_1->Op()->Type(), op_2->Op()->Type()}));
@ -251,7 +251,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
std::unordered_set<const Node *> need_removed_nodes;
for (auto &cur_node : graph->Nodes()) {
if (cur_node->IsVar()) continue;
if (cur_node->Name() == "fused_elemwise_activation") {
if (cur_node->Name() == "fused_elemwise_add_activation") {
bool save_intermediate_out = BOOST_GET_CONST(
bool, cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
@ -272,7 +272,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
}
}
}
} else if (cur_node->Name() == "fused_elemwise_activation_grad") {
} else if (cur_node->Name() == "fused_elemwise_add_activation_grad") {
auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE_EQ(

@ -361,10 +361,14 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace());
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
FusedElemwiseAddActivationNoNeddBufVarInferer, "X", "Y");
} // namespace operators
} // namespace paddle
@ -390,3 +394,27 @@ REGISTER_OP_CPU_KERNEL(
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
double>);
// for memory optimization, we register the fused_elemwise_add_activation OP
REGISTER_OPERATOR(
fused_elemwise_add_activation, ops::FusedElemwiseActivationOp,
ops::FusedElemwiseActivationMaker,
ops::FusedElemwiseActivationGradMaker<paddle::framework::OpDesc>,
ops::FusedElemwiseActivationGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_elemwise_add_activation_grad,
ops::FusedElemwiseAddActivationNoNeddBufVarInferer,
ops::FusedElemwiseActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
fused_elemwise_add_activation,
ops::FusedElemwiseActivationKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FusedElemwiseActivationKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
double>);

@ -32,3 +32,21 @@ REGISTER_OP_CUDA_KERNEL(
double>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_elemwise_add_activation,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);

@ -77,4 +77,6 @@ class TestMNIST(TestParallelExecutorBase):
if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main()

@ -390,4 +390,6 @@ for mode in {0, 1}:
grad_chek=False)
if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main()

Loading…
Cancel
Save