Add the first implememtation of fusion_group op (#19621)
* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc. test=develop * Call CUDA driver api to launch the kernel compiled by nvrtc. test=develop * Disable for mac and windows. test=develop * Refine the codes to support manually specified num_threads and workload_per_thread. test=develop * Refine the CUDA kernel to support large dims. test=develop * Add DeviceCodePool to manage all device codes. * Add the first implementation fusion_group op. * Add unit-test for fusion_group op. * Add the check of result. * Add the check of nvrtc in unit-test. test=develop * Add comment to explain the inputs, outputs and features of fusion_group op. test=develop * Disable fusion_group op for mac and windows. test=develop * Make the compiling of device code return status instead of hanging up. test=develop * Add the check of whether there is CUDA driver library, and do not core dump when failing to call the CUDA driver API. * Unify fusion_group_op's input and output names. test=develop * Add the check of CUDA driver library in unittest. test=develop * Refine the calling of PADDLE_ENFORCE. test=developrelease/1.7
parent
6192108408
commit
d48320777e
@ -0,0 +1,90 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/fusion_group_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FusionGroupOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
const size_t num_ins = ctx->Inputs("Inputs").size();
|
||||
const size_t num_outs = ctx->Outputs("Outs").size();
|
||||
|
||||
PADDLE_ENFORCE_GE(
|
||||
num_ins, 1UL,
|
||||
platform::errors::InvalidArgument(
|
||||
"Expected the number of inputs >= 1. Received %d.", num_ins));
|
||||
PADDLE_ENFORCE_GE(
|
||||
num_outs, 1UL,
|
||||
platform::errors::InvalidArgument(
|
||||
"Expected the number of outputs >= 1. Recived %d.", num_outs));
|
||||
|
||||
int type = ctx->Attrs().Get<int>("type");
|
||||
PADDLE_ENFORCE_EQ(type, 0UL,
|
||||
platform::errors::InvalidArgument(
|
||||
"Only support fusion of elementwise operations."));
|
||||
|
||||
std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs");
|
||||
if (type == 0) {
|
||||
for (size_t i = 1; i < num_ins; ++i) {
|
||||
PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i],
|
||||
platform::errors::InvalidArgument(
|
||||
"All the inputs' dims should be the same."));
|
||||
}
|
||||
std::vector<framework::DDim> out_dims;
|
||||
for (size_t j = 0; j < num_outs; ++j) {
|
||||
out_dims.push_back(x_dims[0]);
|
||||
}
|
||||
ctx->SetOutputsDim("Outs", out_dims);
|
||||
}
|
||||
|
||||
// Only lod of Inputs[0] would be shared with Outs.
|
||||
for (size_t j = 0; j < num_outs; ++j) {
|
||||
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Inputs",
|
||||
"(std::vector<LoDTensor>) The inputs of fusion_group op.")
|
||||
.AsDuplicable();
|
||||
AddOutput("Outs",
|
||||
"(std::vector<LoDTensor>) The outputs of fusion_group op.")
|
||||
.AsDuplicable();
|
||||
AddAttr<int>("type", "Fusion type.").SetDefault(0);
|
||||
AddAttr<std::string>("func_name", "Name of the generated functions.")
|
||||
.SetDefault("");
|
||||
AddComment(R"DOC(
|
||||
fusion_group Operator.
|
||||
|
||||
It is used to execute a generated CUDA kernel which fuse the computation of
|
||||
multiple operators into one. It supports serveral types:
|
||||
0, fused computation of elementwise operations in which all the dims of inputs
|
||||
and outputs should be exactly the same.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fusion_group, ops::FusionGroupOp, ops::FusionGroupOpMaker);
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/fusion_group_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
fusion_group,
|
||||
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, float>);
|
@ -0,0 +1,65 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/device_code.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FusionGroupKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs");
|
||||
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs");
|
||||
int type = ctx.Attr<int>("type");
|
||||
|
||||
size_t num_ins = ins.size();
|
||||
size_t num_outs = outs.size();
|
||||
|
||||
auto place = ctx.GetPlace();
|
||||
for (size_t i = 0; i < num_outs; ++i) {
|
||||
outs[i]->mutable_data<T>(place);
|
||||
}
|
||||
|
||||
std::string func_name = ctx.Attr<std::string>("func_name");
|
||||
platform::DeviceCode* dev_code =
|
||||
platform::DeviceCodePool::Instance().Get(place, func_name);
|
||||
VLOG(3) << "func_name: " << func_name;
|
||||
|
||||
if (type == 0) {
|
||||
size_t n = ins[0]->numel();
|
||||
std::vector<void*> args;
|
||||
args.push_back(&n);
|
||||
std::vector<const T*> ptrs(num_ins + num_outs);
|
||||
for (size_t i = 0; i < num_ins; ++i) {
|
||||
ptrs[i] = ins[i]->data<T>();
|
||||
args.push_back(&ptrs[i]);
|
||||
}
|
||||
for (size_t j = 0; j < num_outs; ++j) {
|
||||
ptrs[num_ins + j] = outs[j]->data<T>();
|
||||
args.push_back(&ptrs[num_ins + j]);
|
||||
}
|
||||
dev_code->Launch(n, &args);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,220 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/op_desc.h"
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/platform/device_code.h"
|
||||
#include "paddle/fluid/platform/init.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using CPUKernelFunc = std::function<void(size_t n, std::vector<void*> args)>;
|
||||
|
||||
template <typename T>
|
||||
framework::Tensor* CreateTensor(framework::Scope* scope,
|
||||
const platform::Place& place,
|
||||
const std::string& name,
|
||||
const std::vector<int64_t>& shape) {
|
||||
auto* var = scope->Var(name);
|
||||
auto* tensor = var->GetMutable<framework::LoDTensor>();
|
||||
if (shape.size() > 0) {
|
||||
tensor->mutable_data<T>(framework::make_ddim(shape), place);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetupRandomCPUTensor(framework::Tensor* tensor,
|
||||
const std::vector<int64_t>& shape) {
|
||||
static unsigned int seed = 100;
|
||||
std::mt19937 rng(seed++);
|
||||
std::uniform_real_distribution<double> uniform_dist(0, 1);
|
||||
|
||||
T* ptr = tensor->mutable_data<T>(framework::make_ddim(shape),
|
||||
platform::CPUPlace());
|
||||
for (int64_t i = 0; i < tensor->numel(); ++i) {
|
||||
ptr[i] = static_cast<T>(uniform_dist(rng)) - static_cast<T>(0.5);
|
||||
}
|
||||
}
|
||||
|
||||
framework::OpDesc* CreateFusionGroupOp(
|
||||
framework::ProgramDesc* program,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::vector<int64_t>>& input_shapes,
|
||||
const std::vector<std::string>& output_names, int type,
|
||||
std::string func_name) {
|
||||
EXPECT_EQ(input_names.size(), input_shapes.size());
|
||||
|
||||
for (size_t i = 0; i < input_names.size(); ++i) {
|
||||
auto* var = program->MutableBlock(0)->Var(input_names[i]);
|
||||
var->SetType(framework::proto::VarType::LOD_TENSOR);
|
||||
var->SetDataType(framework::proto::VarType::FP32);
|
||||
var->SetShape(input_shapes[i]);
|
||||
}
|
||||
for (size_t j = 0; j < output_names.size(); ++j) {
|
||||
auto* var = program->MutableBlock(0)->Var(output_names[j]);
|
||||
var->SetType(framework::proto::VarType::LOD_TENSOR);
|
||||
var->SetDataType(framework::proto::VarType::FP32);
|
||||
}
|
||||
|
||||
auto* op = program->MutableBlock(0)->AppendOp();
|
||||
op->SetType("fusion_group");
|
||||
op->SetInput("Inputs", input_names);
|
||||
op->SetOutput("Outs", output_names);
|
||||
op->SetAttr("type", type);
|
||||
op->SetAttr("func_name", func_name);
|
||||
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
|
||||
static_cast<int>(framework::OpRole::kForward));
|
||||
return op;
|
||||
}
|
||||
|
||||
void PrepareDeviceCode(platform::Place place, std::string func_name,
|
||||
std::string cuda_kernel_str) {
|
||||
paddle::platform::DeviceCodePool& pool =
|
||||
paddle::platform::DeviceCodePool::Init({place});
|
||||
|
||||
std::unique_ptr<paddle::platform::DeviceCode> code(
|
||||
new paddle::platform::CUDADeviceCode(place, func_name, cuda_kernel_str));
|
||||
code->Compile();
|
||||
pool.Set(std::move(code));
|
||||
}
|
||||
|
||||
void CheckOutputs(framework::Scope* scope,
|
||||
const std::vector<std::string>& output_names,
|
||||
std::vector<framework::Tensor>* cpu_tensors,
|
||||
size_t num_inputs, CPUKernelFunc cpu_kernel_func) {
|
||||
std::vector<framework::Tensor> cpu_outputs;
|
||||
cpu_outputs.resize(output_names.size());
|
||||
for (size_t j = 0; j < output_names.size(); ++j) {
|
||||
auto* var = scope->Var(output_names[j]);
|
||||
const auto& dev_tensor = var->Get<framework::LoDTensor>();
|
||||
TensorCopySync(dev_tensor, platform::CPUPlace(), &(cpu_outputs[j]));
|
||||
|
||||
cpu_tensors->at(num_inputs + j)
|
||||
.mutable_data<float>(dev_tensor.dims(), platform::CPUPlace());
|
||||
}
|
||||
|
||||
size_t n = cpu_tensors->at(0).numel();
|
||||
std::vector<void*> args;
|
||||
for (size_t i = 0; i < cpu_tensors->size(); ++i) {
|
||||
args.push_back(cpu_tensors->at(i).data<float>());
|
||||
}
|
||||
cpu_kernel_func(n, args);
|
||||
|
||||
for (size_t j = 0; j < output_names.size(); ++j) {
|
||||
auto* dev_ptr = cpu_outputs[j].data<float>();
|
||||
auto* cpu_ptr = cpu_tensors->at(num_inputs + j).data<float>();
|
||||
int64_t length = cpu_outputs[j].numel();
|
||||
LOG(INFO) << "Check the " << j << "th output...";
|
||||
for (int64_t i = 0; i < length; ++i) {
|
||||
EXPECT_NEAR(dev_ptr[i], cpu_ptr[i], 1.E-05);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestMain(const std::vector<std::string>& input_names,
|
||||
const std::vector<std::vector<int64_t>>& input_shapes,
|
||||
const std::vector<std::string>& output_names, int type,
|
||||
std::string func_name, std::string cuda_kernel_str,
|
||||
CPUKernelFunc cpu_kernel_func) {
|
||||
// Compile the device code
|
||||
paddle::framework::InitDevices(false, {0});
|
||||
platform::CUDAPlace place = platform::CUDAPlace(0);
|
||||
PrepareDeviceCode(place, func_name, cuda_kernel_str);
|
||||
|
||||
// Create a ProgramDesc that has a fusion_group_op.
|
||||
framework::ProgramDesc program;
|
||||
framework::OpDesc* op_desc = CreateFusionGroupOp(
|
||||
&program, input_names, input_shapes, output_names, type, func_name);
|
||||
auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc);
|
||||
|
||||
framework::Scope scope;
|
||||
|
||||
// Prepare input tensors.
|
||||
std::vector<framework::Tensor> cpu_tensors;
|
||||
cpu_tensors.resize(input_names.size() + output_names.size());
|
||||
for (size_t i = 0; i < input_names.size(); ++i) {
|
||||
SetupRandomCPUTensor<float>(&(cpu_tensors[i]), input_shapes[i]);
|
||||
framework::Tensor* dev_tensor =
|
||||
CreateTensor<float>(&scope, place, input_names[i], input_shapes[i]);
|
||||
TensorCopySync(cpu_tensors[i], place, dev_tensor);
|
||||
}
|
||||
// Create output tensors.
|
||||
std::vector<int64_t> empty_shape;
|
||||
for (size_t j = 0; j < output_names.size(); ++j) {
|
||||
CreateTensor<float>(&scope, place, output_names[j], empty_shape);
|
||||
}
|
||||
|
||||
fusion_group_op->Run(scope, place);
|
||||
|
||||
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
|
||||
dev_ctx->Wait();
|
||||
|
||||
// Check the output.
|
||||
CheckOutputs(&scope, output_names, &cpu_tensors, input_names.size(),
|
||||
cpu_kernel_func);
|
||||
}
|
||||
|
||||
TEST(FusionGroupOp, elementwise) {
|
||||
if (!platform::dynload::HasNVRTC() || !platform::dynload::HasCUDADriver()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// z = relu(x + y)
|
||||
std::vector<std::string> input_names = {"x", "y"};
|
||||
std::vector<std::string> output_names = {"z"};
|
||||
std::vector<std::vector<int64_t>> input_shapes = {{256, 256}, {256, 256}};
|
||||
constexpr auto kernel = R"(
|
||||
static inline __device__ float relu(float x) {
|
||||
return x * (x > 0);
|
||||
}
|
||||
|
||||
extern "C" __global__
|
||||
void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
|
||||
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n;
|
||||
tid += blockDim.x * gridDim.x) {
|
||||
float tmp_0 = x[tid];
|
||||
float tmp_1 = y[tid];
|
||||
float tmp_2 = tmp_0 + tmp_1;
|
||||
float tmp_3 = relu(tmp_2);
|
||||
z[tid] = tmp_3;
|
||||
}
|
||||
})";
|
||||
|
||||
auto elementwise_cpu_kernel_0 = [](size_t n,
|
||||
std::vector<void*> args) -> void {
|
||||
float* x = static_cast<float*>(args[0]);
|
||||
float* y = static_cast<float*>(args[1]);
|
||||
float* z = static_cast<float*>(args[2]);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
float tmp_0 = x[i];
|
||||
float tmp_1 = y[i];
|
||||
float tmp_2 = tmp_0 + tmp_1;
|
||||
float tmp_3 = tmp_2 > 0 ? tmp_2 : 0;
|
||||
z[i] = tmp_3;
|
||||
}
|
||||
};
|
||||
|
||||
TestMain(input_names, input_shapes, output_names, 0,
|
||||
"elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0);
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
USE_CUDA_ONLY_OP(fusion_group);
|
Loading…
Reference in new issue