A a pass to enable the use of cudnn (#19346)
* Add a interface to enable cudnn for inference. * Add cudnn_placement_pass. test=develop * Set the default value of cudnn_enabled_op_types to null. test=develop * Write the common basic class, placement_pass_base, to refine the codes. test=develop * Call EnableCUDNN in unittest. test=develop * Refine cudnn_placement_pass tester. * Enable the testing of cudnn_placement_pass in inference's unittest. test=develop * Add the check of op kernels. test=developfix_crf_doc
parent
cc443675e9
commit
c5548178b0
@ -0,0 +1,18 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
|
||||
|
||||
REGISTER_PASS(cudnn_placement_pass, paddle::framework::ir::CUDNNPlacementPass)
|
||||
.RequirePassAttr("cudnn_enabled_op_types");
|
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include "paddle/fluid/framework/ir/placement_pass_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Specifies which operators should use cuDNN.
|
||||
*/
|
||||
class CUDNNPlacementPass : public PlacementPassBase {
|
||||
private:
|
||||
const std::string GetPlacementName() const { return "cuDNN"; }
|
||||
|
||||
const std::string GetAttrName() const { return "use_cudnn"; }
|
||||
|
||||
const std::unordered_set<std::string> GetOpTypesList() const {
|
||||
return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,119 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void RegisterOpKernel() {
|
||||
static bool is_registered = false;
|
||||
if (!is_registered) {
|
||||
auto& all_kernels = OperatorWithKernel::AllOpKernels();
|
||||
|
||||
platform::CUDAPlace place = platform::CUDAPlace(0);
|
||||
OpKernelType plain_kernel_type =
|
||||
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
|
||||
LibraryType::kPlain);
|
||||
OpKernelType cudnn_kernel_type =
|
||||
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
|
||||
LibraryType::kCUDNN);
|
||||
|
||||
auto fake_kernel_func = [](const ExecutionContext&) -> void {
|
||||
static int num_calls = 0;
|
||||
num_calls++;
|
||||
};
|
||||
|
||||
all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
|
||||
all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
|
||||
all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
|
||||
all_kernels["relu"][plain_kernel_type] = fake_kernel_func;
|
||||
|
||||
is_registered = true;
|
||||
}
|
||||
}
|
||||
|
||||
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
|
||||
unsigned expected_use_cudnn_true_count) {
|
||||
// operator use_cudnn
|
||||
// --------------------------------------------------
|
||||
// (a,b)->concat->c -
|
||||
// (c,weights,bias)->conv2d->f false
|
||||
// f->relu->g -
|
||||
// g->pool2d->h false
|
||||
// (h,weights2,bias2)->depthwise_conv2d->k false
|
||||
// k->relu->l -
|
||||
Layers layers;
|
||||
VarDesc* a = layers.data("a");
|
||||
VarDesc* b = layers.data("b");
|
||||
VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
|
||||
VarDesc* weights_0 = layers.data("weights_0");
|
||||
VarDesc* bias_0 = layers.data("bias_0");
|
||||
VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
|
||||
VarDesc* g = layers.relu(f);
|
||||
VarDesc* h = layers.pool2d(g, false);
|
||||
VarDesc* weights_1 = layers.data("weights_1");
|
||||
VarDesc* bias_1 = layers.data("bias_1");
|
||||
VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
|
||||
layers.relu(k);
|
||||
|
||||
RegisterOpKernel();
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
|
||||
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
|
||||
pass->Set("cudnn_enabled_op_types",
|
||||
new std::unordered_set<std::string>(cudnn_enabled_op_types));
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
|
||||
unsigned use_cudnn_true_count = 0;
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()) {
|
||||
auto* op = node->Op();
|
||||
if (op->HasAttr("use_cudnn") &&
|
||||
boost::get<bool>(op->GetAttr("use_cudnn"))) {
|
||||
++use_cudnn_true_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
|
||||
}
|
||||
|
||||
TEST(CUDNNPlacementPass, enable_conv2d) {
|
||||
// 1 conv2d
|
||||
MainTest({"conv2d"}, 1);
|
||||
}
|
||||
|
||||
TEST(CUDNNPlacementPass, enable_relu_pool) {
|
||||
// 1 conv2d + 1 pool2d
|
||||
MainTest({"conv2d", "pool2d"}, 2);
|
||||
}
|
||||
|
||||
TEST(CUDNNPlacementPass, enable_all) {
|
||||
// 1 conv2d + 1 pool2d
|
||||
// depthwise_conv2d doesnot have CUDNN kernel.
|
||||
MainTest({}, 2);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(cudnn_placement_pass);
|
@ -0,0 +1,69 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/placement_pass_base.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void PlacementPassBase::ApplyImpl(ir::Graph* graph) const {
|
||||
VLOG(3) << "Applies " << GetPlacementName() << " placement strategy.";
|
||||
std::string attr_name = GetAttrName();
|
||||
const auto& op_types_list = GetOpTypesList();
|
||||
if (!graph->Has(attr_name)) {
|
||||
graph->Set<bool>(attr_name, new bool(true));
|
||||
}
|
||||
for (const Node* n : graph->Nodes()) {
|
||||
if (n->IsOp()) {
|
||||
auto* op = n->Op();
|
||||
if ((op->HasAttr(attr_name) || op->HasProtoAttr(attr_name)) &&
|
||||
IsSupport(op->Type())) {
|
||||
if (op_types_list.empty()) {
|
||||
op->SetAttr(attr_name, true);
|
||||
} else if (std::find(op_types_list.begin(), op_types_list.end(),
|
||||
n->Name()) != op_types_list.end()) {
|
||||
op->SetAttr(attr_name, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool PlacementPassBase::IsSupport(const std::string& op_type) const {
|
||||
if (GetAttrName() == "use_cudnn") {
|
||||
auto& all_kernels = OperatorWithKernel::AllOpKernels();
|
||||
auto it = all_kernels.find(op_type);
|
||||
if (it == all_kernels.end()) {
|
||||
// All control operators don't have kernel.
|
||||
return false;
|
||||
}
|
||||
for (auto& kernel_pair : it->second) {
|
||||
if (platform::is_gpu_place(kernel_pair.first.place_) &&
|
||||
(kernel_pair.first.library_type_ == LibraryType::kCUDNN)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} else if (GetAttrName() == "use_mkldnn") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Specifies which operators should use cuDNN.
|
||||
*/
|
||||
class PlacementPassBase : public Pass {
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
|
||||
virtual const std::string GetPlacementName() const = 0;
|
||||
virtual const std::string GetAttrName() const = 0;
|
||||
virtual const std::unordered_set<std::string> GetOpTypesList() const = 0;
|
||||
|
||||
private:
|
||||
bool IsSupport(const std::string& op_type) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue