test=developrevert-14324-fix_vlog
commit
d4e8d7077f
@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
|
||||
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
|
||||
op1->Outputs() == op2->Outputs();
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
// FIXME(zjl): Insert dependencies between some distributed ops may cause
|
||||
// the multi_devices_graph_pass fails. So we skip these ops here.
|
||||
// Indeed, maybe we should not insert dependencies between these ops
|
||||
// casually, which may cause deadlock easily.
|
||||
// We should add more skipped distributed ops when found errors in
|
||||
// multi_devices_graph_pass
|
||||
static std::unordered_set<std::string> skip_dist_ops{
|
||||
"send", "recv", "send_barrier", "fetch_barrier"};
|
||||
|
||||
auto &ops = Get<const std::vector<OpDesc *>>(kAllOpDescs);
|
||||
std::vector<ir::Node *> op_node_list;
|
||||
op_node_list.reserve(ops.size());
|
||||
|
||||
std::unordered_map<ir::Node *, size_t> op_deps;
|
||||
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
|
||||
std::unordered_set<ir::Node *> ready_ops;
|
||||
|
||||
for (ir::Node *node : graph->Nodes()) {
|
||||
if (!node->IsOp()) continue;
|
||||
std::unordered_set<ir::Node *> preceding_ops;
|
||||
for (auto *in : node->inputs) {
|
||||
PADDLE_ENFORCE(in->IsVar(),
|
||||
"Preceding Node of Op Nodes must be Var Node");
|
||||
if (in->inputs.empty()) continue;
|
||||
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
|
||||
"Preceding Op Node of Var Node must be unique");
|
||||
preceding_ops.insert(in->inputs[0]);
|
||||
pending_ops[in->inputs[0]].insert(node);
|
||||
}
|
||||
op_deps[node] = preceding_ops.size();
|
||||
if (preceding_ops.empty()) {
|
||||
ready_ops.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto *op_desc : ops) {
|
||||
ir::Node *found_node = nullptr;
|
||||
for (auto *node : ready_ops) {
|
||||
if (IsSameOpDesc(op_desc, node->Op())) {
|
||||
PADDLE_ENFORCE(found_node == nullptr,
|
||||
"Found multiple op_desc in graph: %s", op_desc->Type());
|
||||
found_node = node;
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
|
||||
op_desc->Type());
|
||||
for (auto *pending_op : pending_ops[found_node]) {
|
||||
if (--op_deps.at(pending_op) == 0) {
|
||||
ready_ops.insert(pending_op);
|
||||
}
|
||||
}
|
||||
ready_ops.erase(found_node);
|
||||
if (skip_dist_ops.count(op_desc->Type()) == 0) {
|
||||
op_node_list.push_back(found_node);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < op_node_list.size(); ++i) {
|
||||
auto *dep_var = graph->CreateControlDepVar();
|
||||
op_node_list[i]->inputs.push_back(dep_var);
|
||||
op_node_list[i - 1]->outputs.push_back(dep_var);
|
||||
dep_var->outputs.push_back(op_node_list[i]);
|
||||
dep_var->inputs.push_back(op_node_list[i - 1]);
|
||||
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
|
||||
<< " and " << op_node_list[i]->Name();
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(sequential_execution_pass,
|
||||
paddle::framework::details::SequentialExecutionPass)
|
||||
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
|
@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
constexpr char kAllOpDescs[] = "all_op_descs";
|
||||
|
||||
class SequentialExecutionPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,58 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
#define GET_NODE(id, pattern) \
|
||||
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
|
||||
"pattern has no Node called %s", #id); \
|
||||
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
|
||||
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
|
||||
|
||||
std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get());
|
||||
GraphPatternDetector gpd;
|
||||
|
||||
auto* pattern = gpd.mutable_pattern();
|
||||
pattern->NewNode("depthwise_conv")
|
||||
->assert_is_op("depthwise_conv2d")
|
||||
->assert_op_attr("use_mkldnn", true);
|
||||
|
||||
int found_depthwise_conv_mkldnn_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(3) << "handle DepthwiseConvMKLDNN fuse";
|
||||
GET_NODE(depthwise_conv, (*pattern));
|
||||
depthwise_conv->Op()->SetType("conv2d");
|
||||
found_depthwise_conv_mkldnn_count++;
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
AddStatis(found_depthwise_conv_mkldnn_count);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(depthwise_conv_mkldnn_pass,
|
||||
paddle::framework::ir::DepthwiseConvMKLDNNPass);
|
@ -0,0 +1,34 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class DepthwiseConvMKLDNNPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~DepthwiseConvMKLDNNPass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,123 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||
op->SetAttr("name", name);
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetInput("Filter", {inputs[1]});
|
||||
op->SetInput("Bias", {inputs[2]});
|
||||
op->SetOutput("Out", outputs);
|
||||
}
|
||||
|
||||
// (a, weights, bias)->depthwise conv mkldnn->b
|
||||
// (b, weights2, bias2)->depthwise conv no mkldnn->c
|
||||
// (c, weights3, bias3)->conv mkldnn->d
|
||||
// (d, weights3, bias3)->conv no mkldnn->e
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : std::vector<std::string>(
|
||||
{"a", "b", "c", "d", "e", "weights", "bias", "weights2", "bias2",
|
||||
"weights3", "bias3", "weights4", "bias4"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::SELECTED_ROWS);
|
||||
if (v == "weights" || v == "bias" || v == "weights2" || v == "bias2" ||
|
||||
v == "weights3" || v == "bias3" || v == "weights4" || v == "bias4") {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
// depthwise conv with MKL-DNN
|
||||
SetOp(&prog, "depthwise_conv2d", "conv1",
|
||||
std::vector<std::string>({"a", "weights", "bias"}),
|
||||
std::vector<std::string>({"b"}), true);
|
||||
// depthwise conv without MKL-DNN
|
||||
SetOp(&prog, "depthwise_conv2d", "conv2",
|
||||
std::vector<std::string>({"b", "weights2", "bias2"}),
|
||||
std::vector<std::string>({"c"}), false);
|
||||
// conv with MKL-DNN
|
||||
SetOp(&prog, "conv2d", "conv3",
|
||||
std::vector<std::string>({"c", "weights3", "bias3"}),
|
||||
std::vector<std::string>({"d"}), true);
|
||||
// conv without MKL-dNN
|
||||
SetOp(&prog, "conv2d", "conv4",
|
||||
std::vector<std::string>({"d", "weights4", "bias4"}),
|
||||
std::vector<std::string>({"e"}), false);
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
TEST(DepthwiseConvMKLDNNPass, basic) {
|
||||
auto prog = BuildProgramDesc();
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("depthwise_conv_mkldnn_pass");
|
||||
|
||||
struct counters {
|
||||
int mkldnn_depthwise_conv_nodes;
|
||||
int other_depthwise_conv_nodes;
|
||||
int mkldnn_conv_nodes;
|
||||
int other_conv_nodes;
|
||||
};
|
||||
|
||||
counters before{1, 1, 1, 1};
|
||||
|
||||
graph = pass->Apply(std::move(graph));
|
||||
|
||||
// initialize counters before loop
|
||||
counters after{0, 0, 0, 0};
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp()) {
|
||||
auto* op = node->Op();
|
||||
if (op->Type() == "conv2d") {
|
||||
if (boost::get<bool>(op->GetAttr("use_mkldnn")))
|
||||
after.mkldnn_conv_nodes++;
|
||||
else
|
||||
after.other_conv_nodes++;
|
||||
} else if (op->Type() == "depthwise_conv2d") {
|
||||
if (boost::get<bool>(op->GetAttr("use_mkldnn")))
|
||||
after.mkldnn_depthwise_conv_nodes++;
|
||||
else
|
||||
after.other_depthwise_conv_nodes++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(after.other_depthwise_conv_nodes,
|
||||
before.other_depthwise_conv_nodes);
|
||||
EXPECT_EQ(after.other_conv_nodes, before.other_conv_nodes);
|
||||
EXPECT_EQ(after.mkldnn_depthwise_conv_nodes,
|
||||
before.mkldnn_depthwise_conv_nodes - 1);
|
||||
EXPECT_EQ(after.mkldnn_conv_nodes, before.mkldnn_conv_nodes + 1);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(depthwise_conv_mkldnn_pass);
|
@ -0,0 +1,132 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
|
||||
using DataLayout = platform::DataLayout;
|
||||
using ScopedSpatialTransformerDescriptor =
|
||||
platform::ScopedSpatialTransformerDescriptor;
|
||||
template <typename T>
|
||||
using CudnnDataType = platform::CudnnDataType<T>;
|
||||
|
||||
template <typename T>
|
||||
class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use CUDAPlace");
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto handle = dev_ctx.cudnn_handle();
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* grid = ctx.Input<Tensor>("Grid");
|
||||
auto* output = ctx.Output<Tensor>("Output");
|
||||
|
||||
int n = input->dims()[0];
|
||||
int c = input->dims()[1];
|
||||
int h = input->dims()[2];
|
||||
int w = input->dims()[3];
|
||||
const int size[4] = {n, c, h, w};
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
const T* grid_data = grid->data<T>();
|
||||
T* output_data = output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
|
||||
|
||||
ScopedSpatialTransformerDescriptor st_desc;
|
||||
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
|
||||
st_desc.descriptor<T>(4, size);
|
||||
|
||||
ScopedTensorDescriptor input_desc;
|
||||
ScopedTensorDescriptor output_desc;
|
||||
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
||||
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
|
||||
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
||||
DataLayout::kNCHW, framework::vectorize2int(output->dims()));
|
||||
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward(
|
||||
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc,
|
||||
input_data, grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc,
|
||||
output_data));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use CUDAPlace");
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto handle = dev_ctx.cudnn_handle();
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* grid = ctx.Input<Tensor>("Grid");
|
||||
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
|
||||
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
|
||||
|
||||
auto output_grad_dims = output_grad->dims();
|
||||
const int n = output_grad_dims[0];
|
||||
const int c = output_grad_dims[1];
|
||||
const int h = output_grad_dims[2];
|
||||
const int w = output_grad_dims[3];
|
||||
const int size[4] = {n, c, h, w};
|
||||
|
||||
ScopedSpatialTransformerDescriptor st_dest;
|
||||
cudnnSpatialTransformerDescriptor_t cudnn_st_dest =
|
||||
st_dest.descriptor<T>(4, size);
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
const T* grid_data = grid->data<T>();
|
||||
const T* output_grad_data = output_grad->data<T>();
|
||||
T* input_grad_data =
|
||||
input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
|
||||
T* grid_grad_data =
|
||||
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
|
||||
|
||||
ScopedTensorDescriptor input_desc;
|
||||
ScopedTensorDescriptor input_grad_desc;
|
||||
ScopedTensorDescriptor output_grad_desc;
|
||||
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
||||
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
|
||||
cudnnTensorDescriptor_t cudnn_input_grad_desc =
|
||||
input_grad_desc.descriptor<T>(
|
||||
DataLayout::kNCHW, framework::vectorize2int(input_grad->dims()));
|
||||
cudnnTensorDescriptor_t cudnn_output_grad_desc =
|
||||
output_grad_desc.descriptor<T>(
|
||||
DataLayout::kNCHW, framework::vectorize2int(output_grad->dims()));
|
||||
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward(
|
||||
handle, cudnn_st_dest, CudnnDataType<T>::kOne(), cudnn_input_desc,
|
||||
input_data, CudnnDataType<T>::kZero(), cudnn_input_grad_desc,
|
||||
input_grad_data, CudnnDataType<T>::kOne(), cudnn_output_grad_desc,
|
||||
output_grad_data, grid_data, CudnnDataType<T>::kZero(),
|
||||
grid_grad_data));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace,
|
||||
paddle::operators::CUDNNGridSampleOpKernel<float>,
|
||||
paddle::operators::CUDNNGridSampleOpKernel<double>);
|
||||
REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace,
|
||||
paddle::operators::CUDNNGridSampleGradOpKernel<float>,
|
||||
paddle::operators::CUDNNGridSampleGradOpKernel<double>);
|
@ -0,0 +1,203 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/grid_sampler_op.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class GridSampleOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of GridSampleOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Grid"),
|
||||
"Input(Grid) of GridSampleOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Output"),
|
||||
"Output(Output) of GridSampleOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto grid_dims = ctx->GetInputDim("Grid");
|
||||
PADDLE_ENFORCE(x_dims.size() == 4,
|
||||
"Input(X) of GridSampleOp should be 4-D Tensor.");
|
||||
PADDLE_ENFORCE(grid_dims.size() == 4,
|
||||
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
|
||||
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
|
||||
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
|
||||
"Input(X) and Input(Grid) dims[0] should be equal.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
grid_dims[1], x_dims[2],
|
||||
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
grid_dims[2], x_dims[3],
|
||||
"Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
|
||||
|
||||
ctx->SetOutputDim("Output", x_dims);
|
||||
ctx->ShareLoD("X", "Output");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library_{framework::LibraryType::kPlain};
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (platform::CanCUDNNBeUsed(ctx)) {
|
||||
library_ = framework::LibraryType::kCUDNN;
|
||||
}
|
||||
#endif
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
||||
framework::DataLayout::kAnyLayout, library_);
|
||||
}
|
||||
};
|
||||
|
||||
class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor) The input data of GridSampleOp, "
|
||||
"This is a 4-D tensor with shape of [N, C, H, W]");
|
||||
AddInput(
|
||||
"Grid",
|
||||
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
|
||||
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
|
||||
"of x and y coordinates with shape [N, H, W] in last dimention");
|
||||
AddOutput("Output", "(Tensor) Output tensor with shape [N, C, H, W]");
|
||||
AddAttr<bool>(
|
||||
"use_cudnn",
|
||||
"(bool, default true) Only used in cudnn kernel, need install cudnn")
|
||||
.SetDefault(true);
|
||||
|
||||
AddComment(R"DOC(
|
||||
This operation samples input X by using bilinear interpolation based on
|
||||
flow field grid, which is usually gennerated by affine_grid. The grid of
|
||||
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
|
||||
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
|
||||
(in width dimension) of input data x and grid_y is indexng the 3rd
|
||||
dimention (in height dimension), finally results is the bilinear
|
||||
interpolation value of 4 nearest corner points.
|
||||
|
||||
Step 1:
|
||||
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
|
||||
|
||||
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
|
||||
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
|
||||
|
||||
Step 2:
|
||||
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
|
||||
interpolate point value by 4 nearest points.
|
||||
|
||||
wn ------- y_n ------- en
|
||||
| | |
|
||||
| d_n |
|
||||
| | |
|
||||
x_w --d_w-- grid--d_e-- x_e
|
||||
| | |
|
||||
| d_s |
|
||||
| | |
|
||||
ws ------- y_s ------- wn
|
||||
|
||||
x_w = floor(x) // west side x coord
|
||||
x_e = x_w + 1 // east side x coord
|
||||
y_n = floor(y) // north side y coord
|
||||
y_s = y_s + 1 // south side y coord
|
||||
|
||||
d_w = grid_x - x_w // distance to west side
|
||||
d_e = x_e - grid_x // distance to east side
|
||||
d_n = grid_y - y_n // distance to north side
|
||||
d_s = y_s - grid_y // distance to south side
|
||||
|
||||
wn = X[:, :, y_n, x_w] // north-west point value
|
||||
en = X[:, :, y_n, x_e] // north-east point value
|
||||
ws = X[:, :, y_s, x_w] // south-east point value
|
||||
es = X[:, :, y_s, x_w] // north-east point value
|
||||
|
||||
output = wn * d_e * d_s + en * d_w * d_s
|
||||
+ ws * d_e * d_n + es * d_w * d_n
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class GridSampleOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
auto grid_dims = ctx->GetInputDim("Grid");
|
||||
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
|
||||
}
|
||||
if (ctx->HasOutput(framework::GradVarName("Grid"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Grid"), grid_dims);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library_{framework::LibraryType::kPlain};
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (platform::CanCUDNNBeUsed(ctx)) {
|
||||
library_ = framework::LibraryType::kCUDNN;
|
||||
}
|
||||
#endif
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
||||
framework::DataLayout::kAnyLayout, library_);
|
||||
}
|
||||
};
|
||||
|
||||
class GridSampleGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
auto* op = new framework::OpDesc();
|
||||
op->SetType("grid_sampler_grad");
|
||||
op->SetInput("X", Input("X"));
|
||||
op->SetInput("Grid", Input("Grid"));
|
||||
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
|
||||
|
||||
op->SetAttrMap(Attrs());
|
||||
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid"));
|
||||
return std::unique_ptr<framework::OpDesc>(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
|
||||
ops::GridSampleGradMaker);
|
||||
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
grid_sampler,
|
||||
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
grid_sampler_grad,
|
||||
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue