diff --git a/CMakeLists.txt b/CMakeLists.txt
index 48e52961a9..317f7f9eb4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -138,12 +138,6 @@ else()
set(THIRD_PARTY_BUILD_TYPE Release)
endif()
-if(WITH_MKL)
- option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
- if (MKL_SPLIT_GEMM)
- add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
- endif()
-endif()
set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN)
if (WITH_MKL AND AVX2_FOUND)
diff --git a/doc/fluid/dev/new_op_cn.md b/doc/fluid/dev/new_op_cn.md
index c00f73be95..ff7408111f 100644
--- a/doc/fluid/dev/new_op_cn.md
+++ b/doc/fluid/dev/new_op_cn.md
@@ -36,19 +36,19 @@
OpProtoMake定义 |
-`.cc`文件,Backward Op不需要定义OpProtoMake |
+.cc 文件,Backward Op不需要定义OpProtoMake |
Op定义 |
- `.cc`文件 |
+ .cc 文件 |
Kernel实现 |
- CPU、CUDA共享Kernel实现在`.h`文件中,否则,CPU 实现在`.cc`文件中,CUDA 实现在`.cu`文件中。 |
+ CPU、CUDA共享Kernel实现在.h 文件中,否则,CPU 实现在.cc 文件中,CUDA 实现在.cu 文件中。 |
注册Op |
- Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,CUDA实现在`.cu`文件中 |
+ Op注册实现在.cc 文件;Kernel注册CPU实现在.cc 文件中,CUDA实现在.cu 文件中 |
@@ -391,7 +391,7 @@ PADDLE_ENFORCE(ctx->HasInput("X"), "");
```
问题示例2 :提示信息过于简单
```
-PADDLE_ENFORCE(i != nullptr, "I must be set"); // I是什么?
+PADDLE_ENFORCE(i != nullptr, "i must be set"); // i是什么?
```
2. 在报错信息中使用开发人员定义的变量缩写,不易理解!
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 01b6053524..37c2523c9f 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -163,6 +163,7 @@ paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], v
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
+paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
@@ -192,7 +193,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None
paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
-paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
+paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
diff --git a/paddle/fluid/framework/array.h b/paddle/fluid/framework/array.h
new file mode 100644
index 0000000000..be9efcd749
--- /dev/null
+++ b/paddle/fluid/framework/array.h
@@ -0,0 +1,48 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include
+#include "paddle/fluid/platform/hostdevice.h"
+
+namespace paddle {
+namespace framework {
+template
+class Array {
+ static_assert(N > 0, "The size of array must be larger than 0");
+
+ public:
+ HOSTDEVICE Array() {}
+
+ HOSTDEVICE explicit Array(const T &val) {
+ for (size_t i = 0; i < N; ++i) data_[i] = val;
+ }
+
+ HOSTDEVICE const T *Get() const { return data_; }
+
+ HOSTDEVICE T *GetMutable() { return data_; }
+
+ HOSTDEVICE T &operator[](size_t index) { return data_[index]; }
+
+ HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; }
+
+ HOSTDEVICE constexpr size_t size() const { return N; }
+
+ private:
+ T data_[N];
+};
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
index c5a13e7e1f..bc61b0eacb 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
@@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
// Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const {
+ // FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
+ // put them into transpiler.
int op_dev_id = -1;
if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
@@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
"This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by
// split_byref op
- // so that we can balance the variable blocks to all the pserver
- // instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
node->inputs[0]->Name().find(".block") == std::string::npos) {
std::vector input_var_names;
for (ir::Node *n : node->inputs) {
input_var_names.push_back(n->Name());
}
- op_dev_id = GetAppropriateDeviceID(input_var_names);
+ auto send_param_grad = boost::get>(
+ node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+ PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
+ op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
+ VLOG(10) << "send grad " << input_var_names[0] << " origin "
+ << send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) {
result->Get(kShardedVarDevice)
.emplace(varname, op_dev_id);
}
+ result->Get(kShardedVarDevice)
+ .emplace(send_param_grad[1], op_dev_id);
}
} else if (node->Op()->Type() == "recv") {
std::vector output_var_names;
for (ir::Node *n : node->outputs) {
output_var_names.push_back(n->Name());
}
- op_dev_id = GetAppropriateDeviceID(output_var_names);
+ auto recv_param_grad = boost::get>(
+ node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+ // FIXME(typhoonzero): assume each recv op output one param
+ // Use the same place as send.
+ if (recv_param_grad.size() == 2U) {
+ op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
+ VLOG(10) << "recv param " << recv_param_grad[0]
+ << " get grad place: " << recv_param_grad[1]
+ << " place: " << op_dev_id;
+ } else {
+ op_dev_id = GetAppropriateDeviceID(output_var_names);
+ }
for (auto &varname : output_var_names) {
result->Get(kShardedVarDevice)
.emplace(varname, op_dev_id);
diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc
index 69944a42b6..361c91dc78 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc
@@ -54,7 +54,8 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
- << var_handle_ptr->version_ << "\"]" << std::endl;
+ << "scope: " << var_handle_ptr->scope_idx_ << "\\n"
+ << "v" << var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
index 993c885a81..06f9df5546 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc
@@ -163,8 +163,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
// 3. Detect op2 -> var2 -> op4
// 4. Detect op2 -> var3 -> op5
// But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2
- ASSERT_GE(count, 1UL);
- ASSERT_LE(count, 2UL);
+ ASSERT_GE(count, 1);
+ ASSERT_LE(count, 2);
}
} // namespace ir
diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc
index aca77da8d6..65c45c7d20 100644
--- a/paddle/fluid/framework/ir/node.cc
+++ b/paddle/fluid/framework/ir/node.cc
@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
-const char Node::kControlDepVarName[] = "__control_var";
+constexpr char Node::kControlDepVarName[];
} // namespace ir
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h
index 63277d2d01..aab3180e7e 100644
--- a/paddle/fluid/framework/ir/node.h
+++ b/paddle/fluid/framework/ir/node.h
@@ -27,7 +27,7 @@ namespace ir {
class Node {
public:
enum class Type { kOperation, kVariable };
- static const char kControlDepVarName[];
+ static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc
index c202b0a5be..a4319ffabb 100644
--- a/paddle/fluid/framework/selected_rows.cc
+++ b/paddle/fluid/framework/selected_rows.cc
@@ -139,7 +139,7 @@ int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) {
}
auto write_iter = id_to_index_.find(key);
if (write_iter == id_to_index_.end()) {
- size_t row_num = rows_.size();
+ int row_num = rows_.size();
if (row_num == value_->dims()[0]) {
rwlock_->UNLock();
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
@@ -182,7 +182,7 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"except the dims[0].");
- for (size_t i = 0; i < ids.numel(); ++i) {
+ for (int i = 0; i < ids.numel(); ++i) {
int64_t index = AutoGrownIndex(ids.data()[i], auto_grown);
framework::VisitDataType(
framework::ToDataType(value_->type()),
diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc
index 52f5c4f5ae..baa7600283 100644
--- a/paddle/fluid/inference/analysis/analyzer_tester.cc
+++ b/paddle/fluid/inference/analysis/analyzer_tester.cc
@@ -23,6 +23,8 @@
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN");
+DEFINE_int32(batch_size, 10, "batch size.");
+DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
namespace paddle {
namespace inference {
@@ -92,7 +94,7 @@ struct DataRecord {
size_t batch_iter{0};
size_t batch_size{1};
DataRecord() = default;
- DataRecord(const std::string &path, int batch_size = 1)
+ explicit DataRecord(const std::string &path, int batch_size = 1)
: batch_size(batch_size) {
Load(path);
}
@@ -165,7 +167,6 @@ struct DataRecord {
};
void PrepareInputs(std::vector *input_slots, DataRecord *data,
int batch_size) {
- // DataRecord data(FLAGS_datapath, batch_size);
PaddleTensor lod_attention_tensor, init_zero_tensor, lod_tensor_tensor,
week_tensor, minute_tensor;
lod_attention_tensor.name = "data_lod_attention";
@@ -174,28 +175,33 @@ void PrepareInputs(std::vector *input_slots, DataRecord *data,
week_tensor.name = "week";
minute_tensor.name = "minute";
auto one_batch = data->NextBatch();
- // clang-format off
- std::vector rnn_link_data_shape
- ({static_cast(one_batch.rnn_link_data.size()), static_cast(one_batch.rnn_link_data.front().size())});
+ std::vector rnn_link_data_shape(
+ {static_cast(one_batch.rnn_link_data.size()),
+ static_cast(one_batch.rnn_link_data.front().size())});
lod_attention_tensor.shape.assign({1, 2});
lod_attention_tensor.lod.assign({one_batch.lod1, one_batch.lod2});
init_zero_tensor.shape.assign({batch_size, 15});
init_zero_tensor.lod.assign({one_batch.lod3});
lod_tensor_tensor.shape = rnn_link_data_shape;
lod_tensor_tensor.lod.assign({one_batch.lod1});
- week_tensor.shape.assign({(int) one_batch.rnn_week_datas.size(), (int) one_batch.rnn_week_datas.front().size()});
+ // clang-format off
+ week_tensor.shape.assign(
+ {static_cast(one_batch.rnn_week_datas.size()),
+ static_cast(one_batch.rnn_week_datas.front().size())});
week_tensor.lod.assign({one_batch.lod3});
- minute_tensor.shape.assign({(int) one_batch.rnn_minute_datas.size(),
- (int) one_batch.rnn_minute_datas.front().size()});
+ minute_tensor.shape.assign(
+ {static_cast(one_batch.rnn_minute_datas.size()),
+ static_cast(one_batch.rnn_minute_datas.front().size())});
minute_tensor.lod.assign({one_batch.lod3});
+ // clang-format on
// assign data
- TensorAssignData(&lod_attention_tensor, std::vector>({{0, 0}}));
+ TensorAssignData(&lod_attention_tensor,
+ std::vector>({{0, 0}}));
std::vector tmp_zeros(batch_size * 15, 0.);
TensorAssignData(&init_zero_tensor, {tmp_zeros});
TensorAssignData(&lod_tensor_tensor, one_batch.rnn_link_data);
TensorAssignData(&week_tensor, one_batch.rnn_week_datas);
TensorAssignData(&minute_tensor, one_batch.rnn_minute_datas);
- // clang-format on
// Set inputs.
auto init_zero_tensor1 = init_zero_tensor;
init_zero_tensor1.name = "hidden_init";
@@ -231,12 +237,9 @@ std::string DescribeTensor(const PaddleTensor &tensor) {
os << "\n";
os << " - data: ";
- // clang-format off
- int dim = std::accumulate(tensor.shape.begin(),
- tensor.shape.end(),
- 1,
- [](int a, int b) { return a * b; }); // clang-format on
- for (size_t i = 0; i < dim; i++) {
+ int dim = std::accumulate(tensor.shape.begin(), tensor.shape.end(), 1,
+ [](int a, int b) { return a * b; });
+ for (int i = 0; i < dim; i++) {
os << static_cast(tensor.data.data())[i] << " ";
}
os << '\n';
@@ -300,13 +303,16 @@ void TestDituRNNPrediction(const std::string &model_path,
for (int i = 0; i < num_times; i++) {
predictor->Run(input_slots, &outputs);
}
- LOG(INFO) << "time/batch: " << timer.toc() / num_times;
+ LOG(INFO) << "===========profile result===========";
+ LOG(INFO) << "batch_size: " << batch_size << ", repeat: " << num_times
+ << ", latency: " << timer.toc() / num_times << "ms";
+ LOG(INFO) << "=====================================";
for (auto &out : outputs) {
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; });
float *data = static_cast(out.data.data());
- for (int i = 0;
+ for (size_t i = 0;
i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size);
i++) {
EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3);
@@ -336,7 +342,7 @@ TEST(Analyzer, SupportIRPass) {
// Directly infer with the original model.
TEST(Analyzer, DituRNN_without_analysis) {
TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
- 10, false, false);
+ FLAGS_batch_size, false, false, FLAGS_repeat);
}
// Inference with the original model with the analysis turned on, the analysis
@@ -344,14 +350,14 @@ TEST(Analyzer, DituRNN_without_analysis) {
TEST(Analyzer, DituRNN_with_analysis) {
LOG(INFO) << "ditu rnn with analysis";
TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
- 10, true, false, 1);
+ FLAGS_batch_size, true, false, FLAGS_repeat);
}
// Inference with analysis and IR. The IR module will fuse some large kernels.
TEST(Analyzer, DituRNN_with_analysis_with_IR) {
LOG(INFO) << "ditu rnn with analysis and IR fuse";
TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
- 10, true, true, 1);
+ FLAGS_batch_size, true, true, FLAGS_repeat);
}
} // namespace analysis
diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc
new file mode 100644
index 0000000000..1cb65346ee
--- /dev/null
+++ b/paddle/fluid/operators/attention_lstm_op.cc
@@ -0,0 +1,422 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/attention_lstm_op.h"
+#include
+#include
+#include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/cpu_vec.h"
+#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace operators {
+
+void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("C0"),
+ "Input(C0) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
+ "Input(LSTMWeight) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
+ "Input(LSTMBias) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
+ "Input(AttentionWeight) of AttentionLSTM should not be null.");
+
+ PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
+ "Output(Hidden) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Cell"),
+ "Output(Cell) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
+ "Output(AttentionedX) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
+ "Output(AttentionFCOut) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
+ "Output(LSTMX) of AttentionLSTM should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
+ "Output(LSTMOUT) of AttentionLSTM should not be null.");
+
+ auto x_dims = ctx->GetInputDim("X");
+ const int M = x_dims[1];
+ PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
+
+ auto w_dims = ctx->GetInputDim("LSTMWeight");
+ const int D = w_dims[1] / 4;
+ PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
+ PADDLE_ENFORCE_EQ(w_dims[0], D + M,
+ "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
+
+ auto b_dims = ctx->GetInputDim("LSTMBias");
+ PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
+ PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
+ PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
+
+ auto c_dims = ctx->GetInputDim("C0");
+ PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
+ PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
+ if (ctx->HasInput("H0")) {
+ auto h_dims = ctx->GetInputDim("H0");
+ PADDLE_ENFORCE(h_dims == c_dims,
+ "The dimension of Input(H0) and Input(C0) "
+ "should be the same.");
+ }
+
+ auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
+ PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
+ "Input(AttentionWeight)'s rank must be 2.");
+ PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
+ "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
+ PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
+ "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
+ if (ctx->HasInput("AttentionBias")) {
+ auto atten_b_dims = ctx->GetInputDim("AttentionBias");
+ PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
+ "Input(AttentionBias)'s rank must be 2.");
+ PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
+ "AttentionBias shapes must be 1 * 1.");
+ PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
+ "AttentionBias shapes must be 1 * 1.");
+ }
+
+ if (ctx->HasInput("AttentionScalar")) {
+ auto dims = ctx->GetInputDim("AttentionScalar");
+ PADDLE_ENFORCE_EQ(dims.size(), 2,
+ "Input(AttentionScalar)'s rank must be 2.");
+ PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
+ PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
+ }
+
+ if (ctx->HasInput("AttentionScalarBias")) {
+ auto dims = ctx->GetInputDim("AttentionScalarBias");
+ PADDLE_ENFORCE(
+ ctx->HasInput("AttentionScalar"),
+ "AttentionScalar should not be null when have AttentionScalarBias.");
+ PADDLE_ENFORCE_EQ(dims.size(), 2,
+ "Input(AttentionScalarBias)'s rank must be 2.");
+ PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
+ PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
+ }
+
+ framework::DDim out_dims({x_dims[0], D});
+ ctx->SetOutputDim("Hidden", out_dims);
+ ctx->SetOutputDim("Cell", out_dims);
+ ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
+ ctx->SetOutputDim("LSTMX", {1, M});
+ ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
+ // AttentionFCOut should be reshape as (maxseqlen,1) in runtime
+ ctx->ShareLoD("X", "Hidden");
+ ctx->ShareLoD("X", "Cell");
+}
+
+framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const {
+ return framework::OpKernelType(
+ framework::ToDataType(ctx.Input("X")->type()),
+ ctx.device_context());
+}
+
+void AttentionLSTMOpMaker::Make() {
+ AddInput("X",
+ "(LoDTensor) the input is a LodTensor, which support "
+ "variable-time length input sequence. The underlying tensor in "
+ "this LoDTensor is a matrix with shape (T X M), where T is the "
+ "total time steps in this mini-batch, M is the dim size of x.");
+ AddInput("C0",
+ "(Tensor) LSTM C0"
+ "This is a tensor with shape (N x D), where N is the batch size, D "
+ "is the gate size."
+ "C0 is necessary because of attention.");
+ AddInput("H0",
+ "(Tensor, optional) LSTM H0"
+ "This is a tensor with shape (N x D), where N is the "
+ "batch size and D is the gate size.")
+ .AsDispensable();
+ AddInput("AttentionWeight",
+ "(Tensor) the weights of attention fc. Always relu the fc result."
+ "The shape is ((M+D) x 1), where M is the dim size of x, D is the "
+ "gate size of LSTM.");
+ AddInput("AttentionBias",
+ "(Tensor, optional) the bias of attention fc."
+ "The shape is (1 x 1)")
+ .AsDispensable();
+ AddInput("AttentionScalar",
+ "(Tensor, optional) the scalar on the result of attentioned fc. "
+ "Always relu the Scalar."
+ "The shape is (1 x 1)")
+ .AsDispensable();
+ AddInput("AttentionScalarBias",
+ "(Tensor, optional) the scalar bias of attention fc."
+ "The shape is (1 x 1)")
+ .AsDispensable();
+ AddInput("LSTMWeight",
+ "(Tensor) the combined weight of LSTM"
+ " - The shape is ((D+M) x 4D), where D is the hidden gate size, M "
+ "is the dim size of x"
+ " - Weight = {W_forget, W_input, W_output, W_cell}");
+ AddInput("LSTMBias",
+ "(Tensor) the combined bias of LSTM, shape (1x4D)."
+ "Note: we should add the bias of hidden and context accorindg to "
+ "the same gate: "
+ "{B_forget, B_input, B_output, B_cell}");
+ AddOutput("Hidden",
+ "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
+ "The shape is (T x D), and lod is the same with the `Input`.");
+ AddOutput("Cell",
+ "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
+ "The shape is (T x D), and lod is the same with the `Input`.");
+ AddOutput("AttentionedX",
+ "(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
+ " where T is the total time steps in this mini-batch,"
+ " D is the hidden size.")
+ .AsIntermediate();
+ AddOutput("AttentionFCOut",
+ "(Tensor) (max_seq_len, 1), compute at each step.")
+ .AsIntermediate();
+ AddOutput("LSTMX",
+ "(Tensor) the input X of LSTM for each step."
+ "Shape is (1 x M), where M is the x frame size")
+ .AsIntermediate();
+ AddOutput(
+ "LSTMOUT",
+ "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
+ "Shape is (1 x 4D), where M is the x frame size")
+ .AsIntermediate();
+ AddAttr("gate_activation",
+ "(string, default: sigmoid)"
+ "The activation for input gate, forget gate and output "
+ "gate, `sigmoid` by default.")
+ .SetDefault("sigmoid")
+ .InEnum({"sigmoid", "tanh", "relu", "identity"});
+ AddAttr("cell_activation",
+ "(string, default: tanh)"
+ "The activation for cell output, `tanh` by defalut.")
+ .SetDefault("tanh")
+ .InEnum({"sigmoid", "tanh", "relu", "identity"});
+ AddAttr("candidate_activation",
+ "(string, default: tanh)"
+ "The activation for candidate hidden state, "
+ "`tanh` by default.")
+ .SetDefault("tanh")
+ .InEnum({"sigmoid", "tanh", "relu", "identity"});
+ AddComment(R"DOC(
+Attention Long-Short Term Memory (LSTM) Operator.
+
+Attention part:
+concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D))
+
+tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu
+
+fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu
+
+dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M)
+
+LSTM part:
+use lstm_x_t as input and compute as standard LSTM.
+
+)DOC");
+}
+
+// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
+template
+inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
+ if (bias) {
+ for (int i = 0; i < n; ++i) {
+ y[i] = x[i] + bias[0];
+ }
+ math::vec_relu(n, y, y);
+ } else {
+ math::vec_relu(n, x, y);
+ }
+}
+
+template
+inline void vec_softmax(const math::BlasT& blas, const int n,
+ const T* x, T* y) {
+ T scalar = x[0];
+ // max
+ for (int i = 1; i < n; ++i) {
+ scalar = scalar < x[i] ? x[i] : scalar;
+ }
+
+ // sub
+ for (int i = 0; i < n; ++i) {
+ y[i] = x[i] - scalar;
+ }
+
+ // exp
+ blas.VEXP(n, y, y);
+
+ // sum
+ scalar = T(0);
+ for (int i = 0; i < n; ++i) {
+ scalar += y[i];
+ }
+
+ // scale
+ blas.SCAL(n, static_cast(1) / scalar, y);
+}
+
+template
+class AttentionLSTMKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ using DeviceContext = paddle::platform::CPUDeviceContext;
+
+ auto* x = ctx.Input("X");
+ auto* h0 = ctx.Input("H0");
+ auto* c0 = ctx.Input("C0");
+ auto* atten_w = ctx.Input("AttentionWeight");
+ auto* atten_b = ctx.Input("AttentionBias");
+ auto* atten_scalar = ctx.Input("AttentionScalar");
+ auto* atten_scalar_bias = ctx.Input("AttentionScalarBias");
+ auto* lstm_w = ctx.Input("LSTMWeight");
+ auto* lstm_b = ctx.Input("LSTMBias");
+
+ auto* hidden_out = ctx.Output("Hidden");
+ auto* cell_out = ctx.Output("Cell");
+ auto* atted_x = ctx.Output("AttentionedX");
+ auto* fc_out = ctx.Output("AttentionFCOut");
+ auto* lstm_x = ctx.Output("LSTMX");
+ auto* lstm_out = ctx.Output("LSTMOUT");
+
+ // some shape should be reshape here since infershape can not get lod info
+ auto x_lod = x->lod();
+ const int N = x_lod[0].size() - 1; // batch size
+ auto x_dims = x->dims(); // T x M
+ auto w_dims = lstm_w->dims(); // (D+M) x 4D
+ const int total_T = x_dims[0];
+ const int M = x_dims[1]; // x frame size
+ const int D = w_dims[1] / 4; // gate frame size
+ const int D2 = D * 2;
+ const int D3 = D * 3;
+ const int D4 = w_dims[1];
+ int max_seq_len = x_lod[0][1];
+ for (int i = 1; i < N; ++i) {
+ int len = x_lod[0][i + 1] - x_lod[0][i];
+ max_seq_len = max_seq_len < len ? len : max_seq_len;
+ }
+ PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1.");
+ PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
+ fc_out->Resize({max_seq_len, 1});
+
+ math::VecActivations act_functor;
+ std::function act_gate, act_cell, act_cand;
+ act_gate = act_functor(ctx.Attr("gate_activation"));
+ act_cell = act_functor(ctx.Attr("cell_activation"));
+ act_cand = act_functor(ctx.Attr("candidate_activation"));
+
+ const T* x_data = x->data();
+ const T* h0_data = h0 ? h0->data() : NULL;
+ const T* c0_data = c0->data();
+ const T* lstm_w_data = lstm_w->data();
+ const T* lstm_b_data = lstm_b->data();
+ const T* atten_w_data = atten_w->data();
+ const T* atten_b_data = atten_b ? atten_b->data() : NULL;
+ const T* atten_scalar_data = atten_scalar ? atten_scalar->data() : NULL;
+ const T* atten_scalar_bias_data =
+ atten_scalar_bias ? atten_scalar_bias->data() : NULL;
+
+ T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace());
+ T* cell_out_data = cell_out->mutable_data(ctx.GetPlace());
+ T* atted_x_data = atted_x->mutable_data(ctx.GetPlace());
+ T* fc_out_data = fc_out->mutable_data(ctx.GetPlace());
+ T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace());
+ T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace());
+
+ // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
+ auto blas = math::GetBlas(ctx);
+ math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data,
+ atted_x_data, atten_b_data);
+
+ const T* cur_atten_x_data = atted_x_data;
+ const T* cur_x_data = x_data;
+ const T* prev_cell_data = NULL;
+ const T* prev_hidden_data = NULL;
+ T* cur_cell_out_data = cell_out_data;
+ T* cur_hidden_out_data = hidden_out_data;
+ for (int i = 0; i < N; ++i) {
+ int seq_len = x_lod[0][i + 1] - x_lod[0][i];
+ prev_cell_data = c0_data + i * D;
+ prev_hidden_data = h0_data ? h0_data + i * D : NULL;
+ for (int step = 0; step < seq_len; ++step) {
+ /// 1. compute attention vector
+ // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
+ T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
+ // 1b. add cell bias and relu
+ bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
+ // 1c. fc scalar
+ if (atten_scalar_data) {
+ blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
+ bias_relu(seq_len, fc_out_data, atten_scalar_bias_data,
+ fc_out_data);
+ }
+ // 1d. softmax
+ vec_softmax(blas, seq_len, fc_out_data, fc_out_data);
+ // mul x(seq_len*M) and sum pool
+ math::FCCompute(blas, 1, M, seq_len, fc_out_data,
+ cur_x_data, lstm_x_data);
+
+ /// 2. compute LSTM step
+ // lstm weight : concat[forget , input , output , tilde]
+ // shape : (D + M) x (4 * D)
+ // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
+ blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data);
+ if (prev_hidden_data) {
+ blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1),
+ prev_hidden_data, D, lstm_w_data, D4, static_cast(1),
+ lstm_out_data, D4);
+ }
+ // since input is 1xM, so can use add bias
+ blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);
+
+ // gate act: sigmoid
+ act_gate(D3, lstm_out_data, lstm_out_data);
+ // candicate act: tanh
+ act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
+
+ // a = forget * prev_cell
+ blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
+
+ // b = input * tilde
+ blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
+
+ // cell_out = a + b
+ blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
+
+ // state act tanh(cell_out) * output_gate
+ act_cell(D, cur_cell_out_data, lstm_out_data);
+ blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
+
+ prev_hidden_data = cur_hidden_out_data;
+ prev_cell_data = cur_cell_out_data;
+ cur_cell_out_data = cur_cell_out_data + D;
+ cur_hidden_out_data = cur_hidden_out_data + D;
+ }
+ cur_x_data = cur_x_data + seq_len * M;
+ cur_atten_x_data = cur_atten_x_data + seq_len;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp,
+ ops::AttentionLSTMOpMaker,
+ paddle::framework::DefaultGradOpDescMaker);
+
+REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel,
+ ops::AttentionLSTMKernel);
diff --git a/paddle/fluid/operators/attention_lstm_op.h b/paddle/fluid/operators/attention_lstm_op.h
new file mode 100644
index 0000000000..6ede3a7f3c
--- /dev/null
+++ b/paddle/fluid/operators/attention_lstm_op.h
@@ -0,0 +1,41 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using LoDTensor = framework::LoDTensor;
+using Tensor = framework::Tensor;
+
+class AttentionLSTMOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override;
+
+ protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override;
+};
+
+class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override;
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/fluid/operators/fusion_lstm_op.h b/paddle/fluid/operators/fusion_lstm_op.h
index 39dc09b4d1..7f79601602 100644
--- a/paddle/fluid/operators/fusion_lstm_op.h
+++ b/paddle/fluid/operators/fusion_lstm_op.h
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
-// #include
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h
index 8dcf7c99f3..da185d93c0 100644
--- a/paddle/fluid/operators/math/blas.h
+++ b/paddle/fluid/operators/math/blas.h
@@ -90,6 +90,11 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
+ template
+ void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
+ T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
+ int ldc) const;
+
#ifdef PADDLE_WITH_MKLML
template
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
@@ -109,6 +114,10 @@ class Blas {
void GEMM_FREE(T* data) const;
#endif
+ template
+ void MatMul(const int M, const int N, const int K, const T* A, const T* B,
+ T* C) const;
+
template
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha,
@@ -140,10 +149,19 @@ class Blas {
template
void VCOPY(int n, const T* x, T* y) const;
+ template
+ void VEXP(int n, const T* x, T* y) const;
+
template
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const;
+ template
+ T DOT(int n, const T* x, const T* y) const;
+
+ template
+ void SCAL(int n, const T a, T* x) const;
+
template
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C,
@@ -215,11 +233,26 @@ class BlasT : private Blas {
Base()->template VCOPY(args...);
}
+ template
+ void VEXP(ARGS... args) const {
+ Base()->template VEXP(args...);
+ }
+
template
void GEMV(ARGS... args) const {
Base()->template GEMV(args...);
}
+ template
+ T DOT(ARGS... args) const {
+ return Base()->template DOT(args...);
+ }
+
+ template
+ void SCAL(ARGS... args) const {
+ Base()->template SCAL(args...);
+ }
+
template
void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM(args...);
diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h
index dc77b6d793..e1df78d11e 100644
--- a/paddle/fluid/operators/math/blas_impl.h
+++ b/paddle/fluid/operators/math/blas_impl.h
@@ -73,6 +73,16 @@ struct CBlas {
platform::dynload::cblas_sgemv(args...);
}
+ template
+ static float DOT(ARGS... args) {
+ return platform::dynload::cblas_sdot(args...);
+ }
+
+ template
+ static void SCAL(ARGS... args) {
+ platform::dynload::cblas_sscal(args...);
+ }
+
template
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...);
@@ -87,6 +97,11 @@ struct CBlas {
static void VMUL(ARGS... args) {
platform::dynload::vsMul(args...);
}
+
+ template
+ static void VEXP(ARGS... args) {
+ platform::dynload::vsExp(args...);
+ }
};
template <>
@@ -138,6 +153,16 @@ struct CBlas {
platform::dynload::cblas_dgemv(args...);
}
+ template
+ static double DOT(ARGS... args) {
+ return platform::dynload::cblas_ddot(args...);
+ }
+
+ template
+ static void SCAL(ARGS... args) {
+ platform::dynload::cblas_dscal(args...);
+ }
+
template
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...);
@@ -152,6 +177,11 @@ struct CBlas {
static void VMUL(ARGS... args) {
platform::dynload::vdMul(args...);
}
+
+ template
+ static void VEXP(ARGS... args) {
+ platform::dynload::vdExp(args...);
+ }
};
#else
@@ -210,6 +240,9 @@ struct CBlas {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
+ static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
+ static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
+ static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
@@ -217,64 +250,6 @@ struct CBlas {
#endif
};
-template
-inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
- bool transb, const T &alpha, const T &beta) {
-#ifdef PADDLE_WITH_LIBXSMM
- // Refer to https://github.com/hfp/libxsmm/blob/master/README.md
- // But the threshold is custom
- constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
- if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
- std::abs(alpha - static_cast(1) >
- std::numeric_limits::epsilon()) ||
- std::abs(beta) > std::numeric_limits::epsilon()) {
- return false;
- } else {
- return true;
- }
-#endif
- return false;
-}
-
-template <>
-inline bool UseXSMM(const int &m, const int &n, const int &k,
- bool transa, bool transb,
- const platform::float16 &alpha,
- const platform::float16 &beta) {
- return false;
-}
-
-template
-inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
- CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
- const T *A, int lda, const T *B, int ldb, T beta, T *C,
- int ldc) {
-#ifdef PADDLE_WITH_LIBXSMM
- if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
- beta)) {
- // Note: SMM use ColMajor
- const char transa = 'N';
- const char transb = 'N';
- CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
- &beta, C, &ldc);
- return;
- }
-#endif
-
-#ifdef PADDLE_MKL_SPLIT_GEMM
- constexpr int bs = 2;
- if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
- for (int off = 0; off < M; off += bs) {
- CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
- A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
- }
- return;
- }
-#endif
- CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
- beta, C, ldc);
-}
-
#ifdef PADDLE_WITH_MKLML
template <>
template
@@ -319,8 +294,8 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA,
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
- GEMM_WARP(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
- beta, C, ldc);
+ CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
+ beta, C, ldc);
}
template <>
@@ -329,9 +304,20 @@ void Blas::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
- GEMM_WARP(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
- transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
- lda, B, ldb, beta, C, ldc);
+ CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
+ transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
+ lda, B, ldb, beta, C, ldc);
+}
+
+template <>
+template
+void Blas::GEMM(CBLAS_TRANSPOSE transA,
+ CBLAS_TRANSPOSE transB, int M,
+ int N, int K, T alpha, const T *A,
+ int lda, const T *B, int ldb,
+ T beta, T *C, int ldc) const {
+ CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
+ beta, C, ldc);
}
template
@@ -399,6 +385,47 @@ void Blas::VMUL(int n, const T *x, const T *y,
#endif
}
+template <>
+template
+void Blas::VEXP(int n, const T *x, T *y) const {
+#ifdef PADDLE_WITH_MKLML
+ CBlas::VEXP(n, x, y);
+#else
+ // try to find if openblas support vexp
+ for (int i = 0; i < n; ++i) {
+ y[i] = std::exp(x[i]);
+ }
+#endif
+}
+
+template <>
+template
+T Blas::DOT(int n, const T *x, const T *y) const {
+#ifdef PADDLE_WITH_MKLML
+ return CBlas::DOT(n, x, 1, y, 1);
+#else
+ // try to find if openblas support cblas_dot
+ T sum = 0;
+ for (int i = 0; i < n; ++i) {
+ sum += x[i] * y[i];
+ }
+ return sum;
+#endif
+}
+
+template <>
+template
+void Blas::SCAL(int n, const T a, T *x) const {
+#ifdef PADDLE_WITH_MKLML
+ CBlas::SCAL(n, a, x, 1);
+#else
+ // try to find if openblas support cblas_scal
+ for (int i = 0; i < n; ++i) {
+ x[i] = a * x[i];
+ }
+#endif
+}
+
template <>
template
void Blas::GEMV(bool trans_a, int M, int N, T alpha,
@@ -440,6 +467,42 @@ void Blas::BatchedGEMM(
#endif
}
+template
+template
+void Blas::MatMul(const int M, const int N, const int K,
+ const T *A, const T *B, T *C) const {
+ this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
+ static_cast(1), A, K, B, N, static_cast(0), C,
+ N);
+}
+
+template <>
+template
+void Blas::MatMul(const int M, const int N,
+ const int K, const T *A,
+ const T *B, T *C) const {
+#ifdef PADDLE_WITH_LIBXSMM
+ // Refer to https://github.com/hfp/libxsmm/blob/master/README.md
+ // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
+
+ // Since the matrix is very small,
+ // so the unit of calculation is already very fast,
+ // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
+ // use xsmm directly.
+ // Note: SMM use ColMajor
+ const char transa = 'N';
+ const char transb = 'N';
+ const T alpha = static_cast(1);
+ const T beta = static_cast(0);
+ CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
+ C, &N);
+ return;
+#endif
+
+ CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
+ static_cast(1), A, K, B, N, static_cast(0), C, N);
+}
+
template
template
void Blas::MatMul(const framework::Tensor &mat_a,
diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h
new file mode 100644
index 0000000000..48c0da0e36
--- /dev/null
+++ b/paddle/fluid/operators/math/cpu_vec.h
@@ -0,0 +1,105 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+#define SIGMOID_THRESHOLD_MIN -40.0
+#define SIGMOID_THRESHOLD_MAX 13.0
+#define EXP_MAX_INPUT 40.0
+
+template
+inline T sigmoid(T x) {
+ return 1. / (1. + exp(-x));
+}
+
+template
+inline T tanh(T x) {
+ return 2. * sigmoid(2. * x) - 1.;
+}
+
+template
+inline void vec_identity(const int n, const T* x, T* y) {
+ // do nothing
+ return;
+}
+
+template
+inline void vec_sigmoid(const int n, const T* x, T* y) {
+ const T min = SIGMOID_THRESHOLD_MIN;
+ const T max = SIGMOID_THRESHOLD_MAX;
+ for (int i = 0; i < n; ++i) {
+ T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
+ y[i] = 1.0 / (1.0 + std::exp(-tmp));
+ }
+}
+
+template
+inline void vec_tanh(const int n, const T* x, T* y) {
+ for (int i = 0; i < n; ++i) {
+ y[i] = tanh(x[i]);
+ }
+}
+
+template
+inline void vec_relu(const int n, const T* x, T* y) {
+ for (int i = 0; i < n; ++i) {
+ y[i] = x[i] > 0 ? x[i] : 0;
+ }
+}
+
+template <>
+inline void vec_relu(const int n, const float* x,
+ float* y) {
+ // TODO(TJ): complete me
+ for (int i = 0; i < n; ++i) {
+ y[i] = x[i] > 0 ? x[i] : 0;
+ }
+}
+
+template <>
+inline void vec_relu(const int n, const float* x,
+ float* y) {
+ // TODO(TJ): complete me
+ for (int i = 0; i < n; ++i) {
+ y[i] = x[i] > 0 ? x[i] : 0;
+ }
+}
+
+template
+class VecActivations {
+ public:
+ std::function operator()(
+ const std::string& type) {
+ if (type == "sigmoid") {
+ return vec_sigmoid;
+ } else if (type == "relu") {
+ return vec_relu;
+ } else if (type == "tanh") {
+ return vec_tanh;
+ } else if (type == "identity" || type == "") {
+ return vec_identity;
+ }
+ PADDLE_THROW("Not support type %s.", type);
+ }
+};
+
+} // namespace math
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h
index 8600fa9e2c..1f5a49c0ab 100644
--- a/paddle/fluid/operators/math/fc_compute.h
+++ b/paddle/fluid/operators/math/fc_compute.h
@@ -25,17 +25,25 @@ namespace math {
template
inline void FCCompute(const BlasT& blas, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
- const T* B = NULL) {
- blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), X, W,
- static_cast(0), Y);
- if (B) {
+ const T* B = NULL, bool relu = false) {
+ blas.MatMul(M, N, K, X, W, Y);
+ if (B == NULL) {
+ return;
+ }
+
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
- for (int i = 0; i < M; i++) {
- blas.AXPY(N, static_cast(1), B, Y + i * N);
- }
+ for (int i = 0; i < M; i++) {
+ blas.AXPY(N, static_cast(1), B, Y + i * N);
}
+
+ if (!relu) {
+ return;
+ }
+
+ // TODO(TJ): fuse relu
+ LOG(FATAL) << "Not implemented!";
}
} // namespace math
diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h
index f730a9746d..01308e416a 100644
--- a/paddle/fluid/operators/sampling_id_op.h
+++ b/paddle/fluid/operators/sampling_id_op.h
@@ -54,7 +54,7 @@ class SamplingIdKernel : public framework::OpKernel {
static_cast(context.Attr("max")));
std::vector ids(batch_size);
- for (size_t i = 0; i < batch_size; ++i) {
+ for (int i = 0; i < batch_size; ++i) {
T r = dist(engine);
int idx = width - 1;
for (int j = 0; j < width; ++j) {
@@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel {
break;
}
}
- ids[i] = ins_vector[i * width + idx];
+ ids[i] = ins_vector[idx];
}
std::vector out_dim;
diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc
new file mode 100644
index 0000000000..3f4b48bc73
--- /dev/null
+++ b/paddle/fluid/operators/stack_op.cc
@@ -0,0 +1,28 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/operators/stack_op.h"
+
+namespace plat = paddle::platform;
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
+ ops::StackGradOpDescMaker);
+REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
+
+REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel,
+ ops::StackKernel);
+
+REGISTER_OP_CPU_KERNEL(stack_grad,
+ ops::StackGradKernel,
+ ops::StackGradKernel);
diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu
new file mode 100644
index 0000000000..92c1bde2bc
--- /dev/null
+++ b/paddle/fluid/operators/stack_op.cu
@@ -0,0 +1,25 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/operators/stack_op.h"
+
+namespace plat = paddle::platform;
+namespace ops = paddle::operators;
+
+REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel,
+ ops::StackKernel);
+
+REGISTER_OP_CUDA_KERNEL(stack_grad,
+ ops::StackGradKernel,
+ ops::StackGradKernel);
diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h
new file mode 100644
index 0000000000..c777d5feae
--- /dev/null
+++ b/paddle/fluid/operators/stack_op.h
@@ -0,0 +1,278 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/for_range.h"
+
+#ifdef __NVCC__
+#include
+#include "paddle/fluid/framework/array.h"
+#endif
+
+namespace paddle {
+namespace operators {
+
+class StackOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext *ctx) const override {
+ PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
+ "Number of Inputs(X) must be larger than 0");
+ PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist.");
+
+ auto input_dims = ctx->GetInputsDim("X");
+ for (size_t i = 1; i < input_dims.size(); ++i) {
+ PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
+ "Dims of all Inputs(X) must be the same");
+ }
+
+ // Only lod of X[0] would be shared with Y
+ ctx->ShareLoD("X", /*->*/ "Y");
+
+ int axis = ctx->Attrs().Get("axis");
+ int rank = input_dims[0].size();
+ PADDLE_ENFORCE(
+ axis >= -(rank + 1) && axis < rank + 1,
+ "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
+ if (axis < 0) axis += (rank + 1);
+
+ auto vec = framework::vectorize2int(input_dims[0]);
+ vec.insert(vec.begin() + axis, input_dims.size());
+ ctx->SetOutputDim("Y", framework::make_ddim(vec));
+ }
+};
+
+class StackOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddInput("X", "The input of stack op.").AsDuplicable();
+ AddOutput("Y", "The output of stack op.");
+ AddAttr("axis",
+ "The axis along which all of the Inputs(X) should be stacked.")
+ .SetDefault(0);
+ AddComment(R"DOC(
+ Stack Operator.
+
+ Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
+ )DOC");
+ }
+};
+
+template
+struct StackFunctor {
+ HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post)
+ : x_(x), y_(y), n_(n), post_(post) {}
+
+ HOSTDEVICE void operator()(int idx) {
+ int i = idx / (n_ * post_);
+ int which_x = idx / post_ - i * n_;
+ int x_index = i * post_ + idx % post_;
+ y_[idx] = x_[which_x][x_index];
+ }
+
+ private:
+ VecXType x_;
+ T *y_;
+ int n_;
+ int post_;
+};
+
+template
+struct StackGradFunctor {
+ HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post)
+ : dx_(dx), dy_(dy), n_(n), post_(post) {}
+
+ HOSTDEVICE void operator()(int idx) {
+ int i = idx / (n_ * post_);
+ int which_x = idx / post_ - i * n_;
+ int x_index = i * post_ + idx % post_;
+ dx_[which_x][x_index] = dy_[idx];
+ }
+
+ private:
+ VecDxType dx_;
+ const T *dy_;
+ int n_;
+ int post_;
+};
+
+template
+static inline void StackFunctorForRange(const DeviceContext &ctx,
+ const VecXType &x, T *y, int total_num,
+ int n, int post) {
+ platform::ForRange for_range(ctx, total_num);
+ for_range(StackFunctor(x, y, n, post));
+}
+
+template
+static inline void StackGradFunctorForRange(const DeviceContext &ctx,
+ const VecDxType &dx, const T *dy,
+ int total_num, int n, int post) {
+ platform::ForRange for_range(ctx, total_num);
+ for_range(StackGradFunctor(dx, dy, n, post));
+}
+
+template
+class StackKernel : public framework::OpKernel {
+ using Tensor = framework::LoDTensor;
+
+ public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ auto x = ctx.MultiInput("X");
+ auto *y = ctx.Output("Y");
+
+ int axis = ctx.Attr("axis");
+ if (axis < 0) axis += (x[0]->dims().size() + 1);
+
+ int n = static_cast(x.size());
+ auto *y_data = y->mutable_data(ctx.GetPlace());
+ std::vector x_datas(n);
+ for (int i = 0; i < n; i++) x_datas[i] = x[i]->data();
+
+ int pre = 1, post = 1;
+ auto &dim = x[0]->dims();
+ for (auto i = 0; i < axis; ++i) pre *= dim[i];
+ for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
+ int total_num = pre * n * post;
+
+ auto &dev_ctx = ctx.template device_context();
+ constexpr auto kMaxThreshold = 16;
+ if (std::is_same::value ||
+ n > kMaxThreshold) {
+#ifdef __NVCC__
+ VLOG(10) << "Stack more than " << kMaxThreshold
+ << " tensors on GPU may be slow.";
+ thrust::device_vector device_x_vec(x_datas);
+ auto x_data_arr = device_x_vec.data().get();
+#else
+ auto x_data_arr = x_datas.data();
+#endif
+ StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
+#ifdef __NVCC__
+ // Wait() must be called because device_x_vec may be destructed before
+ // kernel ends
+ dev_ctx.Wait();
+#endif
+ }
+#ifdef __NVCC__
+ else { // NOLINT
+ framework::Array x_data_arr;
+ for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i];
+ StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
+ }
+#endif
+ }
+};
+
+class StackOpGrad : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext *ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
+ "Input(Y@Grad) must exist.");
+
+ int axis = ctx->Attrs().Get("axis");
+ auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
+ int rank = dy_dim.size();
+ PADDLE_ENFORCE(axis >= -rank && axis < rank,
+ "Attr(axis) must be inside [-rank, rank), where rank = %d",
+ rank);
+ if (axis < 0) axis += rank;
+
+ PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(),
+ static_cast(dy_dim[axis]),
+ "Number of Outputs(X@Grad) is wrong");
+ auto vec = framework::vectorize2int(dy_dim);
+ vec.erase(vec.begin() + axis);
+ ctx->SetOutputsDim(
+ framework::GradVarName("X"),
+ std::vector(dy_dim[axis], framework::make_ddim(vec)));
+ }
+};
+
+class StackGradOpDescMaker : public framework::SingleGradOpDescMaker {
+ public:
+ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+ std::unique_ptr Apply() const override {
+ std::unique_ptr op(new framework::OpDesc());
+ op->SetType("stack_grad");
+ op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
+ op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
+ op->SetAttrMap(Attrs());
+ return op;
+ }
+};
+
+template
+class StackGradKernel : public framework::OpKernel {
+ using Tensor = framework::LoDTensor;
+
+ public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ auto *dy = ctx.Input(framework::GradVarName("Y"));
+ auto dx = ctx.MultiOutput(framework::GradVarName("X"));
+ int axis = ctx.Attr("axis");
+ if (axis < 0) axis += dy->dims().size();
+
+ int n = dy->dims()[axis];
+ std::vector dx_datas(n); // NOLINT
+ for (int i = 0; i < n; i++) {
+ dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace());
+ }
+ auto dy_data = dy->data();
+
+ int pre = 1;
+ for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
+ int total_num = dy->numel();
+ int post = total_num / (n * pre);
+
+ auto &dev_ctx = ctx.template device_context();
+ constexpr auto kMaxThreshold = 16;
+ if (std::is_same::value ||
+ n > kMaxThreshold) {
+#ifdef __NVCC__
+ VLOG(10) << "Stack more than " << kMaxThreshold
+ << " tensors on GPU may be slow.";
+ thrust::device_vector device_dx_vec(dx_datas);
+ auto dx_data_arr = device_dx_vec.data().get();
+#else
+ auto dx_data_arr = dx_datas.data();
+#endif
+ StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
+ post);
+#ifdef __NVCC__
+ // Wait() must be called because device_dx_vec may be destructed before
+ // kernel ends
+ dev_ctx.Wait();
+#endif
+ }
+#ifdef __NVCC__
+ else { // NOLINT
+ framework::Array