[CustomOp] Support Compile multi ops at same time (#30920)
* add more unitest for ABI compatibility * add more unittest * refine warning style * support compile multi custom ops in same time * fix not import paddle in unittest * fix typo * add more unittest * add comment for detailsrevert-31068-fix_conv3d_windows
parent
caf9d39839
commit
4c9f96c902
@ -0,0 +1,115 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class Relu3Op : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
auto in_dims = ctx->GetInputDim("X");
|
||||||
|
ctx->SetOutputDim("Y", in_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Relu3OpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X", "The input tensor.");
|
||||||
|
AddOutput("Y", "Output of relu_op");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Relu3 Operator.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Relu3GradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Relu3GradMaker : public framework::SingleGradOpMaker<T> {
|
||||||
|
public:
|
||||||
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||||
|
|
||||||
|
void Apply(GradOpPtr<T> op) const override {
|
||||||
|
op->SetType("relu3_grad");
|
||||||
|
op->SetInput("Y", this->Output("Y"));
|
||||||
|
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
|
||||||
|
op->SetAttrMap(this->Attrs());
|
||||||
|
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class Relu3Kernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* in_t = ctx.Input<Tensor>("X");
|
||||||
|
auto* out_t = ctx.Output<Tensor>("Y");
|
||||||
|
auto x = in_t->data<T>();
|
||||||
|
auto y = out_t->mutable_data<T>(ctx.GetPlace());
|
||||||
|
for (int i = 0; i < in_t->numel(); ++i) {
|
||||||
|
y[i] = std::max(static_cast<T>(0.), x[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class Relu3GradKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||||
|
auto* y_t = ctx.Input<Tensor>("Y");
|
||||||
|
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
|
||||||
|
auto dy = dy_t->data<T>();
|
||||||
|
auto y = y_t->data<T>();
|
||||||
|
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
for (int i = 0; i < y_t->numel(); ++i) {
|
||||||
|
dx[i] = dy[i] * (y[i] > static_cast<T>(0) ? 1. : 0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
using CPU = paddle::platform::CPUDeviceContext;
|
||||||
|
REGISTER_OPERATOR(relu3,
|
||||||
|
ops::Relu3Op,
|
||||||
|
ops::Relu3OpMaker,
|
||||||
|
ops::Relu3GradMaker<paddle::framework::OpDesc>,
|
||||||
|
ops::Relu3GradMaker<paddle::imperative::OpBase>);
|
||||||
|
REGISTER_OPERATOR(relu3_grad, ops::Relu3GradOp);
|
||||||
|
REGISTER_OP_CPU_KERNEL(relu3,
|
||||||
|
ops::Relu3Kernel<CPU, float>,
|
||||||
|
ops::Relu3Kernel<CPU, double>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(relu3_grad,
|
||||||
|
ops::Relu3GradKernel<CPU, float>,
|
||||||
|
ops::Relu3GradKernel<CPU, double>);
|
@ -0,0 +1,87 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void KeRelu3(const T* x, const int num, T* y) {
|
||||||
|
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
|
||||||
|
y[i] = max(x[i], static_cast<T>(0.));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class Relu3CUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* in_t = ctx.Input<Tensor>("X");
|
||||||
|
auto* out_t = ctx.Output<Tensor>("Y");
|
||||||
|
auto x = in_t->data<T>();
|
||||||
|
auto y = out_t->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||||
|
|
||||||
|
int num = in_t->numel();
|
||||||
|
int block = 512;
|
||||||
|
int grid = (num + block - 1) / block;
|
||||||
|
KeRelu3<T><<<grid, block, 0, dev_ctx.stream()>>>(x, num, y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void KeRelu3Grad(const T* y, const T* dy, const int num, T* dx) {
|
||||||
|
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
|
||||||
|
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class Relu3GradCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||||
|
auto* y_t = ctx.Input<Tensor>("Y");
|
||||||
|
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
|
||||||
|
auto dy = dy_t->data<T>();
|
||||||
|
auto y = y_t->data<T>();
|
||||||
|
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||||
|
|
||||||
|
int num = dy_t->numel();
|
||||||
|
int block = 512;
|
||||||
|
int grid = (num + block - 1) / block;
|
||||||
|
KeRelu3Grad<T><<<grid, block, 0, dev_ctx.stream()>>>(y, dy, num, dx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
using CUDA = paddle::platform::CUDADeviceContext;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(relu3,
|
||||||
|
paddle::operators::Relu3CUDAKernel<CUDA, float>,
|
||||||
|
paddle::operators::Relu3CUDAKernel<CUDA, double>);
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(relu3_grad,
|
||||||
|
paddle::operators::Relu3GradCUDAKernel<CUDA, float>,
|
||||||
|
paddle::operators::Relu3GradCUDAKernel<CUDA, double>);
|
@ -0,0 +1,43 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
|
||||||
|
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
|
||||||
|
const paddle::Tensor& out,
|
||||||
|
const paddle::Tensor& grad_out);
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x);
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
|
||||||
|
const paddle::Tensor& out,
|
||||||
|
const paddle::Tensor& grad_out);
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape);
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
|
||||||
|
|
||||||
|
// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator
|
||||||
|
// to test jointly compile multi operators at same time.
|
||||||
|
PD_BUILD_OPERATOR("relu3")
|
||||||
|
.Inputs({"X"})
|
||||||
|
.Outputs({"Out"})
|
||||||
|
.SetKernelFn(PD_KERNEL(ReluForward))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
|
||||||
|
.SetBackwardOp("relu3_grad")
|
||||||
|
.Inputs({"X", "Out", paddle::Grad("Out")})
|
||||||
|
.Outputs({paddle::Grad("X")})
|
||||||
|
.SetKernelFn(PD_KERNEL(ReluBackward));
|
Loading…
Reference in new issue