diff --git a/paddle/fluid/framework/details/memory_optimize_helper.cc b/paddle/fluid/framework/details/memory_optimize_helper.cc index c89a33fc95..533d3269be 100644 --- a/paddle/fluid/framework/details/memory_optimize_helper.cc +++ b/paddle/fluid/framework/details/memory_optimize_helper.cc @@ -337,7 +337,6 @@ bool NodeCanReused(const VarDesc& node) { auto type = node.GetType(); // only these types holds bulk of gpu memory if (!(type == proto::VarType::LOD_TENSOR || - type == proto::VarType::SELECTED_ROWS || type == proto::VarType::LOD_TENSOR_ARRAY)) { return false; } diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f7d82d5ead..ca6950b1a4 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -46,6 +46,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base) +pass_library(cpu_quantize_squash_pass inference) pass_library(fc_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) @@ -101,6 +102,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) +cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) diff --git a/paddle/fluid/framework/ir/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/cpu_quantize_squash_pass.cc new file mode 100644 index 0000000000..de62a69de4 --- /dev/null +++ b/paddle/fluid/framework/ir/cpu_quantize_squash_pass.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file eint8_outcept in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either eint8_outpress or +// implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void CPUQuantizeSquashPass::FindNodesToKeep( + Graph* graph, + std::unordered_map* nodes_keep_counter) const { + GraphPatternDetector gpd; + patterns::DequantAny deq_any_pattern{gpd.mutable_pattern(), "deqant_any"}; + deq_any_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, deq_any_pattern); + + if (nodes_keep_counter->find(dequant_out) == nodes_keep_counter->end()) + (*nodes_keep_counter)[dequant_out] = 1; + else + (*nodes_keep_counter)[dequant_out] += 1; + + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +void CPUQuantizeSquashPass::Squash( + Graph* graph, + std::unordered_map* nodes_keep_counter) const { + GraphPatternDetector gpd; + patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"}; + squash_pattern(); + + int found_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash requantize-quantize ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern); + + auto* next_op_desc = next_op->Op(); + float dequant_scale = boost::get(dequant_op->Op()->GetAttr("Scale")); + float quant_scale = boost::get(quant_op->Op()->GetAttr("Scale")); + PADDLE_ENFORCE(nodes_keep_counter->find(dequant_out) != + nodes_keep_counter->end()); + + // check if dequantize op should be kept or removed, decrease the counter + bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; + + if (dequant_scale == quant_scale) { + // squash dequantize-quantize to nothing + auto quant_out_var_name = quant_out->Name(); + auto next_op_inputs = next_op_desc->InputNames(); + for (const auto& name : next_op_inputs) { + auto var_name = next_op_desc->Input(name)[0]; + if (var_name.compare(quant_out_var_name) == 0) { + next_op_desc->SetInput( + name, std::vector({dequant_in->Name()})); + break; + } + } + + if (keep_dequant) + GraphSafeRemoveNodes(graph, {quant_op, quant_out}); + else + GraphSafeRemoveNodes(graph, + {dequant_op, quant_op, dequant_out, quant_out}); + + IR_NODE_LINK_TO(dequant_in, next_op); + + found_squash_count++; + } else { + // squash dequantize-quantize to requantize op + OpDesc desc; + desc.SetType("requantize"); + desc.SetInput("Input", std::vector({dequant_in->Name()})); + desc.SetOutput("Output", std::vector({quant_out->Name()})); + desc.SetAttr("Scale_in", dequant_scale); + desc.SetAttr("Scale_out", quant_scale); + + auto requant_op = g->CreateOpNode(&desc); + + if (keep_dequant) + GraphSafeRemoveNodes(graph, {quant_op}); + else + GraphSafeRemoveNodes(graph, {dequant_op, quant_op, dequant_out}); + + IR_NODE_LINK_TO(dequant_in, requant_op); + IR_NODE_LINK_TO(requant_op, quant_out); + + found_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_squash_count); + PrettyLogDetail("--- squashed %d dequantize-quantize pairs", + found_squash_count); +} + +std::unique_ptr CPUQuantizeSquashPass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("cpu_quantize_squash_pass", graph.get()); + + std::unordered_map nodes_keep_counter; + FindNodesToKeep(graph.get(), &nodes_keep_counter); + Squash(graph.get(), &nodes_keep_counter); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(cpu_quantize_squash_pass, + paddle::framework::ir::CPUQuantizeSquashPass); diff --git a/paddle/fluid/framework/ir/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/cpu_quantize_squash_pass.h new file mode 100644 index 0000000000..b823a2cef3 --- /dev/null +++ b/paddle/fluid/framework/ir/cpu_quantize_squash_pass.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Squash dequantize->quantize pair pattern into requantize op + */ +class CPUQuantizeSquashPass : public FusePassBase { + public: + virtual ~CPUQuantizeSquashPass() {} + + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; + + /* + * For each dequantize's output find the number of operators it is an input to + */ + void FindNodesToKeep( + Graph* graph, + std::unordered_map* nodes_keep_counter) const; + + /* + * Squash dequantize-quantize ops pairs into requantize or nothing + */ + void Squash(Graph* graph, + std::unordered_map* nodes_keep_counter) const; + + const std::string name_scope_{"squash"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/cpu_quantize_squash_pass_tester.cc new file mode 100644 index 0000000000..3a3eb53f79 --- /dev/null +++ b/paddle/fluid/framework/ir/cpu_quantize_squash_pass_tester.cc @@ -0,0 +1,179 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h" +#include +#include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, bool use_mkldnn, + float scale = 0) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("name", name); + if (type == "conv2d") { + op->SetInput("Input", {inputs[0]}); + if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); + if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); + op->SetOutput("Output", {outputs[0]}); + } else if (type == "quantize") { + op->SetInput("Input", {inputs[0]}); + op->SetOutput("Output", {outputs[0]}); + op->SetAttr("Scale", scale); + } else if (type == "dequantize") { + op->SetInput("Input", {inputs[0]}); + op->SetOutput("Output", {outputs[0]}); + op->SetAttr("Scale", scale); + } +} + +// (a,w1,b1)->Conv1->d +// d->Dequant->e +// e->Quant->f +// (f,w2,b2)->Conv2->i +ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) { + ProgramDesc prog; + for (auto& v : std::initializer_list( + {"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) { + auto* var = prog.MutableBlock(0)->Var(v); + if (v.find("w") == 0 || v.find("b") == 0) { + var->SetPersistable(true); + } + } + + SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn); + SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1); + SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2); + SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn); + return prog; +} + +static const std::initializer_list variable_names{ + "a", "b", "c", "d", "e", "f", "g", "h"}; +// a->Conv1->b +// b->Dequant->c +// +// c->Quant1->d and d->Conv2->e +// +// c->Conv3->f +// +// c->Quant2->g and g->Conv4->h +// +ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2, + float scale3) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn); + SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); + + SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2); + SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn); + + SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn); + + SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3); + SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn); + + return prog; +} + +void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, + const char* var_name) { + auto x = scope->Var(var_name); + auto tensor = x->GetMutable(); + tensor->mutable_data(place, proto::VarType::FP32, + ::paddle::memory::Allocator::kDefault, 1); +} + +void MainTest(const ProgramDesc& prog, int removed_nodes_num) { + std::unique_ptr graph(new ir::Graph(prog)); + + // Init scope, as it is used in pass + auto place = paddle::platform::CPUPlace(); + NaiveExecutor exe{place}; + Scope scope; + exe.CreateVariables(prog, 0, true, &scope); + + for (auto& v : variable_names) { + InitTensorHolder(&scope, place, v.c_str()); + } + + graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); + + auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass"); + + int original_nodes_num = graph->Nodes().size(); + + graph = pass->Apply(std::move(graph)); + + int current_nodes_num = graph->Nodes().size(); + + EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num); +} + +TEST(CpuQuantizeSquashPass, equal_scales) { + auto scale = 1.2345f; + auto use_mkldnn = true; + // Remove 4 nodes: Dequant, Quant, e, f + auto remove_nodes = 4; + MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes); + + use_mkldnn = !use_mkldnn; + MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes); +} + +TEST(CpuQuantizeSquashPass, inequal_scales) { + auto scale1 = 1.2345f; + auto scale2 = 21.0f; + auto use_mkldnn = true; + // Remove 3 nodes: Dequant, Quant, e + // Insert 1 node: requantize + auto remove_nodes = 2; + MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes); + + use_mkldnn = !use_mkldnn; + MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes); +} + +TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) { + // Delete both quantize ops, + // bypass dequantize in both branches, + // insert requantize on one branch + auto scale = 1.2345f; + auto scale2 = 21.0f; + auto use_mkldnn = true; + // Remove 3 nodes: Quant1, Quant2, g + // Insert 1 node: requantize + auto remove_nodes = 2; + MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes); + + use_mkldnn = !use_mkldnn; + MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(cpu_quantize_squash_pass); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index c0c34d186b..08354b526a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1301,6 +1301,51 @@ PDNode *patterns::ConvAffineChannel::operator()( return ac_out_var; } +PDNode *patterns::DequantQuantAny::operator()() { + auto *dequant_in = pattern->NewNode(dequant_in_repr()) + ->AsInput() + ->assert_is_op_input("dequantize", "Input"); + + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); + + auto *dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize", "Output"); + + auto *quant_op = pattern->NewNode(quant_op_repr()) + ->assert_is_op("quantize") + ->AsIntermediate(); + + auto *quant_out = pattern->NewNode(quant_out_repr()) + ->AsOutput() + ->assert_is_op_output("quantize"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); + quant_op->LinksFrom({dequant_out}).LinksTo({quant_out}); + next_op->LinksFrom({quant_out}); + + return quant_out; +} + +PDNode *patterns::DequantAny::operator()() { + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); + + auto *dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize", "Output"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + dequant_op->LinksTo({dequant_out}); + next_op->LinksFrom({dequant_out}); + + return dequant_out; +} + // a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a // b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b // ... diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index c8be586f54..3db4bba10d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -18,8 +18,11 @@ #include #endif +#include #include #include +#include +#include #include #include #include "paddle/fluid/framework/ir/graph.h" @@ -766,6 +769,34 @@ struct ConvAffineChannel : public PatternBase { PATTERN_DECL_NODE(ac_out); // Out }; +// Dequantize + Quantize + anyOP +// This pattern is used for squashing the dequantize-quantize pairs. +struct DequantQuantAny : public PatternBase { + DequantQuantAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_quant_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(dequant_in); + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); + PATTERN_DECL_NODE(next_op); +}; + +// Dequantize + anyOP +// This quantize is used for getting number of ops the Dequantize's +// output is an input to. +struct DequantAny : public PatternBase { + DequantAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); + PATTERN_DECL_NODE(next_op); +}; + struct TransposeFlattenConcat : public PatternBase { TransposeFlattenConcat(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "transpose_flatten_concat") {} diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 3adc7baebd..a617b9fb1d 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -13,18 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" +#include #include +#include namespace paddle { namespace operators { -class CrossEntropyOp : public framework::OperatorWithKernel { +class CrossEntropyOpBase : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); auto x_dims = ctx->GetInputDim("X"); @@ -43,7 +46,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { "Input(X) and Input(Label) shall have the same shape " "except the last dimension."); } - if (ctx->Attrs().Get("soft_label")) { + + if (IsSoftLabel(ctx)) { if (check) { PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], "If Attr(soft_label) == true, the last dimension of " @@ -69,21 +73,24 @@ class CrossEntropyOp : public framework::OperatorWithKernel { return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } }; -class CrossEntropyGradientOp : public framework::OperatorWithKernel { +class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + void InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "Input(Y@GRAD) shoudl be not null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Output(X@GRAD) should be not null."); - auto x_dims = ctx->GetInputDim("X"); + auto x_dims = GetXDim(ctx); auto label_dims = ctx->GetInputDim("Label"); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); int rank = x_dims.size(); @@ -108,9 +115,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { "The Input(X) and Input(Y@Grad) should have the same " "shape except the last dimension."); } - PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, - "The last dimension of Input(Y@Grad) should be 1."); - if (ctx->Attrs().Get("soft_label")) { + if (IsSoftLabel(ctx)) { if (check) { PADDLE_ENFORCE_EQ( x_dims[rank - 1], label_dims[rank - 1], @@ -123,7 +128,10 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { "Input(Label) should be 1."); } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD("X", framework::GradVarName("X")); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X")); } protected: @@ -131,8 +139,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Y"))->type(), + ctx.device_context()); + } + + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + return ctx->GetInputDim("X"); + } + + virtual const char* VarNameWithXLoD() const { return "X"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } +}; + +class CrossEntropyOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; } }; @@ -200,22 +228,132 @@ or not. But the output only shares the LoD information with input X. } }; -class CrossEntropyOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { +class CrossEntropyGradientOp : public CrossEntropyGradientOpBase { + public: + using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + CrossEntropyGradientOpBase::InferShape(ctx); + } +}; + +class CrossEntropyOp2 : public CrossEntropyOpBase { + public: + using CrossEntropyOpBase::CrossEntropyOpBase; + + void InferShape(framework::InferShapeContext* ctx) const override { + CrossEntropyOpBase::InferShape(ctx); + + PADDLE_ENFORCE(ctx->HasOutput("XShape"), + "Output(XShape) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("MatchX"), + "Output(MatchX) should be not null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_dims_vec = framework::vectorize(x_dims); + x_dims_vec.push_back(0); + ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); + x_dims[x_dims.size() - 1] = 1; + ctx->SetOutputDim("MatchX", x_dims); + ctx->ShareLoD("X", /*->*/ "XShape"); + } + protected: - std::unordered_map GetInputOutputWithSameType() - const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; + bool IsSoftLabel(framework::InferShapeContext* ctx) const override { + return false; + } +}; + +class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { + public: + using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist"); + CrossEntropyGradientOpBase::InferShape(ctx); + } + + protected: + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + auto x_shape = ctx->GetInputDim("XShape"); + return framework::DDim(x_shape.Get(), x_shape.size() - 1); + } + + virtual const char* VarNameWithXLoD() const { return "XShape"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return false; } }; + +class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. One hot Tensor."); + AddOutput("Y", + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); + AddOutput("XShape", "Temporaily variable to save shape and LoD of X."); + AddOutput("MatchX", + "X value that matches label, used for gradient computation."); + AddAttr("ignore_index", + "(int, default -100), Specifies a target value that is" + "ignored and does not contribute to the input gradient." + "Only valid if soft_label is set to False") + .SetDefault(-100); + AddComment(R"DOC( +Hard-label CrossEntropy Operator. + +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + +Only support hard label. + +Both the input X and Label can carry the LoD (Level of Details) information, +or not. But the output only shares the LoD information with input X. + +)DOC"); + } +}; + +class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("cross_entropy_grad2"); + op->SetInput("Label", Input("Label")); + op->SetInput("MatchX", Output("MatchX")); + op->SetInput("XShape", Output("XShape")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPUCtx = paddle::platform::CPUDeviceContext; -REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, - ops::CrossEntropyOpInferVarType, +REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase, + ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, @@ -223,3 +361,14 @@ REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, REGISTER_OP_CPU_KERNEL(cross_entropy_grad, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel); + +REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, + ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType, + ops::CrossEntropyGradOpDescMaker2); +REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); +REGISTER_OP_CPU_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); +REGISTER_OP_CPU_KERNEL(cross_entropy_grad2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index fcd34383a8..243e7f52c1 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -27,3 +27,13 @@ REGISTER_OP_CUDA_KERNEL( cross_entropy_grad, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel); + +REGISTER_OP_CUDA_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); + +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index f123e11542..7eb663773e 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" @@ -137,5 +138,124 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { } }; +template +struct HardLabelCrossEntropyForwardFunctor { + HardLabelCrossEntropyForwardFunctor(const T* x, T* y, T* match_x, + const int64_t* label, + int64_t ignore_index, + int64_t feature_size) + : x_(x), + y_(y), + match_x_(match_x), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto label = label_[idx]; + if (label != ignore_index_) { + auto match_x = x_[idx * feature_size_ + label]; + y_[idx] = -math::TolerableValue()(real_log(match_x)); + match_x_[idx] = match_x; + } else { + y_[idx] = 0; + match_x_[idx] = 0; // any value is ok + } + } + + const T* x_; + T* y_; + T* match_x_; + const int64_t* label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + +template +struct HardLabelCrossEntropyBackwardFunctor { + HardLabelCrossEntropyBackwardFunctor(T* dx, const T* dy, const T* match_x, + const int64_t* label, + int64_t ignore_index, + int64_t feature_size) + : dx_(dx), + dy_(dy), + match_x_(match_x), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto row_idx = idx / feature_size_; + auto col_idx = idx % feature_size_; + auto label = label_[row_idx]; + if (label == col_idx && label != ignore_index_) { + dx_[idx] = -dy_[row_idx] / match_x_[row_idx]; + } else { + dx_[idx] = 0; + } + } + + T* dx_; + const T* dy_; + const T* match_x_; + const int64_t* label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + +template +class CrossEntropyOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* label = ctx.Input("Label"); + auto* y = ctx.Output("Y"); + auto* match_x = ctx.Output("MatchX"); + + auto& x_dims = x->dims(); + auto feature_size = x_dims[x_dims.size() - 1]; + auto batch_size = framework::product(x->dims()) / feature_size; + + auto* p_x = x->data(); + auto* p_label = label->data(); + auto* p_y = y->mutable_data(ctx.GetPlace()); + auto* p_match_x = match_x->mutable_data(ctx.GetPlace()); + + auto ignore_index = ctx.Attr("ignore_index"); + + platform::ForRange for_range( + ctx.template device_context(), batch_size); + for_range(HardLabelCrossEntropyForwardFunctor( + p_x, p_y, p_match_x, p_label, ignore_index, feature_size)); + } +}; + +template +class CrossEntropyGradientOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* match_x = ctx.Input("MatchX"); + auto* label = ctx.Input("Label"); + + auto* p_dx = dx->mutable_data(ctx.GetPlace()); + auto* p_dy = dy->data(); + auto* p_match_x = match_x->data(); + auto* p_label = label->data(); + + int64_t ignore_index = ctx.Attr("ignore_index"); + int rank = dx->dims().size(); + int64_t feature_size = dx->dims()[rank - 1]; + int64_t batch_size = framework::product(dx->dims()) / feature_size; + + platform::ForRange for_range( + ctx.template device_context(), + batch_size * feature_size); + for_range(HardLabelCrossEntropyBackwardFunctor( + p_dx, p_dy, p_match_x, p_label, ignore_index, feature_size)); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 44a2f37b66..fcb2be9363 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/expand_op.h" +#include #include namespace paddle { @@ -138,12 +139,28 @@ class ExpandGradOp : public framework::OperatorWithKernel { } }; +class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("expand_grad"); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::ExpandGradOpDescMaker); REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp); REGISTER_OP_CPU_KERNEL( expand, ops::ExpandKernel, diff --git a/paddle/fluid/operators/math.h b/paddle/fluid/operators/math.h new file mode 100644 index 0000000000..8cc24200d3 --- /dev/null +++ b/paddle/fluid/operators/math.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/hostdevice.h" + +#include "math.h" // NOLINT + +namespace paddle { +namespace operators { + +inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) { + return static_cast(::expf(static_cast(x))); +} + +inline HOSTDEVICE float real_exp(float x) { return ::expf(x); } + +inline HOSTDEVICE double real_exp(double x) { return ::exp(x); } + +inline HOSTDEVICE platform::float16 real_log(platform::float16 x) { + return static_cast(::logf(static_cast(x))); +} + +inline HOSTDEVICE float real_log(float x) { return ::logf(x); } + +inline HOSTDEVICE double real_log(double x) { return ::log(x); } + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index cb200ec8d6..44cbdf2e98 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -20,17 +21,6 @@ namespace paddle { namespace operators { namespace math { -namespace { - -__device__ __forceinline__ float real_log(float x) { return logf(x); } - -__device__ __forceinline__ double real_log(double x) { return log(x); } - -__device__ __forceinline__ platform::float16 real_log( - const platform::float16& val) { - return static_cast(logf(static_cast(val))); -} - template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, const int N, const int D, @@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, Y[blockIdx.x] = -val; } } -} // namespace template class CrossEntropyFunctor { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 37f69426b6..2b429380fb 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -219,14 +219,6 @@ class ReshapeKernel { std::vector(shape_data, shape_data + shape_tensor->numel()); out_dims = ReshapeOp::ValidateShape(shape, in->dims()); } - if (!in->lod().empty()) { - PADDLE_ENFORCE_EQ( - out_dims[0], in->dims()[0], - "Reshape operator cannot reshape an input sequence batch " - "into an output sequence batch that has a different " - "number of time steps. Please consider using " - "sequence_reshape op."); - } out->mutable_data(ctx.GetPlace(), in->type()); framework::TensorCopy( diff --git a/paddle/fluid/operators/selu_op.h b/paddle/fluid/operators/selu_op.h index bdb506885c..b2fc834c42 100644 --- a/paddle/fluid/operators/selu_op.h +++ b/paddle/fluid/operators/selu_op.h @@ -15,13 +15,12 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/for_range.h" + namespace paddle { namespace operators { -static HOSTDEVICE float real_exp(float x) { return expf(x); } -static HOSTDEVICE float real_exp(double x) { return exp(x); } - template struct SeluFunctor { SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu index cc5e982190..a9dc0a4fda 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include // NOLINT +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" namespace paddle { @@ -21,9 +22,6 @@ namespace operators { using LoDTensor = framework::LoDTensor; -__device__ __forceinline__ float real_exp(float x) { return expf(x); } -__device__ __forceinline__ double real_exp(double x) { return exp(x); } - template using BlockReduce = cub::BlockReduce; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index 2a4570ef5c..aea69de643 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "cub/cub.cuh" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/hostdevice.h" @@ -21,11 +22,6 @@ namespace operators { using Tensor = framework::Tensor; -static HOSTDEVICE float real_exp(float x) { return expf(x); } -static HOSTDEVICE float real_exp(double x) { return exp(x); } -static HOSTDEVICE float real_log(float x) { return logf(x); } -static HOSTDEVICE float real_log(double x) { return log(x); } - static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaxinumNumBlocks = 4096; diff --git a/paddle/fluid/operators/slice_op.cu b/paddle/fluid/operators/slice_op.cu index 5efecb78d1..1af57b89a3 100644 --- a/paddle/fluid/operators/slice_op.cu +++ b/paddle/fluid/operators/slice_op.cu @@ -12,18 +12,138 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/slice_op.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void Padding(const paddle::platform::float16* d_out, + const int* out_dims, const int* in_dims, + const int* offsets, int64_t n, + paddle::platform::float16* d_in) { + int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x; + if (out_idx < n) { + int coords[D] = {0}; + for (int i = D - 1; i >= 0; --i) { + coords[i] = out_idx % out_dims[i]; + out_idx /= out_dims[i]; + coords[i] += offsets[i]; + } + + int64_t in_idx = 0; + for (int i = 0; i < D - 1; ++i) { + in_idx += coords[i] * in_dims[i + 1]; + } + in_idx += coords[D - 1]; + + d_in[in_idx] = d_out[out_idx]; + } +} + +template <> +class SliceGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_in = ctx.Output(framework::GradVarName("Input")); + d_in->mutable_data(ctx.GetPlace()); + + auto out_dims = d_out->dims(); + auto in_dims = d_in->dims(); + int rank = out_dims.size(); + std::vector offsets(rank, 0); + auto axes = ctx.Attr>("axes"); + auto starts = ctx.Attr>("starts"); + + for (size_t i = 0; i < starts.size(); ++i) { + if (starts[i] < 0) { + starts[i] += in_dims[axes[i]]; + } + offsets[axes[i]] = std::max(starts[i], 0); + } + + math::SetConstant + set_zero; + auto& dev_ctx = + ctx.template device_context(); + set_zero(dev_ctx, d_in, static_cast(0)); + + int64_t numel = d_out->numel(); + dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1, 1, 1); + dim3 threads(PADDLE_CUDA_NUM_THREADS, 1, 1); + auto stream = ctx.cuda_device_context().stream(); + + auto out_shape = framework::vectorize2int(out_dims); + thrust::device_vector out_dims_vec(out_shape.begin(), out_shape.end()); + auto in_shape = framework::vectorize2int(in_dims); + thrust::device_vector in_dims_vec(in_shape.begin(), in_shape.end()); + thrust::device_vector offsets_vec(offsets.begin(), offsets.end()); + const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data()); + const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data()); + const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data()); + + switch (rank) { + case 1: + Padding<1><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 2: + Padding<2><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 3: + Padding<3><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 4: + Padding<4><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 5: + Padding<5><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 6: + Padding<6><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + } + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( slice, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, - ops::SliceKernel); + ops::SliceKernel, + ops::SliceKernel); REGISTER_OP_CUDA_KERNEL( slice_grad, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, - ops::SliceGradKernel); + ops::SliceGradKernel, + ops::SliceGradKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d0bff52e43..9886f4e84c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1432,6 +1432,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): predict = fluid.layers.fc(input=net, size=classdim, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) """ + if not soft_label: + return cross_entropy2(input, label, ignore_index) helper = LayerHelper('cross_entropy', **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( @@ -1444,6 +1446,22 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): return out +def cross_entropy2(input, label, ignore_index=kIgnoreIndex): + helper = LayerHelper('cross_entropy2', **locals()) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + xshape = helper.create_variable_for_type_inference(dtype=input.dtype) + match_x = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='cross_entropy2', + inputs={'X': [input], + 'Label': [label]}, + outputs={'Y': [out], + 'MatchX': [match_x], + 'XShape': [xshape]}, + attrs={'ignore_index': ignore_index}) + return out + + def bpr_loss(input, label, name=None): """ Bayesian Personalized Ranking Loss Operator. diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy2_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy2_op.py new file mode 100644 index 0000000000..55029c18d6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy2_op.py @@ -0,0 +1,82 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import unittest +import numpy as np +import six + + +class CrossEntropy2OpTestBase(OpTest): + def initParameters(self): + return [32, 64], 'float32', -100 + + def calc_output(self, logits, label, ignore_index): + ret = np.zeros(shape=label.shape, dtype=logits.dtype) + match_x = np.zeros(shape=label.shape, dtype=logits.dtype) + for idx in six.moves.range(label.shape[0]): + if label[idx] == ignore_index: + continue + match_x[idx] = logits[idx][label[idx]] + ret[idx] = -np.log(match_x[idx]) + return ret, match_x + + def setUp(self): + self.shape, self.dtype, self.ignore_index = self.initParameters() + self.op_type = 'cross_entropy2' + feature_size = int(self.shape[-1]) + batch_size = int(np.prod(self.shape) / feature_size) + logits = (np.random.random(size=self.shape) + 1).astype(self.dtype) + label = np.random.random_integers( + low=0, high=feature_size - 1, + size=self.shape[0:-1] + [1]).astype('int64') + outputs, match_x = self.calc_output( + np.reshape(logits, [batch_size, feature_size]), + np.reshape(label, [batch_size, 1]), self.ignore_index) + self.inputs = {'X': logits, 'Label': label} + self.outputs = { + 'Y': np.reshape(outputs, label.shape), + 'MatchX': np.reshape(match_x, label.shape), + 'XShape': np.zeros( + shape=logits.shape, dtype=logits.dtype) + } + self.attrs = {'ignore_index': self.ignore_index} + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad( + inputs_to_check=['X'], + output_names=['Y'], + no_grad_set=['XShape', 'MatchX', 'Label']) + + +class CrossEntropy2OpTest2(CrossEntropy2OpTestBase): + def initParameters(self): + return [32, 64], 'float64', 3 + + +class CrossEntropy2OpTest3(CrossEntropy2OpTestBase): + def initParameters(self): + return [4, 8, 16, 32], 'float32', -100 + + +class CrossEntropy2OpTest4(CrossEntropy2OpTestBase): + def initParameters(self): + return [4, 8, 16, 32], 'float32', 3 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 12132477d2..f81d4fda50 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -524,8 +524,8 @@ class TestLocalLookupTable(TestDistLookupTableBase): ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', - 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', - 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', @@ -564,8 +564,8 @@ class TestDistLookupTable(TestDistLookupTableBase): ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', - 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', - 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', @@ -612,8 +612,8 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', - 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', - 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', @@ -652,8 +652,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', - 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', - 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', @@ -841,8 +841,8 @@ class TestRemoteLookupTable(TestDistLookupTableBase): ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', - 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', - 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 4e6ed3a74b..5fdabbabed 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest @@ -63,5 +64,28 @@ class TestCase2(TestSliceOp): self.out = self.input[-3:3, 0:100, :, 2:-1] +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFP16(TestSliceOp): + def config(self): + self.dtype = "float16" + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [-3, 0, 2] + self.ends = [3, 100, -1] + self.axes = [0, 1, 3] + self.out = self.input[-3:3, 0:100, :, 2:-1] + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-5) + + def test_check_grad_normal(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, ['Input'], 'Out', max_relative_error=0.006) + + if __name__ == '__main__': unittest.main() diff --git a/tools/timeline.py b/tools/timeline.py index 694ab1d50f..44c1c09b80 100644 --- a/tools/timeline.py +++ b/tools/timeline.py @@ -160,6 +160,8 @@ class Timeline(object): self._devices[(k, event.device_id, "GPUKernel")] = pid self._chrome_trace.emit_pid("%s:gpu:%d" % (k, event.device_id), pid) + if not hasattr(profile_pb, "mem_events"): + continue for mevent in profile_pb.mem_events: if mevent.place == profiler_pb2.MemEvent.CUDAPlace: if (k, mevent.device_id, "GPU") not in self._mem_devices: @@ -211,7 +213,7 @@ class Timeline(object): args = {'name': event.name} if event.memcopy.bytes > 0: args['mem_bytes'] = event.memcopy.bytes - if event.detail_info: + if hasattr(event, "detail_info") and event.detail_info: args['detail_info'] = event.detail_info # TODO(panyx0718): Chrome tracing only handles ms. However, some # ops takes micro-seconds. Hence, we keep the ns here. @@ -220,6 +222,8 @@ class Timeline(object): event.sub_device_id, 'Op', event.name, args) def _allocate_memory_event(self): + if not hasattr(profiler_pb2, "MemEvent"): + return place_to_str = { profiler_pb2.MemEvent.CPUPlace: "CPU", profiler_pb2.MemEvent.CUDAPlace: "GPU",