From 7c671466320ac257c6a8637ac283512702d5eecd Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 9 May 2018 14:12:46 +0000 Subject: [PATCH 001/140] Add forward and backward. --- paddle/fluid/operators/sequence_pad_op.cc | 131 ++++++++++++++++++++++ paddle/fluid/operators/sequence_pad_op.cu | 23 ++++ paddle/fluid/operators/sequence_pad_op.h | 97 ++++++++++++++++ 3 files changed, 251 insertions(+) create mode 100644 paddle/fluid/operators/sequence_pad_op.cc create mode 100644 paddle/fluid/operators/sequence_pad_op.cu create mode 100644 paddle/fluid/operators/sequence_pad_op.h diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc new file mode 100644 index 0000000000..183d38fcc9 --- /dev/null +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -0,0 +1,131 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_pad_op.h" + +namespace paddle { +namespace operators { + +class SequencePadOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequencePadOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequencePadOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2, + "Only support 2-D tensor, rank of Input(X) should be 2."); + + auto out_dims = x_dims; + + if (ctx->IsRuntime()) { + framework::Variable* x_var = + boost::get(ctx->GetInputVarPtrs("X")[0]); + + auto& x_lod = x_var->Get().lod(); + + PADDLE_ENFORCE_GE(x_lod.size(), 1, + "Input(X) should be sequences containing lod."); + + auto last_level_lod = x_lod[x_lod.size() - 1]; + size_t max_len = 0; + + for (size_t i = 1; i < last_level_lod.size(); ++i) { + auto seq_len = last_level_lod[i] - last_level_lod[i - 1]; + max_len = max_len < seq_len ? seq_len : max_len; + } + + out_dims[0] = max_len * (last_level_lod.size() - 1); + } else { + framework::VarDesc* x_desc = + boost::get(ctx->GetInputVarPtrs("X")[0]); + PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1, + "Input(X) should be sequences containing lod."); + out_dims[0] = -1; + } + + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequencePadOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, default LoDTensor) Input variable which " + "should contain lod information. Length of each sequence would " + "be computed from the most bottom level lod."); + AddOutput("Out", + "(Tensor) Output variable which would be a common tensor " + "without lod. Each sequence would be padded to the maximum " + "length."); + AddAttr("pad_value", + "(float, default 0.0) Value to be padded " + "to the end of each sequence."); + AddComment(R"DOC( + + )DOC"); + } +}; + +class SequencePadGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequencePadGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequencePadGradOp should not be null."); + + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_pad, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_pad_grad, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_pad_op.cu b/paddle/fluid/operators/sequence_pad_op.cu new file mode 100644 index 0000000000..a2fa62957e --- /dev/null +++ b/paddle/fluid/operators/sequence_pad_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_pad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_pad, + ops::SequencePadOpKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_pad_grad, + ops::SequencePadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h new file mode 100644 index 0000000000..b36465d8e7 --- /dev/null +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -0,0 +1,97 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +// @TODO clean code +template +class SequencePadOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_ptr = ctx.Input("X"); + auto* out_ptr = ctx.Output("Out"); + + out_ptr->mutable_data(ctx.GetPlace()); + + T pad_value = static_cast(ctx.Attr("pad_value")); + + math::SetConstant set_func; + set_func(ctx.template device_context(), out_ptr, pad_value); + + auto& x_lod = x_ptr->lod(); + auto& x_last_level_lod = x_lod[x_lod.size() - 1]; + auto seq_num = x_last_level_lod.size() - 1; + auto max_len = out_ptr->dims()[0] / seq_num; + + PADDLE_ENFORCE_EQ(max_len * seq_num, out_ptr->dims()[0], + "First dimension of `Out` should be equal to " + "maximum length mulplied by sequence number."); + + for (size_t i = 1; i < x_last_level_lod.size(); ++i) { + auto x_start = x_last_level_lod[i - 1]; + auto x_end = x_last_level_lod[i]; + auto out_start = (i - 1) * max_len; + auto out_end = out_start + (x_end - x_start); + auto x_sub_tensor = x_ptr->Slice(x_start, x_end); + auto out_sub_tensor = out_ptr->Slice(out_start, out_end); + framework::TensorCopy(x_sub_tensor, ctx.GetPlace(), &out_sub_tensor); + } + } +}; + +template +class SequencePadGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_ptr = ctx.Input("X"); + auto* g_out_ptr = ctx.Input(framework::GradVarName("Out")); + auto* g_x_ptr = ctx.Output(framework::GradVarName("X")); + + math::SetConstant set_func; + set_func(ctx.template device_context(), g_x_ptr, + static_cast(0)); + + auto& x_lod = x_ptr->lod(); + auto& x_last_level_lod = x_lod[x_lod.size() - 1]; + auto seq_num = x_last_level_lod.size() - 1; + int64_t max_len = g_out_ptr->dims()[0] / seq_num; + + PADDLE_ENFORCE_EQ(max_len * seq_num, g_out_ptr->dims()[0], + "First dimension of `Out` should be equal to " + "maximum length mulplied by sequence number."); + + for (size_t i = 1; i < x_last_level_lod.size(); ++i) { + auto x_start = x_last_level_lod[i - 1]; + auto x_end = x_last_level_lod[i]; + auto out_start = (i - 1) * max_len; + auto out_end = out_start + (x_end - x_start); + + auto g_out_sub = g_out_ptr->Slice(out_start, out_end); + auto g_x_sub = g_x_ptr->Slice(x_start, x_end); + framework::TensorCopy(g_x_sub, ctx.GetPlace(), &g_out_sub); + } + } +}; + +} // namespace operators +} // namespace paddle From 0797246704aad1392f8d410e5ba179db8592d2e0 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 11 May 2018 09:55:59 +0000 Subject: [PATCH 002/140] Enhance sequence_padding functor (CPU and GPU). --- .../fluid/operators/math/sequence_padding.cc | 203 +++++++-------- .../fluid/operators/math/sequence_padding.cu | 231 +++++++----------- .../fluid/operators/math/sequence_padding.h | 66 +++-- paddle/fluid/operators/sequence_pad_op.cc | 40 +-- paddle/fluid/operators/sequence_pad_op.h | 101 +++++--- paddle/fluid/operators/warpctc_op.h | 5 +- 6 files changed, 330 insertions(+), 316 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index d63c6c4ed5..2dd2cafa23 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,128 +18,111 @@ namespace paddle { namespace operators { namespace math { -template -class PaddingLoDTensorFunctor { - public: - void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& seq, framework::Tensor* padding, - bool norm_by_times) { - auto lod = seq.lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The LoD of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq.dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding->dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequence_length, num_sequences, sequence_width]."); - - const int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be the " - "maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be the " - "number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq.numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - const T* seq_data = seq.data(); - T* padding_data = padding->data(); - for (int64_t i = 0; i < max_sequence_length; ++i) { - for (int64_t j = 0; j < num_sequences; ++j) { - int64_t start_pos = abs_offset_lod[level][j]; - int64_t sequence_length = abs_offset_lod[level][j + 1] - start_pos; - if (i < sequence_length) { - // i > 0 => sequence_length > 0 - T scale = - norm_by_times ? (1.0f / static_cast(sequence_length)) : 1.0f; - for (int64_t k = 0; k < sequence_width; ++k) { - padding_data[(i * num_sequences + j) * sequence_width + k] = - seq_data[(start_pos + i) * sequence_width + k] * scale; - } +template +void CopyDataCPU(framework::LoDTensor* seq_tensor, + framework::Tensor* padding_tensor, + const framework::Vector& abs_offset, + const int64_t& max_seq_len, const int64_t& seq_width, + bool seq_to_padding, bool norm_by_len) { + T* seq_data = seq_tensor->data(); + T* padding_data = padding_tensor->data(); + + int64_t seq_num = abs_offset.size() - 1; + + for (int64_t i = 0; i < seq_num; ++i) { + int64_t seq_start = abs_offset[i]; + int64_t seq_len = abs_offset[i + 1] - seq_start; + + T scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; + + for (int64_t j = 0; j < seq_len; ++j) { + for (int64_t k = 0; k < seq_width; ++k) { + size_t padding_offset = 0; + if (padding_layout == BATCH_LENGTH_WIDTH) { + padding_offset = (i * max_seq_len * seq_width) + j * seq_width + k; + } else { + padding_offset = (j * seq_num * seq_width) + i * seq_width + k; + } + if (seq_to_padding) { + padding_data[padding_offset] = + seq_data[(seq_start + j) * seq_width + k] * scale; } else { - memset(padding_data + (i * num_sequences + j) * sequence_width, 0, - sequence_width * sizeof(T)); + seq_data[(seq_start + j) * seq_width + k] = + padding_data[padding_offset] * scale; } } } } +} + +template +class PaddingLoDTensorFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& seq_tensor, + framework::Tensor* padding_tensor, + T padding_value = static_cast(0), + bool norm_by_times = false, size_t lod_level = 0) { + ValidateLoD(seq_tensor, lod_level); + + auto& lod = seq_tensor.lod(); + auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + + auto seq_dims = seq_tensor.dims(); + auto padding_dims = padding_tensor->dims(); + int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); + int64_t seq_num = abs_offset.size() - 1; + int64_t seq_width = seq_tensor.numel() / seq_dims[0]; + int64_t numel = max_seq_len * seq_num * seq_width; + + ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, + seq_num, seq_width, padding_layout); + + T* padding_data = padding_tensor->data(); + + memset(padding_data, padding_value, numel * sizeof(T)); + + CopyDataCPU( + const_cast(&seq_tensor), padding_tensor, + abs_offset, max_seq_len, seq_width, true /* seq_to_padding */, + norm_by_times); + } }; -template -class UnpaddingLoDTensorFunctor { +template +class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - framework::LoDTensor* seq, const framework::Tensor& padding, - bool norm_by_times) { - auto lod = seq->lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The LoD of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq->dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding.dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequnece_length, num_sequences, sequence_width]."); - - const int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be " - "the maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be " - "the number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq->numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - const T* padding_data = padding.data(); - T* seq_data = seq->data(); - for (int64_t i = 0; i < num_sequences; ++i) { - int64_t start_pos = abs_offset_lod[level][i]; - int64_t sequence_length = abs_offset_lod[level][i + 1] - start_pos; - for (int64_t j = 0; j < sequence_length; ++j) { - // sequence_width > j > 0 - T scale = - norm_by_times ? (1.0f / static_cast(sequence_length)) : 1.0f; - for (int64_t k = 0; k < sequence_width; ++k) { - seq_data[(start_pos + j) * sequence_width + k] = - padding_data[(j * num_sequences + i) * sequence_width + k] * - scale; - } - } - } + framework::LoDTensor* seq_tensor, + const framework::Tensor& padding_tensor, + bool norm_by_times = false, size_t lod_level = 0) { + ValidateLoD(*seq_tensor, lod_level); + + auto& lod = seq_tensor->lod(); + auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + + auto& seq_dims = seq_tensor->dims(); + auto& padding_dims = padding_tensor.dims(); + int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); + int64_t seq_num = abs_offset.size() - 1; + int64_t seq_width = seq_tensor->numel() / seq_dims[0]; + + ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, + seq_num, seq_width, padding_layout); + + T* seq_data = seq_tensor->data(); + memset(seq_data, static_cast(0), seq_tensor->numel() * sizeof(T)); + + CopyDataCPU( + seq_tensor, const_cast(&padding_tensor), abs_offset, + max_seq_len, seq_width, false /* seq_to_padding */, norm_by_times); } }; -template class PaddingLoDTensorFunctor; -template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 0956a0c17d..2377bef024 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -19,87 +19,76 @@ namespace paddle { namespace operators { namespace math { -template -__global__ void SequencePaddingKernel(T* padding, T* sequence, - const size_t* sequence_start_positions, - const size_t sequence_width, - const size_t max_sequence_length, - const size_t num_sequences) { +template +__global__ void SequencePaddingKernel( + T* padding_data, T* seq_data, const size_t* abs_offset, + const size_t& seq_num, const size_t& max_seq_len, const size_t& seq_width, + const PaddingLayout& padding_layout, bool norm_by_times = false, + const T& padding_value = 0) { size_t padding_idx = blockIdx.y; - size_t start_pos = sequence_start_positions[padding_idx]; - size_t sequence_length = - sequence_start_positions[padding_idx + 1] - start_pos; + size_t seq_start = abs_offset[padding_idx]; + size_t seq_len = abs_offset[padding_idx + 1] - seq_start; - size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y; - size_t padding_base_idx = - (sequence_idx * num_sequences + padding_idx) * sequence_width; - size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width; + size_t seq_idx = blockIdx.x * blockDim.y + threadIdx.y; - if (sequence_idx < sequence_length) { - T scale = NormByTimes ? (1.0f / static_cast(sequence_length)) : 1.0f; + size_t seq_offset = (seq_start + seq_idx) * seq_width; + + size_t padding_offset = 0; + + if (padding_layout == LENGTH_BATCH_WIDTH) { + padding_offset = (seq_idx * seq_num + padding_idx) * seq_width; + } else { + padding_offset = (padding_idx * max_seq_len + seq_idx) * seq_width; + } + + if (seq_idx < seq_len) { + T scale = norm_by_times ? (1.0f / static_cast(seq_len)) : 1.0f; if (Padding) { /* sequence -> padding */ - for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { - padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i]; + for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { + padding_data[padding_offset + i] = scale * seq_data[seq_offset + i]; } } else { /* padding -> sequence */ - for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { - sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i]; + for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { + seq_data[seq_offset + i] = scale * padding_data[padding_offset + i]; } } - } else if (sequence_idx < max_sequence_length) { + } else if (seq_idx < max_seq_len) { if (Padding) { /* sequence -> padding */ - for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { - padding[padding_base_idx + i] = 0; + for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { + padding_data[padding_offset + i] = padding_value; } } } } -template -class PaddingLoDTensorFunctor { +template +class PaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::LoDTensor& seq, framework::Tensor* padding, - bool norm_by_times) { - auto lod = seq.lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The lod of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq.dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding->dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequence_length, num_sequences, sequence_width]."); - - int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be the " - "maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be the " - "number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq.numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - if (!norm_by_times && num_sequences == 1UL) { - TensorCopy(seq, context.GetPlace(), context, padding); - padding->Resize(padding_dims); + const framework::LoDTensor& seq_tensor, + framework::Tensor* padding_tensor, + T padding_value = static_cast(0), + bool norm_by_times = false, size_t lod_level = 0) { + ValidateLoD(seq_tensor, lod_level); + + auto& lod = seq_tensor.lod(); + auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + + auto seq_dims = seq_tensor.dims(); + auto padding_dims = padding_tensor->dims(); + int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); + const int64_t seq_num = abs_offset.size() - 1; + const int64_t seq_width = seq_tensor.numel() / seq_dims[0]; + + ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, + seq_num, seq_width, padding_layout); + + if (!norm_by_times && seq_num == 1UL) { + TensorCopy(seq_tensor, context.GetPlace(), context, padding_tensor); + padding_tensor->Resize(padding_dims); return; } @@ -109,72 +98,46 @@ class PaddingLoDTensorFunctor { * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y; - size_t grid_dim_y = num_sequences; + size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); - const T* seq_data = seq.data(); - T* padding_data = padding->data(); - if (norm_by_times) { - SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } else { - SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } + const T* seq_data = seq_tensor.data(); + T* padding_data = padding_tensor->data(); + + SequencePaddingKernel<<>>( + padding_data, const_cast(seq_data), + abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, + seq_width, padding_layout, norm_by_times, padding_value); } }; -template -class UnpaddingLoDTensorFunctor { +template +class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - framework::LoDTensor* seq, const framework::Tensor& padding, - bool norm_by_times) { - auto lod = seq->lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The lod of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq->dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding.dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequnece_length, num_sequences, sequence_width]."); - - int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be " - "the maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be " - "the number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq->numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - if (!norm_by_times && num_sequences == 1UL) { - TensorCopy(padding, context.GetPlace(), context, seq); - seq->Resize(seq_dims); + framework::LoDTensor* seq_tensor, + const framework::Tensor& padding_tensor, + bool norm_by_times = false, size_t lod_level = 0) { + ValidateLoD(*seq_tensor, lod_level); + + auto& lod = seq_tensor->lod(); + auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + + auto seq_dims = seq_tensor->dims(); + auto padding_dims = padding_tensor.dims(); + int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); + int64_t seq_num = abs_offset.size() - 1; + int64_t seq_width = seq_tensor->numel() / seq_dims[0]; + + if (!norm_by_times && seq_num == 1UL) { + TensorCopy(padding_tensor, context.GetPlace(), context, seq_tensor); + seq_tensor->Resize(seq_dims); return; } @@ -184,32 +147,28 @@ class UnpaddingLoDTensorFunctor { * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y; - size_t grid_dim_y = num_sequences; + size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); - const T* padding_data = padding.data(); - T* seq_data = seq->data(); - if (norm_by_times) { - SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } else { - SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } + const T* padding_data = padding_tensor.data(); + T* seq_data = seq_tensor->data(); + + SequencePaddingKernel<<>>( + const_cast(padding_data), seq_data, + abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, + seq_width, padding_layout, norm_by_times); } }; -template class PaddingLoDTensorFunctor; -template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index b56e6db1eb..91d205641a 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -22,17 +22,50 @@ namespace paddle { namespace operators { namespace math { +enum PaddingLayout { BATCH_LENGTH_WIDTH, LENGTH_BATCH_WIDTH }; + inline static size_t MaximumSequenceLength(const framework::LoD& lod, const size_t level) { - const size_t num_sequences = lod[level].size() - 1; - size_t max_sequence_length = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - for (size_t i = 0; i < num_sequences; ++i) { - max_sequence_length = - std::max(max_sequence_length, - abs_offset_lod[level][i + 1] - abs_offset_lod[level][i]); + const size_t seq_num = lod[level].size() - 1; + size_t max_seq_len = 0; + auto abs_offset = framework::ToAbsOffset(lod)[level]; + for (size_t i = 0; i < seq_num; ++i) { + max_seq_len = std::max(max_seq_len, abs_offset[i + 1] - abs_offset[i]); + } + return max_seq_len; +} + +inline static void ValidateLoD(const framework::LoDTensor& seq_tensor, + const size_t& lod_level) { + PADDLE_ENFORCE(lod_level < seq_tensor.lod().size(), + "Invalid `lod_level` which should be at least 0 and less " + "than maximum lod level of `seq_tensor`."); +} + +inline static void ValidateShape(const framework::DDim& seq_tensor_dims, + const size_t& abs_offset_back_value, + const framework::DDim& padding_tensor_dims, + const int64_t& max_seq_len, + const int64_t& seq_num, + const int64_t& seq_width, + const PaddingLayout& padding_layout) { + PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), + abs_offset_back_value, + "The 1st dimension of `seq_tensor` should be equal to " + "sum of lengths of all sequences."); + + PADDLE_ENFORCE_EQ(padding_tensor_dims.size(), 3UL, + "`padding_tensor` should be a 3-D tensor."); + + if (padding_layout == BATCH_LENGTH_WIDTH) { + PADDLE_ENFORCE_EQ(padding_tensor_dims, + framework::make_ddim({seq_num, max_seq_len, seq_width})); + } else if (padding_layout == LENGTH_BATCH_WIDTH) { + PADDLE_ENFORCE_EQ(padding_tensor_dims, + framework::make_ddim({max_seq_len, seq_num, seq_width})); + } else { + PADDLE_THROW("Unsupported padding layout."); } - return max_sequence_length; } /* @@ -61,18 +94,23 @@ inline static size_t MaximumSequenceLength(const framework::LoD& lod, * * \note transposition is also done in this functor. */ -template +template class PaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, const framework::LoDTensor& seq, - framework::Tensor* padding, bool norm_by_times); + void operator()(const DeviceContext& context, + const framework::LoDTensor& seq_tensor, + framework::Tensor* padding_tensor, + T padding_value = static_cast(0), + bool norm_by_times = false, size_t lod_level = 0); }; -template +template class UnpaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, framework::LoDTensor* seq, - const framework::Tensor& padding, bool norm_by_times); + void operator()(const DeviceContext& context, + framework::LoDTensor* seq_tensor, + const framework::Tensor& padding_tensor, + bool norm_by_times = false, size_t lod_level = 0); }; } // namespace math diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index 183d38fcc9..f3a6fff0e1 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -32,7 +32,11 @@ class SequencePadOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Only support 2-D tensor, rank of Input(X) should be 2."); - auto out_dims = x_dims; + int lod_level = ctx->Attrs().Get("lod_level"); + + int64_t max_len = -1; + int64_t seq_num = -1; + int x_lod_size = -1; if (ctx->IsRuntime()) { framework::Variable* x_var = @@ -40,27 +44,31 @@ class SequencePadOp : public framework::OperatorWithKernel { auto& x_lod = x_var->Get().lod(); - PADDLE_ENFORCE_GE(x_lod.size(), 1, - "Input(X) should be sequences containing lod."); + x_lod_size = x_lod.size(); + + auto x_abs_offset = framework::ToAbsOffset(x_lod)[lod_level]; + + PADDLE_ENFORCE_EQ(x_dims[0], static_cast(x_abs_offset.back()), + "The first dimension of `X` should be equal to sum " + "of all sequences' length."); - auto last_level_lod = x_lod[x_lod.size() - 1]; - size_t max_len = 0; + seq_num = x_abs_offset.size() - 1; - for (size_t i = 1; i < last_level_lod.size(); ++i) { - auto seq_len = last_level_lod[i] - last_level_lod[i - 1]; + for (size_t i = 1; i <= seq_num; ++i) { + int64_t seq_len = x_abs_offset[i] - x_abs_offset[i - 1]; max_len = max_len < seq_len ? seq_len : max_len; } - - out_dims[0] = max_len * (last_level_lod.size() - 1); } else { framework::VarDesc* x_desc = boost::get(ctx->GetInputVarPtrs("X")[0]); - PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1, - "Input(X) should be sequences containing lod."); - out_dims[0] = -1; + x_lod_size = x_desc->GetLoDLevel(); } - ctx->SetOutputDim("Out", out_dims); + PADDLE_ENFORCE(lod_level >= 0 && lod_level < x_lod_size, + "Invalid `lod_level` which should be at least 0 and less " + "than maximum lod level of `X`"); + + ctx->SetOutputDim("Out", {seq_num, max_len, x_dims[1]}); } protected: @@ -84,9 +92,11 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) Output variable which would be a common tensor " "without lod. Each sequence would be padded to the maximum " "length."); + AddAttr("lod_level", + "(int, default 0) Specify which level lod to referred to."); AddAttr("pad_value", - "(float, default 0.0) Value to be padded " - "to the end of each sequence."); + "(float, default 0.0) Specify which value to be padded to " + "the end of each sequence."); AddComment(R"DOC( )DOC"); diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h index b36465d8e7..6d136b65f1 100644 --- a/paddle/fluid/operators/sequence_pad_op.h +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sequence_padding.h" namespace paddle { namespace operators { @@ -23,39 +24,68 @@ namespace operators { using LoDTensor = framework::LoDTensor; using LoD = framework::LoD; -// @TODO clean code +template +struct CopyFunctor { + LoDTensor* lod_tensor_; + LoDTensor* pad_tensor_; + const LoD& ref_lod_; + const DeviceContext& ctx_; + bool is_lod_to_pad_; + + CopyFunctor(LoDTensor* lod_tensor, const LoD& ref_lod, LoDTensor* pad_tensor, + const DeviceContext& ctx, bool is_lod_to_pad) + : lod_tensor_(lod_tensor), + pad_tensor_(pad_tensor), + ref_lod_(ref_lod), + ctx_(ctx), + is_lod_to_pad_(is_lod_to_pad) {} + + void operator()() const { + /* + auto seq_num = ref_lod_.size() - 1; + auto max_len = pad_tensor_->dims()[0] / seq_num; + + PADDLE_ENFORCE_EQ(max_len * seq_num, pad_tensor_->dims()[0], + "First dimension of padded tensor should be equal to " + "maximum sequence length mulplied by sequence number."); + + for (size_t i = 1; i < ref_lod_.size(); ++i) { + auto seq_start = ref_lod_[i - 1]; + auto seq_end = ref_lod_[i]; + auto pad_start = (i - 1) * max_len; + auto pad_end = pad_start + (seq_end - seq_start); + auto sub_lod_tensor = lod_tensor_->Slice(seq_start, seq_end); + auto sub_pad_tensor = pad_tensor_->Slice(pad_start, pad_end); + if (is_lod_to_pad_) { + framework::TensorCopy(sub_lod_tensor, ctx.GetPlace(), &sub_pad_tensor); + } else { + framework::TensorCopy(sub_pad_tensor, ctx.GetPlace(), &sub_lod_tensor); + } + } + */ + } +}; + template class SequencePadOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x_ptr = ctx.Input("X"); + /* + auto* x = ctx.Input("X"); auto* out_ptr = ctx.Output("Out"); out_ptr->mutable_data(ctx.GetPlace()); + // Resize(); + T pad_value = static_cast(ctx.Attr("pad_value")); + math::PaddingLoDTensorFunctor()( + ctx.template device_context(), *x, *, false); + math::SetConstant set_func; set_func(ctx.template device_context(), out_ptr, pad_value); - - auto& x_lod = x_ptr->lod(); - auto& x_last_level_lod = x_lod[x_lod.size() - 1]; - auto seq_num = x_last_level_lod.size() - 1; - auto max_len = out_ptr->dims()[0] / seq_num; - - PADDLE_ENFORCE_EQ(max_len * seq_num, out_ptr->dims()[0], - "First dimension of `Out` should be equal to " - "maximum length mulplied by sequence number."); - - for (size_t i = 1; i < x_last_level_lod.size(); ++i) { - auto x_start = x_last_level_lod[i - 1]; - auto x_end = x_last_level_lod[i]; - auto out_start = (i - 1) * max_len; - auto out_end = out_start + (x_end - x_start); - auto x_sub_tensor = x_ptr->Slice(x_start, x_end); - auto out_sub_tensor = out_ptr->Slice(out_start, out_end); - framework::TensorCopy(x_sub_tensor, ctx.GetPlace(), &out_sub_tensor); - } + */ } }; @@ -63,33 +93,26 @@ template class SequencePadGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + /* auto* x_ptr = ctx.Input("X"); auto* g_out_ptr = ctx.Input(framework::GradVarName("Out")); auto* g_x_ptr = ctx.Output(framework::GradVarName("X")); math::SetConstant set_func; - set_func(ctx.template device_context(), g_x_ptr, + set_func(ctx.template device_context(), + g_x_ptr, static_cast(0)); auto& x_lod = x_ptr->lod(); auto& x_last_level_lod = x_lod[x_lod.size() - 1]; - auto seq_num = x_last_level_lod.size() - 1; - int64_t max_len = g_out_ptr->dims()[0] / seq_num; - - PADDLE_ENFORCE_EQ(max_len * seq_num, g_out_ptr->dims()[0], - "First dimension of `Out` should be equal to " - "maximum length mulplied by sequence number."); - - for (size_t i = 1; i < x_last_level_lod.size(); ++i) { - auto x_start = x_last_level_lod[i - 1]; - auto x_end = x_last_level_lod[i]; - auto out_start = (i - 1) * max_len; - auto out_end = out_start + (x_end - x_start); - - auto g_out_sub = g_out_ptr->Slice(out_start, out_end); - auto g_x_sub = g_x_ptr->Slice(x_start, x_end); - framework::TensorCopy(g_x_sub, ctx.GetPlace(), &g_out_sub); - } + + CopyFunctor copy_func(g_out_ptr, + x_last_level_lod, + g_x_ptr, + ctx, + false); + copy_func(); + */ } }; diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 705cc894c0..1b649be203 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -161,7 +161,7 @@ class WarpCTCKernel : public framework::OpKernel { static_cast(num_sequences), static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); - math::PaddingLoDTensorFunctor()( + math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, false); const T* warpctc_logits_data = warpctc_logits.data(); @@ -216,7 +216,8 @@ class WarpCTCGradKernel : public framework::OpKernel { logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); - math::UnpaddingLoDTensorFunctor()( + math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), logits_grad, *warpctc_grad, norm_by_times); From 10ec329b7d8613c60d7324395ecc42e10b3ce0c0 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 23 May 2018 14:28:14 +0000 Subject: [PATCH 003/140] Refine code. --- .../fluid/operators/math/sequence_padding.cc | 123 ++++++++-------- .../fluid/operators/math/sequence_padding.cu | 136 +++++++++--------- .../fluid/operators/math/sequence_padding.h | 69 +++++---- .../operators/math/sequence_padding_test.cc | 10 +- paddle/fluid/operators/sequence_pad_op.cc | 2 +- paddle/fluid/operators/warpctc_op.h | 12 +- 6 files changed, 183 insertions(+), 169 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index 2dd2cafa23..5ceb26553c 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,111 +18,114 @@ namespace paddle { namespace operators { namespace math { -template +template void CopyDataCPU(framework::LoDTensor* seq_tensor, - framework::Tensor* padding_tensor, - const framework::Vector& abs_offset, + framework::Tensor* pad_tensor, + const framework::Vector& seq_offset, const int64_t& max_seq_len, const int64_t& seq_width, - bool seq_to_padding, bool norm_by_len) { + bool seq_to_pad, bool norm_by_len, + OutputLayout output_layout) { T* seq_data = seq_tensor->data(); - T* padding_data = padding_tensor->data(); + T* pad_data = pad_tensor->data(); - int64_t seq_num = abs_offset.size() - 1; + int64_t seq_num = seq_offset.size() - 1; for (int64_t i = 0; i < seq_num; ++i) { - int64_t seq_start = abs_offset[i]; - int64_t seq_len = abs_offset[i + 1] - seq_start; - + int64_t seq_start = seq_offset[i]; + int64_t seq_len = seq_offset[i + 1] - seq_start; T scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; - for (int64_t j = 0; j < seq_len; ++j) { for (int64_t k = 0; k < seq_width; ++k) { - size_t padding_offset = 0; - if (padding_layout == BATCH_LENGTH_WIDTH) { - padding_offset = (i * max_seq_len * seq_width) + j * seq_width + k; + size_t pad_data_idx = 0; + size_t seq_data_idx = (seq_start + j) * seq_width + k; + if (output_layout == kBatchLengthWidth) { + pad_data_idx = (i * max_seq_len + j) * seq_width + k; } else { - padding_offset = (j * seq_num * seq_width) + i * seq_width + k; + pad_data_idx = (j * seq_num + i) * seq_width + k; } - if (seq_to_padding) { - padding_data[padding_offset] = - seq_data[(seq_start + j) * seq_width + k] * scale; + if (seq_to_pad) { + pad_data[pad_data_idx] = seq_data[seq_data_idx] * scale; } else { - seq_data[(seq_start + j) * seq_width + k] = - padding_data[padding_offset] * scale; + seq_data[seq_data_idx] = pad_data[pad_data_idx] * scale; } } } } } -template -class PaddingLoDTensorFunctor { +template +class PaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* padding_tensor, - T padding_value = static_cast(0), - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(seq_tensor, lod_level); + framework::Tensor* pad_tensor, + T pad_value = static_cast(0), bool norm_by_times = false, + size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(seq_tensor, lod_level); auto& lod = seq_tensor.lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto seq_dims = seq_tensor.dims(); - auto padding_dims = padding_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - int64_t seq_num = abs_offset.size() - 1; - int64_t seq_width = seq_tensor.numel() / seq_dims[0]; - int64_t numel = max_seq_len * seq_num * seq_width; + auto seq_tensor_dims = seq_tensor.dims(); + auto pad_tensor_dims = pad_tensor->dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; - ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, - seq_num, seq_width, padding_layout); + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); - T* padding_data = padding_tensor->data(); + T* pad_data = pad_tensor->data(); - memset(padding_data, padding_value, numel * sizeof(T)); + memset(pad_data, pad_value, max_seq_len * seq_num * seq_width * sizeof(T)); - CopyDataCPU( - const_cast(&seq_tensor), padding_tensor, - abs_offset, max_seq_len, seq_width, true /* seq_to_padding */, - norm_by_times); + CopyDataCPU(const_cast(&seq_tensor), pad_tensor, + seq_offset, max_seq_len, seq_width, true /* seq_to_pad */, + norm_by_times, output_layout); } }; -template -class UnpaddingLoDTensorFunctor { +template +class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, framework::LoDTensor* seq_tensor, - const framework::Tensor& padding_tensor, - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(*seq_tensor, lod_level); + const framework::Tensor& pad_tensor, + bool norm_by_times = false, size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(*seq_tensor, lod_level); auto& lod = seq_tensor->lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto& seq_dims = seq_tensor->dims(); - auto& padding_dims = padding_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - int64_t seq_num = abs_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_dims[0]; + auto& seq_tensor_dims = seq_tensor->dims(); + auto& pad_tensor_dims = pad_tensor.dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; - ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, - seq_num, seq_width, padding_layout); + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); T* seq_data = seq_tensor->data(); memset(seq_data, static_cast(0), seq_tensor->numel() * sizeof(T)); - CopyDataCPU( - seq_tensor, const_cast(&padding_tensor), abs_offset, - max_seq_len, seq_width, false /* seq_to_padding */, norm_by_times); + CopyDataCPU(seq_tensor, const_cast(&pad_tensor), + seq_offset, max_seq_len, seq_width, false /* seq_to_pad */, + norm_by_times, output_layout); } }; -template class PaddingLoDTensorFunctor; -template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 2377bef024..20e3e3de2a 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -21,74 +21,74 @@ namespace math { template __global__ void SequencePaddingKernel( - T* padding_data, T* seq_data, const size_t* abs_offset, - const size_t& seq_num, const size_t& max_seq_len, const size_t& seq_width, - const PaddingLayout& padding_layout, bool norm_by_times = false, - const T& padding_value = 0) { - size_t padding_idx = blockIdx.y; - size_t seq_start = abs_offset[padding_idx]; - size_t seq_len = abs_offset[padding_idx + 1] - seq_start; + T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num, + const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times, + const T& pad_value, const OutputLayout& output_layout) { + size_t seq_idx = blockIdx.y; + size_t seq_start = seq_offset[seq_idx]; + size_t seq_len = seq_offset[seq_idx + 1] - seq_start; - size_t seq_idx = blockIdx.x * blockDim.y + threadIdx.y; + size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y; - size_t seq_offset = (seq_start + seq_idx) * seq_width; + size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width; - size_t padding_offset = 0; + size_t pad_data_offset = 0; - if (padding_layout == LENGTH_BATCH_WIDTH) { - padding_offset = (seq_idx * seq_num + padding_idx) * seq_width; + if (output_layout == kLengthBatchWidth) { + pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width; } else { - padding_offset = (padding_idx * max_seq_len + seq_idx) * seq_width; + pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width; } - if (seq_idx < seq_len) { + if (seq_step_idx < seq_len) { T scale = norm_by_times ? (1.0f / static_cast(seq_len)) : 1.0f; if (Padding) { - /* sequence -> padding */ + /* seq -> pad */ for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - padding_data[padding_offset + i] = scale * seq_data[seq_offset + i]; + pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i]; } } else { - /* padding -> sequence */ + /* pad -> seq */ for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - seq_data[seq_offset + i] = scale * padding_data[padding_offset + i]; + seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i]; } } - } else if (seq_idx < max_seq_len) { + } else if (seq_step_idx < max_seq_len) { if (Padding) { - /* sequence -> padding */ + /* seq -> pad */ for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - padding_data[padding_offset + i] = padding_value; + pad_data[pad_data_offset + i] = pad_value; } } } } -template -class PaddingLoDTensorFunctor { +template +class PaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* padding_tensor, - T padding_value = static_cast(0), - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(seq_tensor, lod_level); + framework::Tensor* pad_tensor, + T pad_value = static_cast(0), bool norm_by_times = false, + size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(seq_tensor, lod_level); auto& lod = seq_tensor.lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto seq_dims = seq_tensor.dims(); - auto padding_dims = padding_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - const int64_t seq_num = abs_offset.size() - 1; - const int64_t seq_width = seq_tensor.numel() / seq_dims[0]; + auto seq_tensor_dims = seq_tensor.dims(); + auto pad_tensor_dims = pad_tensor->dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; - ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, - seq_num, seq_width, padding_layout); + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); if (!norm_by_times && seq_num == 1UL) { - TensorCopy(seq_tensor, context.GetPlace(), context, padding_tensor); - padding_tensor->Resize(padding_dims); + TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor); + pad_tensor->Resize(pad_tensor_dims); return; } @@ -107,37 +107,40 @@ class PaddingLoDTensorFunctor { dim3 grid(grid_dim_x, grid_dim_y); const T* seq_data = seq_tensor.data(); - T* padding_data = padding_tensor->data(); + T* pad_data = pad_tensor->data(); SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), - abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, padding_layout, norm_by_times, padding_value); + pad_data, const_cast(seq_data), + seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, + seq_width, norm_by_times, pad_value, output_layout); } }; -template -class UnpaddingLoDTensorFunctor { +template +class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, framework::LoDTensor* seq_tensor, - const framework::Tensor& padding_tensor, - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(*seq_tensor, lod_level); + const framework::Tensor& pad_tensor, + bool norm_by_times = false, size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(*seq_tensor, lod_level); auto& lod = seq_tensor->lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto seq_dims = seq_tensor->dims(); - auto padding_dims = padding_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - int64_t seq_num = abs_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_dims[0]; + auto seq_tensor_dims = seq_tensor->dims(); + auto pad_tensor_dims = pad_tensor.dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); if (!norm_by_times && seq_num == 1UL) { - TensorCopy(padding_tensor, context.GetPlace(), context, seq_tensor); - seq_tensor->Resize(seq_dims); + TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor); + seq_tensor->Resize(seq_tensor_dims); return; } @@ -155,20 +158,25 @@ class UnpaddingLoDTensorFunctor(); + const T* pad_data = pad_tensor.data(); T* seq_data = seq_tensor->data(); - SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, - abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, padding_layout, norm_by_times); + SequencePaddingKernel<<>>( + const_cast(pad_data), seq_data, + seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, + seq_width, norm_by_times, static_cast(0), output_layout); } }; -template class PaddingLoDTensorFunctor; -template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 91d205641a..44d6404335 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -22,49 +22,46 @@ namespace paddle { namespace operators { namespace math { -enum PaddingLayout { BATCH_LENGTH_WIDTH, LENGTH_BATCH_WIDTH }; +enum OutputLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; -inline static size_t MaximumSequenceLength(const framework::LoD& lod, - const size_t level) { - const size_t seq_num = lod[level].size() - 1; +inline static size_t MaximumSequenceLength( + const framework::Vector& seq_offset) { + size_t seq_num = seq_offset.size() - 1; size_t max_seq_len = 0; - auto abs_offset = framework::ToAbsOffset(lod)[level]; for (size_t i = 0; i < seq_num; ++i) { - max_seq_len = std::max(max_seq_len, abs_offset[i + 1] - abs_offset[i]); + max_seq_len = std::max(max_seq_len, seq_offset[i + 1] - seq_offset[i]); } return max_seq_len; } -inline static void ValidateLoD(const framework::LoDTensor& seq_tensor, - const size_t& lod_level) { +inline static void CheckLoD(const framework::LoDTensor& seq_tensor, + const size_t& lod_level) { PADDLE_ENFORCE(lod_level < seq_tensor.lod().size(), - "Invalid `lod_level` which should be at least 0 and less " - "than maximum lod level of `seq_tensor`."); + "Invalid lod level which should be at least 0 and less " + "than maximum lod level of sequence tensor."); } -inline static void ValidateShape(const framework::DDim& seq_tensor_dims, - const size_t& abs_offset_back_value, - const framework::DDim& padding_tensor_dims, - const int64_t& max_seq_len, - const int64_t& seq_num, - const int64_t& seq_width, - const PaddingLayout& padding_layout) { - PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), - abs_offset_back_value, - "The 1st dimension of `seq_tensor` should be equal to " - "sum of lengths of all sequences."); +inline static void CheckDims(const framework::DDim& seq_tensor_dims, + const size_t& last_offset, + const framework::DDim& pad_tensor_dims, + const int64_t& max_seq_len, const int64_t& seq_num, + const int64_t& seq_width, + const OutputLayout& output_layout) { + PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), last_offset, + "Value of 1st dimension of the sequence tensor should be " + "equal to sum of lengths of all sequences."); - PADDLE_ENFORCE_EQ(padding_tensor_dims.size(), 3UL, - "`padding_tensor` should be a 3-D tensor."); + PADDLE_ENFORCE_EQ(pad_tensor_dims.size(), 3UL, + "Padded tensor should be a 3-D tensor."); - if (padding_layout == BATCH_LENGTH_WIDTH) { - PADDLE_ENFORCE_EQ(padding_tensor_dims, + if (output_layout == kBatchLengthWidth) { + PADDLE_ENFORCE_EQ(pad_tensor_dims, framework::make_ddim({seq_num, max_seq_len, seq_width})); - } else if (padding_layout == LENGTH_BATCH_WIDTH) { - PADDLE_ENFORCE_EQ(padding_tensor_dims, + } else if (output_layout == kLengthBatchWidth) { + PADDLE_ENFORCE_EQ(pad_tensor_dims, framework::make_ddim({max_seq_len, seq_num, seq_width})); } else { - PADDLE_THROW("Unsupported padding layout."); + PADDLE_THROW("Unsupported output layout."); } } @@ -94,23 +91,25 @@ inline static void ValidateShape(const framework::DDim& seq_tensor_dims, * * \note transposition is also done in this functor. */ -template +template class PaddingLoDTensorFunctor { public: void operator()(const DeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* padding_tensor, - T padding_value = static_cast(0), - bool norm_by_times = false, size_t lod_level = 0); + framework::Tensor* pad_tensor, + T pad_value = static_cast(0), bool norm_by_times = false, + size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth); }; -template +template class UnpaddingLoDTensorFunctor { public: void operator()(const DeviceContext& context, framework::LoDTensor* seq_tensor, - const framework::Tensor& padding_tensor, - bool norm_by_times = false, size_t lod_level = 0); + const framework::Tensor& pad_tensor, + bool norm_by_times = false, size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth); }; } // namespace math diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index b0c201db0c..82459274c4 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -46,20 +46,24 @@ void TestSequencePadding(const paddle::framework::LoD& lod, } const size_t max_sequence_length = - paddle::operators::math::MaximumSequenceLength(lod, level); + paddle::operators::math::MaximumSequenceLength(lod[level]); const size_t num_sequences = lod[level].size() - 1; auto padding_dims = paddle::framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); + padding.mutable_data(padding_dims, *place); + paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, &padding, false); + *context, seq, &padding, 0, false, 0, + paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); seq_back.mutable_data(seq_dims, *place); paddle::operators::math::UnpaddingLoDTensorFunctor()( - *context, &seq_back, padding, false); + *context, &seq_back, padding, false, 0, + paddle::operators::math::kLengthBatchWidth); if (paddle::platform::is_cpu_place(*place)) { cpu_seq_back = seq_back; diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index f3a6fff0e1..dc79b252c7 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -54,7 +54,7 @@ class SequencePadOp : public framework::OperatorWithKernel { seq_num = x_abs_offset.size() - 1; - for (size_t i = 1; i <= seq_num; ++i) { + for (int64_t i = 1; i <= seq_num; ++i) { int64_t seq_len = x_abs_offset[i] - x_abs_offset[i - 1]; max_len = max_len < seq_len ? seq_len : max_len; } diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 1b649be203..075eb010c5 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -155,15 +155,16 @@ class WarpCTCKernel : public framework::OpKernel { // warpctc needs sequences data stored in transposed padding format Tensor warpctc_logits; const size_t max_sequence_length = - math::MaximumSequenceLength(logits_lod, level); + math::MaximumSequenceLength(logits_lod[level]); auto warpctc_logits_dims = framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); - math::PaddingLoDTensorFunctor()( + math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, - false); + static_cast(0), false /* norm_by_times */, 0, + math::kLengthBatchWidth); const T* warpctc_logits_data = warpctc_logits.data(); std::vector warpctc_label_lengths(num_sequences); @@ -216,10 +217,9 @@ class WarpCTCGradKernel : public framework::OpKernel { logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); - math::UnpaddingLoDTensorFunctor()( + math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), logits_grad, - *warpctc_grad, norm_by_times); + *warpctc_grad, norm_by_times, 0, math::kLengthBatchWidth); const T* loss_grad_data = loss_grad->data(); math::ScaleLoDTensorFunctor()( From 94bc25d4bf53713fcfd9d022b308345dcdb2dc43 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 4 Jul 2018 18:48:36 +0800 Subject: [PATCH 004/140] add releasing for mac --- doc/fluid/dev/releasing_process_en.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/doc/fluid/dev/releasing_process_en.md b/doc/fluid/dev/releasing_process_en.md index f989b964d6..2c1c30c1ed 100644 --- a/doc/fluid/dev/releasing_process_en.md +++ b/doc/fluid/dev/releasing_process_en.md @@ -50,6 +50,33 @@ pop-up box, choose the current release branch and click "Run Build" button. You * pypi does not allow overwrite the already uploaded version of wheel package, even if you delete the old version. you must change the version number before upload a new one. +### Publish wheel Packages for MacOS + +You need to build the binary wheel package for MacOS before publishing, to +make sure that the package can be used by many versions of MacOS +(10.11, 10.12, 10.13) and different python installs (python.org, homebrew, etc.), +you must build the package ***exactly*** following below steps: + +Build steps: + +1. install python from python.org downloads, and make sure it's currently in use + in your system. +1. `export MACOSX_DEPLOYMENT_TARGET=10.11`, use `10.11` is enough for recent versions. +1. `git clone https://github.com/PaddlePaddle/Paddle.git && cd Paddle && mkdir build && cd build` +1. `cmake -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_SYSTEM_BLAS=OFF ..`, make sure the output of `cmake` command is using the correct python interpreter installed from python.org +1. `make -j` +1. `pip install delocate` +1. `mkdir fixed_wheel && delocate-wheel -w fixed_wheel python/dist/*.whl` + +Then the whl under `fixed_wheel` is ready to upload. + +Install steps: + +1. run `pip install paddlepaddle...whl` +1. find the `libpython.dylib` that are currently in use: + - for python.org package installs, do nothing. + - for other python installs, find the path of `libpython*.dylib` and `export LD_LIBRARY_PATH=you path && DYLD_LIBRARY_PATH=your path` + ## Publish Docker Images Our CI tool will push latest images to DockerHub, so we only need to push a version tag like: From 3c749fae43765a1543b450a9a21ac514a1d9a535 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 16 Aug 2018 20:53:47 +0800 Subject: [PATCH 005/140] update CPU sequence_padding functor --- .../fluid/operators/math/sequence_padding.cc | 149 +++++++++--------- .../fluid/operators/math/sequence_padding.h | 56 +++---- .../operators/math/sequence_padding_test.cc | 6 +- paddle/fluid/operators/warpctc_op.h | 10 +- 4 files changed, 108 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index 5ceb26553c..e8ccf006ad 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,37 +18,45 @@ namespace paddle { namespace operators { namespace math { +enum CopyType { kSeqToPad, kPadToSeq }; + template -void CopyDataCPU(framework::LoDTensor* seq_tensor, - framework::Tensor* pad_tensor, - const framework::Vector& seq_offset, - const int64_t& max_seq_len, const int64_t& seq_width, - bool seq_to_pad, bool norm_by_len, - OutputLayout output_layout) { - T* seq_data = seq_tensor->data(); - T* pad_data = pad_tensor->data(); - - int64_t seq_num = seq_offset.size() - 1; - - for (int64_t i = 0; i < seq_num; ++i) { - int64_t seq_start = seq_offset[i]; - int64_t seq_len = seq_offset[i + 1] - seq_start; - T scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; - for (int64_t j = 0; j < seq_len; ++j) { - for (int64_t k = 0; k < seq_width; ++k) { - size_t pad_data_idx = 0; - size_t seq_data_idx = (seq_start + j) * seq_width + k; - if (output_layout == kBatchLengthWidth) { - pad_data_idx = (i * max_seq_len + j) * seq_width + k; - } else { - pad_data_idx = (j * seq_num + i) * seq_width + k; - } - if (seq_to_pad) { - pad_data[pad_data_idx] = seq_data[seq_data_idx] * scale; - } else { - seq_data[seq_data_idx] = pad_data[pad_data_idx] * scale; +void CopyValidData(framework::Tensor* dst_tensor, + const framework::Tensor* src_tensor, + const framework::Vector& seq_offsets, + int pad_seq_len, int step_width, bool norm_by_len, + CopyType type, PadLayout layout) { + int seq_num = seq_offsets.size() - 1; + const T* src_data = src_tensor->data(); + T* dst_data = dst_tensor->data(); + + int seq_cpy_gap = step_width; + int pad_cpy_gap = + layout == kBatchLengthWidth ? step_width : seq_num * step_width; + for (int seq_idx = 0; seq_idx < seq_num; ++seq_idx) { + int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; + PADDLE_ENFORCE_GE( + pad_seq_len, valid_seq_len, + "The padded sequence length can not be less than its original length."); + int seq_data_offset = seq_offsets[seq_idx] * step_width; + int pad_data_offset = layout == kBatchLengthWidth + ? seq_idx * pad_seq_len * step_width + : seq_idx * step_width; + float scale = 1.0f / static_cast(valid_seq_len); + + for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) { + const T* src = + src_data + (type == kSeqToPad ? seq_data_offset : pad_data_offset); + T* dst = + dst_data + (type == kSeqToPad ? pad_data_offset : seq_data_offset); + memcpy(dst, src, step_width * sizeof(T)); + if (norm_by_len) { + for (int i = 0; i < step_width; ++i) { + *(dst + i) *= scale; } } + seq_data_offset += seq_cpy_gap; + pad_data_offset += pad_cpy_gap; } } } @@ -58,31 +66,37 @@ class PaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* pad_tensor, - T pad_value = static_cast(0), bool norm_by_times = false, - size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(seq_tensor, lod_level); - - auto& lod = seq_tensor.lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - + framework::LoDTensor* pad_tensor, + std::vector pad_value = {0}, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor.lod())[lod_level]; auto seq_tensor_dims = seq_tensor.dims(); auto pad_tensor_dims = pad_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + PADDLE_ENFORCE(pad_value.size() == 1 || + static_cast(pad_value.size()) == step_width, + "The size of 'pad_value' can only be 1 or be equal to the " + "'step_width'."); - T* pad_data = pad_tensor->data(); + if (pad_value.size() == 1) { + pad_value = std::vector(step_width, pad_value[0]); + } - memset(pad_data, pad_value, max_seq_len * seq_num * seq_width * sizeof(T)); + // fill padding value + T* pad_data = pad_tensor->data(); + for (int i = 0; i < pad_tensor->numel() / step_width; ++i) { + memcpy(pad_data, pad_value.data(), step_width * sizeof(T)); + } - CopyDataCPU(const_cast(&seq_tensor), pad_tensor, - seq_offset, max_seq_len, seq_width, true /* seq_to_pad */, - norm_by_times, output_layout); + CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kSeqToPad, layout); } }; @@ -90,30 +104,23 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - framework::LoDTensor* seq_tensor, - const framework::Tensor& pad_tensor, - bool norm_by_times = false, size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(*seq_tensor, lod_level); - - auto& lod = seq_tensor->lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - - auto& seq_tensor_dims = seq_tensor->dims(); - auto& pad_tensor_dims = pad_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; - - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); - - T* seq_data = seq_tensor->data(); - memset(seq_data, static_cast(0), seq_tensor->numel() * sizeof(T)); - - CopyDataCPU(seq_tensor, const_cast(&pad_tensor), - seq_offset, max_seq_len, seq_width, false /* seq_to_pad */, - norm_by_times, output_layout); + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout& layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + auto seq_tensor_dims = seq_tensor->dims(); + auto pad_tensor_dims = pad_tensor.dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + + CopyValidData(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kPadToSeq, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 44d6404335..d5790e2ba2 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -22,7 +23,7 @@ namespace paddle { namespace operators { namespace math { -enum OutputLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; +enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; inline static size_t MaximumSequenceLength( const framework::Vector& seq_offset) { @@ -34,35 +35,22 @@ inline static size_t MaximumSequenceLength( return max_seq_len; } -inline static void CheckLoD(const framework::LoDTensor& seq_tensor, - const size_t& lod_level) { - PADDLE_ENFORCE(lod_level < seq_tensor.lod().size(), - "Invalid lod level which should be at least 0 and less " - "than maximum lod level of sequence tensor."); -} - inline static void CheckDims(const framework::DDim& seq_tensor_dims, - const size_t& last_offset, const framework::DDim& pad_tensor_dims, - const int64_t& max_seq_len, const int64_t& seq_num, - const int64_t& seq_width, - const OutputLayout& output_layout) { - PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), last_offset, + const framework::Vector& seq_offset, + int64_t padded_seq_len, int64_t step_width, + const PadLayout& layout) { + PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), seq_offset.back(), "Value of 1st dimension of the sequence tensor should be " "equal to sum of lengths of all sequences."); - PADDLE_ENFORCE_EQ(pad_tensor_dims.size(), 3UL, - "Padded tensor should be a 3-D tensor."); + PADDLE_ENFORCE(seq_tensor_dims.size() == 1 || seq_tensor_dims.size() == 2, + "seq_tensor's rank should be 1 or 2."); - if (output_layout == kBatchLengthWidth) { - PADDLE_ENFORCE_EQ(pad_tensor_dims, - framework::make_ddim({seq_num, max_seq_len, seq_width})); - } else if (output_layout == kLengthBatchWidth) { - PADDLE_ENFORCE_EQ(pad_tensor_dims, - framework::make_ddim({max_seq_len, seq_num, seq_width})); - } else { - PADDLE_THROW("Unsupported output layout."); - } + PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() || + seq_tensor_dims.size() == pad_tensor_dims.size(), + "pad_tensor's rank should be 1 greater than seq_tensor's " + "rank, or be equal with it."); } /* @@ -94,22 +82,22 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims, template class PaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, + void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* pad_tensor, - T pad_value = static_cast(0), bool norm_by_times = false, - size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth); + framework::LoDTensor* pad_tensor, + std::vector pad_value = {0}, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth); }; template class UnpaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, - framework::LoDTensor* seq_tensor, - const framework::Tensor& pad_tensor, - bool norm_by_times = false, size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth); + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout& layout = kBatchLengthWidth); }; } // namespace math diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index 82459274c4..3171c7c33e 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -23,7 +23,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, paddle::framework::LoDTensor cpu_seq_back; paddle::framework::LoDTensor seq; paddle::framework::LoDTensor seq_back; - paddle::framework::Tensor padding; + paddle::framework::LoDTensor padding; const size_t level = lod.size() - 1; auto seq_dims = @@ -56,13 +56,13 @@ void TestSequencePadding(const paddle::framework::LoD& lod, padding.mutable_data(padding_dims, *place); paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, &padding, 0, false, 0, + *context, seq, &padding, {0}, -1, 0, false, paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); seq_back.mutable_data(seq_dims, *place); paddle::operators::math::UnpaddingLoDTensorFunctor()( - *context, &seq_back, padding, false, 0, + *context, padding, &seq_back, -1, 0, false, paddle::operators::math::kLengthBatchWidth); if (paddle::platform::is_cpu_place(*place)) { diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index cb56f42a8d..6cbf985039 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -153,7 +153,7 @@ class WarpCTCKernel : public framework::OpKernel { framework::make_ddim({static_cast(num_sequences), 1}); // warpctc needs sequences data stored in transposed padding format - Tensor warpctc_logits; + LoDTensor warpctc_logits; const size_t max_sequence_length = math::MaximumSequenceLength(logits_lod[level]); auto warpctc_logits_dims = @@ -163,7 +163,7 @@ class WarpCTCKernel : public framework::OpKernel { warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, - static_cast(0), false /* norm_by_times */, 0, + {static_cast(0)}, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth); const T* warpctc_logits_data = warpctc_logits.data(); @@ -210,15 +210,15 @@ template class WarpCTCGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* warpctc_grad = ctx.Input("WarpCTCGrad"); + auto* warpctc_grad = ctx.Input("WarpCTCGrad"); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), logits_grad, - *warpctc_grad, norm_by_times, 0, math::kLengthBatchWidth); + ctx.template device_context(), *warpctc_grad, + logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth); const T* loss_grad_data = loss_grad->data(); math::ScaleLoDTensorFunctor()( From 6588d0e039b36be9febd51683b6cad17264628ab Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Mon, 13 Aug 2018 12:20:06 +0200 Subject: [PATCH 006/140] Update MKLDNN to 0.15, fix conv integration --- cmake/external/mkldnn.cmake | 2 +- paddle/fluid/framework/tensor.cc | 9 ++++---- paddle/fluid/framework/tensor.h | 14 +++++++----- paddle/fluid/framework/tensor_impl.h | 9 ++++---- paddle/fluid/operators/conv_mkldnn_op.cc | 28 +++++++++++++++++------- 5 files changed, 39 insertions(+), 23 deletions(-) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index 260985cc8a..baf253df27 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -54,7 +54,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS ${MKLDNN_DEPENDS} GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" - GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51" + GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944" PREFIX ${MKLDNN_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 56bb9142da..222a51672f 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -31,7 +31,8 @@ size_t Tensor::memory_size() const { return holder_ == nullptr ? 0UL : holder_->size() - offset_; } -void* Tensor::mutable_data(platform::Place place, std::type_index type) { +void* Tensor::mutable_data(platform::Place place, std::type_index type, + int64_t requested_size) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -39,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) { "When calling this method, the Tensor's numel must be " "equal or larger than zero. " "Please check Tensor::Resize has been called first."); - int64_t size = numel() * SizeOfType(type); + int64_t size = requested_size ? requested_size : numel() * SizeOfType(type); /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || holder_->size() < size + offset_) { @@ -68,10 +69,10 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) { offset_); } -void* Tensor::mutable_data(platform::Place place) { +void* Tensor::mutable_data(platform::Place place, int64_t requested_size) { PADDLE_ENFORCE(this->holder_ != nullptr, "Cannot invoke mutable data if current hold nothing."); - return mutable_data(place, holder_->type()); + return mutable_data(place, holder_->type(), requested_size); } Tensor& Tensor::ShareDataWith(const Tensor& src) { diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 0bbfd66148..a4454c90b0 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -89,22 +89,24 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(platform::Place place); + T* mutable_data(platform::Place place, int64_t requested_size = 0); - void* mutable_data(platform::Place place, std::type_index type); + void* mutable_data(platform::Place place, std::type_index type, + int64_t requested_size = 0); - void* mutable_data(platform::Place place); + void* mutable_data(platform::Place place, int64_t requested_size = 0); /** * @brief Return a pointer to mutable memory block. * - * @param[in] dims The dimensions of the memory block. - * @param[in] place The place of the memory block. + * @param[in] dims The dimensions of the memory block. + * @param[in] place The place of the memory block. + * @param[in] requested_size The size of the block in bytes. * * @note If not exist, then allocation. */ template - T* mutable_data(DDim dims, platform::Place place); + T* mutable_data(DDim dims, platform::Place place, int64_t requested_size = 0); /*! Return the dimensions of the memory block. */ const DDim& dims() const; diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index b7b62eef23..ea10c9a265 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -46,16 +46,17 @@ inline T* Tensor::data() { } template -inline T* Tensor::mutable_data(DDim dims, platform::Place place) { +inline T* Tensor::mutable_data(DDim dims, platform::Place place, + int64_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); - return mutable_data(place); + return mutable_data(place, requested_size); } template -inline T* Tensor::mutable_data(platform::Place place) { +inline T* Tensor::mutable_data(platform::Place place, int64_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(place, typeid(T))); + return reinterpret_cast(mutable_data(place, typeid(T), requested_size)); } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index f07ab5a33b..77d0cf07a8 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -53,6 +53,18 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { key_ += "-BWD"; } + size_t GetDstMemorySize() { + return conv_pd_->dst_primitive_desc().get_size(); + } + + size_t GetDiffWeightsMemorySize() { + return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size(); + } + + size_t GetDiffSourceMemorySize() { + return conv_bwd_data_pd_->diff_src_primitive_desc().get_size(); + } + std::shared_ptr AcquireSrcMemoryFromWeightsPrimitive( const std::shared_ptr user_memory_p, std::vector& pipeline) { // NOLINT @@ -251,7 +263,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const T* input_data = input->data(); const T* filter_data = filter->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); PADDLE_ENFORCE(input->dims().size() == 4, "Input must be with 4 dimensions, i.e. NCHW"); @@ -306,6 +317,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto user_weights_memory_p = handler.AcquireWeightsMemory( user_weights_md, to_void_cast(filter_data)); + T* output_data = + output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); // create reorder primitive if the input format is not the preferred one auto src_memory_p = handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); @@ -393,13 +406,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { T* input_grad_data = nullptr; T* filter_grad_data = nullptr; - if (input_grad) { - input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - } - if (filter_grad) { - filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - } - std::vector src_tz = paddle::framework::vectorize2int(input->dims()); std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); @@ -485,6 +491,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { handler.AcquireDiffDstMemoryFromWeightsPrimitive( user_diff_dst_memory_p, pipeline); + size_t size = handler.GetDiffWeightsMemorySize(); + filter_grad_data = filter_grad->mutable_data(ctx.GetPlace(), size); + auto diff_weights_memory_p = handler.AcquireDiffWeightsMemoryFromWeightsPrimitive( reinterpret_cast(filter_grad_data)); @@ -507,6 +516,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p, pipeline); + size_t size = handler.GetDiffSourceMemorySize(); + input_grad_data = input_grad->mutable_data(ctx.GetPlace(), size); + auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( reinterpret_cast(input_grad_data)); From 4a7f0698e0b7169022409b0f962e7c7d24caab85 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 14 Aug 2018 13:29:44 +0200 Subject: [PATCH 007/140] Add consts to new MKLDNN integration Also replace memory types from int64_t to size_t --- paddle/fluid/framework/tensor.cc | 6 +++--- paddle/fluid/framework/tensor.h | 8 ++++---- paddle/fluid/framework/tensor_impl.h | 4 ++-- paddle/fluid/operators/conv_mkldnn_op.cc | 10 +++++----- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 222a51672f..d61dbb98a2 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -32,7 +32,7 @@ size_t Tensor::memory_size() const { } void* Tensor::mutable_data(platform::Place place, std::type_index type, - int64_t requested_size) { + size_t requested_size) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -40,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type, "When calling this method, the Tensor's numel must be " "equal or larger than zero. " "Please check Tensor::Resize has been called first."); - int64_t size = requested_size ? requested_size : numel() * SizeOfType(type); + size_t size = requested_size ? requested_size : numel() * SizeOfType(type); /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || holder_->size() < size + offset_) { @@ -69,7 +69,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type, offset_); } -void* Tensor::mutable_data(platform::Place place, int64_t requested_size) { +void* Tensor::mutable_data(platform::Place place, size_t requested_size) { PADDLE_ENFORCE(this->holder_ != nullptr, "Cannot invoke mutable data if current hold nothing."); return mutable_data(place, holder_->type(), requested_size); diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index a4454c90b0..4cf95fa0ae 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -89,12 +89,12 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(platform::Place place, int64_t requested_size = 0); + T* mutable_data(platform::Place place, size_t requested_size = 0); void* mutable_data(platform::Place place, std::type_index type, - int64_t requested_size = 0); + size_t requested_size = 0); - void* mutable_data(platform::Place place, int64_t requested_size = 0); + void* mutable_data(platform::Place place, size_t requested_size = 0); /** * @brief Return a pointer to mutable memory block. @@ -106,7 +106,7 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(DDim dims, platform::Place place, int64_t requested_size = 0); + T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0); /*! Return the dimensions of the memory block. */ const DDim& dims() const; diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index ea10c9a265..6d3047c95d 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -47,14 +47,14 @@ inline T* Tensor::data() { template inline T* Tensor::mutable_data(DDim dims, platform::Place place, - int64_t requested_size) { + size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); return mutable_data(place, requested_size); } template -inline T* Tensor::mutable_data(platform::Place place, int64_t requested_size) { +inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); return reinterpret_cast(mutable_data(place, typeid(T), requested_size)); } diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 77d0cf07a8..d75e6412c8 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -53,15 +53,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { key_ += "-BWD"; } - size_t GetDstMemorySize() { + size_t GetDstMemorySize() const { return conv_pd_->dst_primitive_desc().get_size(); } - size_t GetDiffWeightsMemorySize() { + size_t GetDiffWeightsMemorySize() const { return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size(); } - size_t GetDiffSourceMemorySize() { + size_t GetDiffSourceMemorySize() const { return conv_bwd_data_pd_->diff_src_primitive_desc().get_size(); } @@ -491,7 +491,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { handler.AcquireDiffDstMemoryFromWeightsPrimitive( user_diff_dst_memory_p, pipeline); - size_t size = handler.GetDiffWeightsMemorySize(); + const size_t size = handler.GetDiffWeightsMemorySize(); filter_grad_data = filter_grad->mutable_data(ctx.GetPlace(), size); auto diff_weights_memory_p = @@ -516,7 +516,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p, pipeline); - size_t size = handler.GetDiffSourceMemorySize(); + const size_t size = handler.GetDiffSourceMemorySize(); input_grad_data = input_grad->mutable_data(ctx.GetPlace(), size); auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( From 8d8d48a34f9116f5a501d69cc4dbbf9ce13a1446 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 17 Aug 2018 17:58:12 +0800 Subject: [PATCH 008/140] Complete sequence_pad_op and its CPU kernel. Add unittests --- .../fluid/operators/math/sequence_padding.cc | 24 +++- .../fluid/operators/math/sequence_padding.h | 3 - paddle/fluid/operators/sequence_pad_op.cc | 105 +++++++------- paddle/fluid/operators/sequence_pad_op.cu | 10 +- paddle/fluid/operators/sequence_pad_op.h | 93 +++--------- .../tests/unittests/test_sequence_pad_op.py | 134 ++++++++++++++++++ 6 files changed, 234 insertions(+), 135 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_sequence_pad_op.py diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index e8ccf006ad..d3dab64f60 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -70,9 +70,10 @@ class PaddingLoDTensorFunctor { std::vector pad_value = {0}, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth) { - auto seq_offsets = framework::ToAbsOffset(seq_tensor.lod())[lod_level]; - auto seq_tensor_dims = seq_tensor.dims(); - auto pad_tensor_dims = pad_tensor->dims(); + auto seq_lod = seq_tensor.lod(); + const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); if (pad_seq_len == -1) { pad_seq_len = MaximumSequenceLength(seq_offsets); } @@ -91,12 +92,21 @@ class PaddingLoDTensorFunctor { // fill padding value T* pad_data = pad_tensor->data(); - for (int i = 0; i < pad_tensor->numel() / step_width; ++i) { - memcpy(pad_data, pad_value.data(), step_width * sizeof(T)); + for (int i = 0; i < pad_tensor->numel(); i += step_width) { + memcpy(pad_data + i, pad_value.data(), step_width * sizeof(T)); } CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, step_width, norm_by_times, kSeqToPad, layout); + + // Set pad_tensor's lod info if possible + if (layout == kBatchLengthWidth) { + framework::LoD pad_lod(seq_lod.begin() + lod_level, seq_lod.end()); + for (size_t i = 0; i < pad_lod[0].size(); ++i) { + pad_lod[0][i] = i * pad_seq_len; + } + pad_tensor->set_lod(pad_lod); + } } }; @@ -109,8 +119,8 @@ class UnpaddingLoDTensorFunctor { int lod_level = 0, bool norm_by_times = false, const PadLayout& layout = kBatchLengthWidth) { auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; - auto seq_tensor_dims = seq_tensor->dims(); - auto pad_tensor_dims = pad_tensor.dims(); + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); if (pad_seq_len == -1) { pad_seq_len = MaximumSequenceLength(seq_offsets); } diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index d5790e2ba2..9b8c892c53 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -44,9 +44,6 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims, "Value of 1st dimension of the sequence tensor should be " "equal to sum of lengths of all sequences."); - PADDLE_ENFORCE(seq_tensor_dims.size() == 1 || seq_tensor_dims.size() == 2, - "seq_tensor's rank should be 1 or 2."); - PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() || seq_tensor_dims.size() == pad_tensor_dims.size(), "pad_tensor's rank should be 1 greater than seq_tensor's " diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index dc79b252c7..f23710cf4d 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -21,82 +21,85 @@ class SequencePadOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequencePadOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PadValue"), + "Input(PadValue) of SequencePadOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SequencePadOp should not be null."); auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "The rank of Input(x) can't be less than 2."); + auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size()); + auto pad_value_dims = ctx->GetInputDim("PadValue"); + PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) || + pad_value_dims == time_step_dims, + "The Input(PadValue) must be a scalar or a tensor whose " + "shape equals to time steps in sequences"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, - "Only support 2-D tensor, rank of Input(X) should be 2."); - - int lod_level = ctx->Attrs().Get("lod_level"); - - int64_t max_len = -1; - int64_t seq_num = -1; - int x_lod_size = -1; + int batch_dim_size = -1; if (ctx->IsRuntime()) { + // run time framework::Variable* x_var = boost::get(ctx->GetInputVarPtrs("X")[0]); - - auto& x_lod = x_var->Get().lod(); - - x_lod_size = x_lod.size(); - - auto x_abs_offset = framework::ToAbsOffset(x_lod)[lod_level]; - - PADDLE_ENFORCE_EQ(x_dims[0], static_cast(x_abs_offset.back()), - "The first dimension of `X` should be equal to sum " - "of all sequences' length."); - - seq_num = x_abs_offset.size() - 1; - - for (int64_t i = 1; i <= seq_num; ++i) { - int64_t seq_len = x_abs_offset[i] - x_abs_offset[i - 1]; - max_len = max_len < seq_len ? seq_len : max_len; + const auto& x_lod = x_var->Get().lod(); + PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info."); + const auto& x_lod_0 = x_lod[0]; + PADDLE_ENFORCE_GE(x_lod_0.size(), 2, + "The Input(X)'s lod info is corrupted."); + PADDLE_ENFORCE_EQ( + x_dims[0], static_cast(x_lod_0.back()), + "The Input(X)'s lod info mismatches the actual tensor shape."); + + int seq_num = x_lod_0.size() - 1; + int max_seq_len = math::MaximumSequenceLength(x_lod_0); + int padded_length = ctx->Attrs().Get("padded_length"); + if (padded_length == -1) { + padded_length = max_seq_len; } + PADDLE_ENFORCE_GE(padded_length, max_seq_len, + "The Attr(padded_length) must be -1 or an int greater " + "than the length of the longest original sequence."); + batch_dim_size = padded_length * seq_num; } else { + // compile time framework::VarDesc* x_desc = boost::get(ctx->GetInputVarPtrs("X")[0]); - x_lod_size = x_desc->GetLoDLevel(); + PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1); } - PADDLE_ENFORCE(lod_level >= 0 && lod_level < x_lod_size, - "Invalid `lod_level` which should be at least 0 and less " - "than maximum lod level of `X`"); - - ctx->SetOutputDim("Out", {seq_num, max_len, x_dims[1]}); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + auto out_dims = x_dims; + out_dims[0] = batch_dim_size; + ctx->SetOutputDim("Out", out_dims); } }; class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { public: - SequencePadOpMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { + void Make() override { AddInput("X", "(LoDTensor, default LoDTensor) Input variable which " - "should contain lod information. Length of each sequence would " - "be computed from the most bottom level lod."); - AddOutput("Out", - "(Tensor) Output variable which would be a common tensor " - "without lod. Each sequence would be padded to the maximum " - "length."); - AddAttr("lod_level", - "(int, default 0) Specify which level lod to referred to."); - AddAttr("pad_value", - "(float, default 0.0) Specify which value to be padded to " - "the end of each sequence."); + "should contain lod information."); + AddInput("PadValue", + "(LoDTensor), this Tensor holds values that will be fill into " + "padded steps. It can be a scalar or a tensor whose shape equals " + "to time steps in sequences. If it's a scalar, it will be " + "automatically broadcasted to the shape of time step."); + AddOutput( + "Out", + "(LoDTensor) The output vairable, which contains padded sequences."); + AddAttr( + "padded_length", + "The length of padded sequences. It can be setted to -1 or " + "any positive int. When it is -1, all sequences will be padded up to " + "the length of the longest one among them; when it a certain positive " + "value, it must be greater than the length of the longest original " + "sequence.") + .SetDefault(-1); AddComment(R"DOC( )DOC"); diff --git a/paddle/fluid/operators/sequence_pad_op.cu b/paddle/fluid/operators/sequence_pad_op.cu index a2fa62957e..ff8f81a2f0 100644 --- a/paddle/fluid/operators/sequence_pad_op.cu +++ b/paddle/fluid/operators/sequence_pad_op.cu @@ -17,7 +17,13 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sequence_pad, - ops::SequencePadOpKernel); + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel); REGISTER_OP_CUDA_KERNEL( sequence_pad_grad, - ops::SequencePadGradOpKernel); + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h index 6d136b65f1..44aff30879 100644 --- a/paddle/fluid/operators/sequence_pad_op.h +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/math/math_function.h" @@ -24,68 +26,24 @@ namespace operators { using LoDTensor = framework::LoDTensor; using LoD = framework::LoD; -template -struct CopyFunctor { - LoDTensor* lod_tensor_; - LoDTensor* pad_tensor_; - const LoD& ref_lod_; - const DeviceContext& ctx_; - bool is_lod_to_pad_; - - CopyFunctor(LoDTensor* lod_tensor, const LoD& ref_lod, LoDTensor* pad_tensor, - const DeviceContext& ctx, bool is_lod_to_pad) - : lod_tensor_(lod_tensor), - pad_tensor_(pad_tensor), - ref_lod_(ref_lod), - ctx_(ctx), - is_lod_to_pad_(is_lod_to_pad) {} - - void operator()() const { - /* - auto seq_num = ref_lod_.size() - 1; - auto max_len = pad_tensor_->dims()[0] / seq_num; - - PADDLE_ENFORCE_EQ(max_len * seq_num, pad_tensor_->dims()[0], - "First dimension of padded tensor should be equal to " - "maximum sequence length mulplied by sequence number."); - - for (size_t i = 1; i < ref_lod_.size(); ++i) { - auto seq_start = ref_lod_[i - 1]; - auto seq_end = ref_lod_[i]; - auto pad_start = (i - 1) * max_len; - auto pad_end = pad_start + (seq_end - seq_start); - auto sub_lod_tensor = lod_tensor_->Slice(seq_start, seq_end); - auto sub_pad_tensor = pad_tensor_->Slice(pad_start, pad_end); - if (is_lod_to_pad_) { - framework::TensorCopy(sub_lod_tensor, ctx.GetPlace(), &sub_pad_tensor); - } else { - framework::TensorCopy(sub_pad_tensor, ctx.GetPlace(), &sub_lod_tensor); - } - } - */ - } -}; - template class SequencePadOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /* - auto* x = ctx.Input("X"); - auto* out_ptr = ctx.Output("Out"); - - out_ptr->mutable_data(ctx.GetPlace()); + const auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); - // Resize(); + const auto* pad_value = ctx.Input("PadValue"); + const T* pad_value_data = pad_value->data(); + std::vector pad_value_vec(pad_value_data, + pad_value_data + pad_value->numel()); - T pad_value = static_cast(ctx.Attr("pad_value")); + int padded_length = ctx.Attr("padded_length"); math::PaddingLoDTensorFunctor()( - ctx.template device_context(), *x, *, false); - - math::SetConstant set_func; - set_func(ctx.template device_context(), out_ptr, pad_value); - */ + ctx.template device_context(), *x, out, pad_value_vec, + padded_length, 0, false, math::kBatchLengthWidth); } }; @@ -93,26 +51,17 @@ template class SequencePadGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /* - auto* x_ptr = ctx.Input("X"); - auto* g_out_ptr = ctx.Input(framework::GradVarName("Out")); - auto* g_x_ptr = ctx.Output(framework::GradVarName("X")); - - math::SetConstant set_func; - set_func(ctx.template device_context(), - g_x_ptr, - static_cast(0)); + auto* d_x = ctx.Output(framework::GradVarName("X")); + if (d_x) { + const auto* d_out = ctx.Input(framework::GradVarName("Out")); + d_x->mutable_data(ctx.GetPlace()); - auto& x_lod = x_ptr->lod(); - auto& x_last_level_lod = x_lod[x_lod.size() - 1]; + int padded_length = ctx.Attr("padded_length"); - CopyFunctor copy_func(g_out_ptr, - x_last_level_lod, - g_x_ptr, - ctx, - false); - copy_func(); - */ + math::UnpaddingLoDTensorFunctor()( + ctx.template device_context(), *d_out, d_x, + padded_length, 0, false, math::kBatchLengthWidth); + } } }; diff --git a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py new file mode 100644 index 0000000000..7b9eedbf52 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py @@ -0,0 +1,134 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestSequencePadOp(OpTest): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = -1 + self.dtype = 'float32' + + def set_data(self): + x_data = np.random.uniform(0.1, 0.5, self.x_shape).astype(self.dtype) + pad_value_data = np.array(self.pad_value).astype(self.dtype) + self.inputs = { + 'X': (x_data, self.x_len_lod), + 'PadValue': pad_value_data + } + self.attrs = {'padded_length': self.padded_length} + + def compute(self): + # get padded length + padded_length = self.padded_length + x_len_lod_0 = self.x_len_lod[0] + if padded_length == -1: + max_seq_len = 0 + for l in x_len_lod_0: + max_seq_len = max(max_seq_len, l) + padded_length = max_seq_len + + # do padding + x_data = self.inputs['X'][0] + pad_value_data = self.inputs['PadValue'] + if pad_value_data.shape == (1, ): + pad_value_data = np.broadcast_to( + pad_value_data, shape=x_data.shape[1:]) + padded_sequences = [] + start_idx = 0 + for l in x_len_lod_0: + end_idx = start_idx + l + seq = x_data[start_idx:end_idx] + to_pad_len = padded_length - l + for _ in range(to_pad_len): + seq = np.append(seq, pad_value_data[np.newaxis, :], axis=0) + padded_sequences.append(seq) + start_idx = end_idx + + out_len_lod = self.x_len_lod[:] + out_len_lod_0 = [padded_length] * len(x_len_lod_0) + out_len_lod[0] = out_len_lod_0 + out_data = np.concatenate(padded_sequences, axis=0) + self.outputs = {'Out': (out_data, out_len_lod)} + + def setUp(self): + self.op_type = 'sequence_pad' + self.set_attr() + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequencePadOp2(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0, 2.0, 3.0, 4.0] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp3(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = 7 + self.dtype = 'float32' + + +class TestSequencePadOp4(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0, 2.0, 3.0, 4.0] + self.padded_length = 7 + self.dtype = 'float32' + + +class TestSequencePadOp5(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp6(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [[1.0, 2.0], [3.0, 4.0]] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp7(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = 7 + self.dtype = 'float32' From 34b209cffa81593092a308e2ffe0536b475e81e6 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 20 Aug 2018 16:33:04 +0800 Subject: [PATCH 009/140] Complete sequence_padding GPU kernel --- paddle/fluid/operators/CMakeLists.txt | 1 + .../fluid/operators/math/sequence_padding.cc | 26 +-- .../fluid/operators/math/sequence_padding.cu | 151 ++++++++---------- .../fluid/operators/math/sequence_padding.h | 6 +- .../operators/math/sequence_padding_test.cc | 13 +- paddle/fluid/operators/sequence_pad_op.h | 5 +- paddle/fluid/operators/warpctc_op.h | 15 +- 7 files changed, 113 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ff0e989464..2179a5acdb 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -277,6 +277,7 @@ op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) +op_library(sequence_pad_op DEPS sequence_padding) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index d3dab64f60..02ede3edce 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,8 +18,6 @@ namespace paddle { namespace operators { namespace math { -enum CopyType { kSeqToPad, kPadToSeq }; - template void CopyValidData(framework::Tensor* dst_tensor, const framework::Tensor* src_tensor, @@ -67,7 +65,7 @@ class PaddingLoDTensorFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, framework::LoDTensor* pad_tensor, - std::vector pad_value = {0}, int pad_seq_len = -1, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth) { auto seq_lod = seq_tensor.lod(); @@ -81,19 +79,21 @@ class PaddingLoDTensorFunctor { CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, step_width, layout); - PADDLE_ENFORCE(pad_value.size() == 1 || - static_cast(pad_value.size()) == step_width, - "The size of 'pad_value' can only be 1 or be equal to the " + PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width, + "The numel of 'pad_value' can only be 1 or be equal to the " "'step_width'."); - if (pad_value.size() == 1) { - pad_value = std::vector(step_width, pad_value[0]); - } - // fill padding value T* pad_data = pad_tensor->data(); - for (int i = 0; i < pad_tensor->numel(); i += step_width) { - memcpy(pad_data + i, pad_value.data(), step_width * sizeof(T)); + const T* pad_value_data = pad_value.data(); + if (pad_value.numel() == 1) { + for (int i = 0; i < pad_tensor->numel(); ++i) { + pad_data[i] = *pad_value_data; + } + } else { + for (int i = 0; i < pad_tensor->numel(); i += step_width) { + memcpy(pad_data + i, pad_value_data, step_width * sizeof(T)); + } } CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, @@ -117,7 +117,7 @@ class UnpaddingLoDTensorFunctor { const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - const PadLayout& layout = kBatchLengthWidth) { + const PadLayout layout = kBatchLengthWidth) { auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; const auto& seq_tensor_dims = seq_tensor->dims(); const auto& pad_tensor_dims = pad_tensor.dims(); diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 20e3e3de2a..3b1a44a457 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -19,46 +19,32 @@ namespace paddle { namespace operators { namespace math { -template +template __global__ void SequencePaddingKernel( - T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num, - const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times, - const T& pad_value, const OutputLayout& output_layout) { + T* dst, const T* src, const T* pad_value, bool is_constant_pad, + const size_t* seq_offsets, const size_t& seq_num, const size_t& pad_seq_len, + const size_t& step_width, bool norm_by_len, const PadLayout& layout) { size_t seq_idx = blockIdx.y; - size_t seq_start = seq_offset[seq_idx]; - size_t seq_len = seq_offset[seq_idx + 1] - seq_start; - - size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y; - - size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width; - - size_t pad_data_offset = 0; - - if (output_layout == kLengthBatchWidth) { - pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width; - } else { - pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width; - } - - if (seq_step_idx < seq_len) { - T scale = norm_by_times ? (1.0f / static_cast(seq_len)) : 1.0f; - if (Padding) { - /* seq -> pad */ - for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i]; - } - } else { - /* pad -> seq */ - for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i]; - } + size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; + + size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y; + size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width; + size_t pad_data_offset = layout == kBatchLengthWidth + ? (seq_idx * pad_seq_len + step_idx) * step_width + : (step_idx * seq_num + seq_idx) * step_width; + + T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset); + const T* src_data = + src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset); + + if (step_idx < seq_len) { + float scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; + for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { + dst_data[i] = scale * src_data[i]; } - } else if (seq_step_idx < max_seq_len) { - if (Padding) { - /* seq -> pad */ - for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - pad_data[pad_data_offset + i] = pad_value; - } + } else if (step_idx < pad_seq_len && Type == kSeqToPad) { + for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { + dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i]; } } } @@ -69,24 +55,26 @@ class PaddingLoDTensorFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::LoDTensor& seq_tensor, framework::Tensor* pad_tensor, - T pad_value = static_cast(0), bool norm_by_times = false, - size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(seq_tensor, lod_level); - - auto& lod = seq_tensor.lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - - auto seq_tensor_dims = seq_tensor.dims(); - auto pad_tensor_dims = pad_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; + const framework::LoDTensor& pad_value, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_lod = seq_tensor.lod(); + const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; + int seq_num = seq_offset.size() - 1; - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width, + "The numel of 'pad_value' can only be 1 or be equal to the " + "'step_width'."); - if (!norm_by_times && seq_num == 1UL) { + if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) { TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor); pad_tensor->Resize(pad_tensor_dims); return; @@ -98,21 +86,22 @@ class PaddingLoDTensorFunctor { * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); const T* seq_data = seq_tensor.data(); T* pad_data = pad_tensor->data(); + const T* pad_value_data = pad_value.data(); - SequencePaddingKernel<<>>( - pad_data, const_cast(seq_data), - seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, norm_by_times, pad_value, output_layout); + SequencePaddingKernel<<>>( + pad_data, seq_data, pad_value_data, pad_value.numel() == 1, + seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + step_width, norm_by_times, layout); } }; @@ -120,25 +109,23 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - framework::LoDTensor* seq_tensor, - const framework::Tensor& pad_tensor, - bool norm_by_times = false, size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(*seq_tensor, lod_level); - - auto& lod = seq_tensor->lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - - auto seq_tensor_dims = seq_tensor->dims(); - auto pad_tensor_dims = pad_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + int seq_num = seq_offset.size() - 1; - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); - if (!norm_by_times && seq_num == 1UL) { + if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) { TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor); seq_tensor->Resize(seq_tensor_dims); return; @@ -150,21 +137,21 @@ class UnpaddingLoDTensorFunctor { * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); const T* pad_data = pad_tensor.data(); T* seq_data = seq_tensor->data(); - SequencePaddingKernel<<>>( - const_cast(pad_data), seq_data, - seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, norm_by_times, static_cast(0), output_layout); + SequencePaddingKernel<<>>( + seq_data, pad_data, nullptr, false, + seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + step_width, norm_by_times, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 9b8c892c53..3fb5859e3b 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -25,6 +25,8 @@ namespace math { enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; +enum CopyType { kSeqToPad, kPadToSeq }; + inline static size_t MaximumSequenceLength( const framework::Vector& seq_offset) { size_t seq_num = seq_offset.size() - 1; @@ -82,7 +84,7 @@ class PaddingLoDTensorFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, framework::LoDTensor* pad_tensor, - std::vector pad_value = {0}, int pad_seq_len = -1, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth); }; @@ -94,7 +96,7 @@ class UnpaddingLoDTensorFunctor { const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - const PadLayout& layout = kBatchLengthWidth); + const PadLayout layout = kBatchLengthWidth); }; } // namespace math diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index 3171c7c33e..4f61b1029c 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -24,6 +24,8 @@ void TestSequencePadding(const paddle::framework::LoD& lod, paddle::framework::LoDTensor seq; paddle::framework::LoDTensor seq_back; paddle::framework::LoDTensor padding; + paddle::framework::LoDTensor cpu_pad_value; + paddle::framework::LoDTensor pad_value; const size_t level = lod.size() - 1; auto seq_dims = @@ -55,8 +57,17 @@ void TestSequencePadding(const paddle::framework::LoD& lod, padding.mutable_data(padding_dims, *place); + T* pad_value_data = + cpu_pad_value.mutable_data({1}, paddle::platform::CPUPlace()); + *pad_value_data = static_cast(0); + if (paddle::platform::is_cpu_place(*place)) { + pad_value = cpu_pad_value; + } else { + TensorCopySync(cpu_pad_value, *place, &pad_value); + } + paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, &padding, {0}, -1, 0, false, + *context, seq, &padding, pad_value, -1, 0, false, paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h index 44aff30879..5fc9da69d7 100644 --- a/paddle/fluid/operators/sequence_pad_op.h +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -35,14 +35,11 @@ class SequencePadOpKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); const auto* pad_value = ctx.Input("PadValue"); - const T* pad_value_data = pad_value->data(); - std::vector pad_value_vec(pad_value_data, - pad_value_data + pad_value->numel()); int padded_length = ctx.Attr("padded_length"); math::PaddingLoDTensorFunctor()( - ctx.template device_context(), *x, out, pad_value_vec, + ctx.template device_context(), *x, out, *pad_value, padded_length, 0, false, math::kBatchLengthWidth); } }; diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 6cbf985039..444265f58d 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -161,10 +161,21 @@ class WarpCTCKernel : public framework::OpKernel { static_cast(num_sequences), static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); + + LoDTensor cpu_pad_value; + T* pad_value_data = + cpu_pad_value.mutable_data({1}, platform::CPUPlace()); + *pad_value_data = static_cast(0); + LoDTensor pad_value; + if (platform::is_cpu_place(ctx.GetPlace())) { + pad_value = cpu_pad_value; + } else { + TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value); + } + math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, - {static_cast(0)}, -1, 0, false /* norm_by_times */, - math::kLengthBatchWidth); + pad_value, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth); const T* warpctc_logits_data = warpctc_logits.data(); std::vector warpctc_label_lengths(num_sequences); From d94a3f621b3d5685505bf7e508103823fa6b0652 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 21 Aug 2018 17:41:16 +0800 Subject: [PATCH 010/140] Disable prelu_op_test until fixing Python3 issues --- .../fluid/tests/unittests/test_prelu_op.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 979be5af3b..1e3e40d54a 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -51,30 +51,28 @@ class PReluTest(OpTest): def test_check_output(self): self.check_output() - def test_check_grad(self): - self.check_grad(['X', 'Alpha'], 'Out') - - def test_check_grad_ignore_x(self): + def test_check_grad_1_ignore_x(self): self.check_grad(['Alpha'], 'Out', no_grad_set=set('X')) - def test_check_grad_ignore_alpha(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) - - -class TestCase1(PReluTest): - def initTestCase(self): - self.attrs = {'mode': "all"} + def test_check_grad_2(self): + self.check_grad(['X', 'Alpha'], 'Out') + def test_check_grad_3_ignore_alpha(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) -class TestCase2(PReluTest): - def initTestCase(self): - self.attrs = {'mode': "channel"} +# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues +# class TestCase1(PReluTest): +# def initTestCase(self): +# self.attrs = {'mode': "all"} -class TestCase3(PReluTest): - def initTestCase(self): - self.attrs = {'mode': "element"} +# class TestCase2(PReluTest): +# def initTestCase(self): +# self.attrs = {'mode': "channel"} +# class TestCase3(PReluTest): +# def initTestCase(self): +# self.attrs = {'mode': "element"} if __name__ == "__main__": unittest.main() From 39c526d42fbfdd410c8bb11084a18b019460db7b Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 21 Aug 2018 20:18:01 +0800 Subject: [PATCH 011/140] Port test_dist_transpiler to it --- .../fluid/tests/unittests/test_dist_transpiler.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 9f04d290f7..1d9ab44ed4 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -21,6 +21,7 @@ import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import delete_ops import traceback import collections +import six class TranspilerTest(unittest.TestCase): @@ -644,18 +645,18 @@ class TestLoadSliceVar(TranspilerTest): self.assertTrue(pserver._slice_vars_and_attrs) self.assertTrue(pserver2._slice_vars_and_attrs) - for idx in xrange(len(pserver._slice_vars_and_attrs)): + for idx in six.moves.xrange(len(pserver._slice_vars_and_attrs)): self.assertEqual(pserver._slice_vars_and_attrs[idx][0], pserver2._slice_vars_and_attrs[idx][0]) - total_numel = reduce(lambda x, y: x * y, - pserver._slice_vars_and_attrs[idx][0].shape) + total_numel = six.moves.reduce( + lambda x, y: x * y, pserver._slice_vars_and_attrs[idx][0].shape) self.assertEqual( total_numel, - reduce(lambda x, y: x * y, - pserver._slice_vars_and_attrs[idx][2].shape) + reduce( - lambda x, y: x * y, - pserver2._slice_vars_and_attrs[idx][2].shape)) + six.moves.reduce(lambda x, y: x * y, + pserver._slice_vars_and_attrs[idx][2].shape) + + six.moves.reduce(lambda x, y: x * y, + pserver2._slice_vars_and_attrs[idx][2].shape)) if __name__ == "__main__": From dd7a79158b17f3613ff66b9c4db7691074fb6218 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 21 Aug 2018 19:54:10 +0800 Subject: [PATCH 012/140] add scope info in graphviz debug --- .../fluid/framework/details/multi_devices_graph_print_pass.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index 69944a42b6..361c91dc78 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -54,7 +54,8 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_ << "\\n" << var_handle_ptr->place_ << "\\n" - << var_handle_ptr->version_ << "\"]" << std::endl; + << "scope: " << var_handle_ptr->scope_idx_ << "\\n" + << "v" << var_handle_ptr->version_ << "\"]" << std::endl; } else if (dummy_ptr) { sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl; } From ce182d9037b988dcbf1c7b86dafd60745afb2d4c Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 22 Aug 2018 11:23:23 +0800 Subject: [PATCH 013/140] bug fix --- .../fluid/operators/math/sequence_padding.cu | 38 ++++++++++++------- .../fluid/operators/math/sequence_padding.h | 4 +- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 3b1a44a457..93d239351a 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -22,8 +22,8 @@ namespace math { template __global__ void SequencePaddingKernel( T* dst, const T* src, const T* pad_value, bool is_constant_pad, - const size_t* seq_offsets, const size_t& seq_num, const size_t& pad_seq_len, - const size_t& step_width, bool norm_by_len, const PadLayout& layout) { + const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len, + const size_t step_width, bool norm_by_len, const PadLayout layout) { size_t seq_idx = blockIdx.y; size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; @@ -43,7 +43,7 @@ __global__ void SequencePaddingKernel( dst_data[i] = scale * src_data[i]; } } else if (step_idx < pad_seq_len && Type == kSeqToPad) { - for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { + for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i]; } } @@ -54,7 +54,7 @@ class PaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* pad_tensor, + framework::LoDTensor* pad_tensor, const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth) { @@ -62,11 +62,12 @@ class PaddingLoDTensorFunctor { const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; const auto& seq_tensor_dims = seq_tensor.dims(); const auto& pad_tensor_dims = pad_tensor->dims(); + int max_seq_len = MaximumSequenceLength(seq_offsets); if (pad_seq_len == -1) { - pad_seq_len = MaximumSequenceLength(seq_offsets); + pad_seq_len = max_seq_len; } int step_width = seq_tensor.numel() / seq_tensor_dims[0]; - int seq_num = seq_offset.size() - 1; + int seq_num = seq_offsets.size() - 1; CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, step_width, layout); @@ -74,13 +75,13 @@ class PaddingLoDTensorFunctor { "The numel of 'pad_value' can only be 1 or be equal to the " "'step_width'."); - if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) { + if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) { TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor); pad_tensor->Resize(pad_tensor_dims); return; } - const int64_t kBlockSize = 512; + const int kBlockSize = 512; /* At least use 32 threads to copy sequence_width elements, * and at least 8 elements for each thread. @@ -100,8 +101,16 @@ class PaddingLoDTensorFunctor { SequencePaddingKernel<<>>( pad_data, seq_data, pad_value_data, pad_value.numel() == 1, - seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, step_width, norm_by_times, layout); + + if (layout == kBatchLengthWidth) { + framework::LoD pad_lod(seq_lod.begin() + lod_level, seq_lod.end()); + for (size_t i = 0; i < pad_lod[0].size(); ++i) { + pad_lod[0][i] = i * pad_seq_len; + } + pad_tensor->set_lod(pad_lod); + } } }; @@ -116,22 +125,23 @@ class UnpaddingLoDTensorFunctor { auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; const auto& seq_tensor_dims = seq_tensor->dims(); const auto& pad_tensor_dims = pad_tensor.dims(); + int max_seq_len = MaximumSequenceLength(seq_offsets); if (pad_seq_len == -1) { - pad_seq_len = MaximumSequenceLength(seq_offsets); + pad_seq_len = max_seq_len; } int step_width = seq_tensor->numel() / seq_tensor_dims[0]; - int seq_num = seq_offset.size() - 1; + int seq_num = seq_offsets.size() - 1; CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, step_width, layout); - if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) { + if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) { TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor); seq_tensor->Resize(seq_tensor_dims); return; } - const int64_t kBlockSize = 512; + const int kBlockSize = 512; /* At least use 32 threads to copy sequence_width elements, * and at least 8 elements for each thread. @@ -150,7 +160,7 @@ class UnpaddingLoDTensorFunctor { SequencePaddingKernel<<>>( seq_data, pad_data, nullptr, false, - seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, step_width, norm_by_times, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 3fb5859e3b..e752aa5897 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -81,7 +81,7 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims, template class PaddingLoDTensorFunctor { public: - void operator()(const platform::CPUDeviceContext& context, + void operator()(const DeviceContext& context, const framework::LoDTensor& seq_tensor, framework::LoDTensor* pad_tensor, const framework::LoDTensor& pad_value, int pad_seq_len = -1, @@ -92,7 +92,7 @@ class PaddingLoDTensorFunctor { template class UnpaddingLoDTensorFunctor { public: - void operator()(const platform::CPUDeviceContext& context, + void operator()(const DeviceContext& context, const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, From 211d81863daefee0757f9cd5e8146382d99d58ec Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 22 Aug 2018 06:17:28 +0000 Subject: [PATCH 014/140] Process elemwise grad op's lod. mul_op's lod --- .../operators/elementwise_add_mkldnn_op.cc | 3 +- paddle/fluid/operators/elementwise_add_op.h | 5 +++- paddle/fluid/operators/elementwise_div_op.h | 5 ++-- paddle/fluid/operators/elementwise_max_op.h | 4 ++- paddle/fluid/operators/elementwise_min_op.h | 5 ++-- paddle/fluid/operators/elementwise_mul_op.h | 5 ++-- paddle/fluid/operators/elementwise_op.h | 14 +++++++++ paddle/fluid/operators/elementwise_sub_op.h | 4 ++- paddle/fluid/operators/mul_op.h | 30 ++++++++++++------- 9 files changed, 54 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc index c86cd57316..9ad82aec81 100644 --- a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc @@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { }; template -class EltwiseAddMKLDNNGradKernel : public framework::OpKernel { +class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 5356105e2e..c60cb1f92e 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" @@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } template -class ElementwiseAddGradKernel : public framework::OpKernel { +class ElementwiseAddGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/elementwise_div_op.h b/paddle/fluid/operators/elementwise_div_op.h index 95649ac46e..41a7950bf0 100644 --- a/paddle/fluid/operators/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise_div_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -53,9 +53,10 @@ struct DivGradDY { }; template -class ElementwiseDivGradKernel : public framework::OpKernel { +class ElementwiseDivGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_max_op.h b/paddle/fluid/operators/elementwise_max_op.h index 527a18ee3b..bfb5c93195 100644 --- a/paddle/fluid/operators/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise_max_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" namespace paddle { @@ -55,9 +56,10 @@ struct MaxGradDy { }; template -class ElementwiseMaxGradKernel : public framework::OpKernel { +class ElementwiseMaxGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_min_op.h b/paddle/fluid/operators/elementwise_min_op.h index d4e5831463..db035ffb52 100644 --- a/paddle/fluid/operators/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise_min_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -55,9 +55,10 @@ struct MinGradDy { }; template -class ElementwiseMinGradKernel : public framework::OpKernel { +class ElementwiseMinGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index dc73cb6f23..82c5fa0472 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -50,9 +50,10 @@ struct MulGradDY { }; template -class ElementwiseMulGradKernel : public framework::OpKernel { +class ElementwiseMulGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index d8a12e800a..a79b900b98 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { } }; +template +class ElemwiseGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dx = + context.Output(framework::GradVarName("X")); + if (dx != nullptr) { + auto& dout = + *context.Input(framework::GradVarName("Out")); + dx->set_lod(dout.lod()); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise_sub_op.h index 11c7e3fe62..3385df0897 100644 --- a/paddle/fluid/operators/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise_sub_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" namespace paddle { @@ -50,9 +51,10 @@ struct SubGradDY { }; template -class ElementwiseSubGradKernel : public framework::OpKernel { +class ElementwiseSubGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index 15dd975e3b..f72824806e 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { int x_num_col_dims = ctx.template Attr("x_num_col_dims"); int y_num_col_dims = ctx.template Attr("y_num_col_dims"); - const Tensor* x = ctx.Input("X"); - const Tensor* y = ctx.Input("Y"); - const Tensor x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, x_num_col_dims) - : *x; - const Tensor y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, y_num_col_dims) - : *y; - const Tensor* dout = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto x_matrix = x->dims().size() > 2 + ? framework::ReshapeToMatrix(*x, x_num_col_dims) + : static_cast(*x); + auto y_matrix = y->dims().size() > 2 + ? framework::ReshapeToMatrix(*y, y_num_col_dims) + : static_cast(*y); + auto* dout = ctx.Input(framework::GradVarName("Out")); Tensor dout_mat; dout_mat.ShareDataWith(*dout); dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0], framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); - Tensor* dx = ctx.Output(framework::GradVarName("X")); - Tensor* dy = ctx.Output(framework::GradVarName("Y")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + if (dx != nullptr) { + dx->set_lod(x->lod()); + } + if (dy != nullptr) { + dy->set_lod(y->lod()); + } + auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); if (dx) { From 2a36ad1a9655c3f618d4f77ca753f8b0fb214399 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 22 Aug 2018 06:48:00 +0000 Subject: [PATCH 015/140] Handle LoD for concat & seq_softmax ops --- paddle/fluid/operators/concat_op.h | 16 ++++++++++++++-- paddle/fluid/operators/math/concat.cc | 2 +- paddle/fluid/operators/math/concat.cu | 2 +- paddle/fluid/operators/math/concat.h | 4 ++-- paddle/fluid/operators/sequence_softmax_op.h | 3 +++ 5 files changed, 21 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index a496301526..78be2e1e1f 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); + auto ins = ctx.MultiInput("X"); auto out_var_names = ctx.Outputs(framework::GradVarName("X")); - auto outs = ctx.MultiOutput(framework::GradVarName("X")); + auto outs = + ctx.MultiOutput(framework::GradVarName("X")); + + { + auto dx = outs; + auto x = ins; + for (size_t i = 0; i < dx.size(); ++i) { + if (dx[i] != nullptr) { + dx[i]->set_lod(x[i]->lod()); + } + } + } + int64_t axis = static_cast(ctx.Attr("axis")); // get output tensor that the name is not kEmptyVarName diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index 55c8a472ac..fbe7c29783 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -71,7 +71,7 @@ class ConcatGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, - const std::vector& ref_inputs, + const std::vector& ref_inputs, const int axis, std::vector* outputs) { // TODO(zcd): Add input data validity checking size_t num = outputs->size(); diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 5863d74fca..820e73e779 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -189,7 +189,7 @@ class ConcatGradFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, - const std::vector& ref_inputs, + const std::vector& ref_inputs, const int axis, std::vector* outputs) { // TODO(zcd): Add input data validity checking int o_num = outputs->size(); diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h index 9e080f2e8b..e5d7d860b3 100644 --- a/paddle/fluid/operators/math/concat.h +++ b/paddle/fluid/operators/math/concat.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/lod_tensor.h" namespace paddle { namespace operators { @@ -57,7 +57,7 @@ template class ConcatGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - const std::vector& ref_inputs, + const std::vector& ref_inputs, const int axis, std::vector* outputs); }; diff --git a/paddle/fluid/operators/sequence_softmax_op.h b/paddle/fluid/operators/sequence_softmax_op.h index cb93a02b83..bca564e16f 100644 --- a/paddle/fluid/operators/sequence_softmax_op.h +++ b/paddle/fluid/operators/sequence_softmax_op.h @@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel { auto* out_grad = ctx.Input(framework::GradVarName("Out")); auto* x = ctx.Input("X"); auto* x_grad = ctx.Output(framework::GradVarName("X")); + if (x_grad) { + x_grad->set_lod(x->lod()); + } auto lod = x->lod(); const size_t level = lod.size() - 1; From 94f6e54db93f06790c84e5932109f08a787c5b2a Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 22 Aug 2018 15:22:07 +0800 Subject: [PATCH 016/140] Add timeout for python3 --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e7dd85ef5c..f2dce9d265 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -64,6 +64,7 @@ if(WITH_DISTRIBUTE) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) +set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 200) py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) From 57dab0bb4c4fb7f902f183d78cc197ebfec27e67 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 22 Aug 2018 15:57:07 +0800 Subject: [PATCH 017/140] Change the link of flowers --- python/paddle/dataset/flowers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index aa73bbaf70..2a020ce6d0 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -45,9 +45,9 @@ from six.moves import cPickle as pickle from six.moves import zip __all__ = ['train', 'test', 'valid'] -DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' -LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' -SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' +DATA_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/102flowers.tgz' +LABEL_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/imagelabels.mat' +SETID_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/setid.mat' DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' From 6d9b9cb4b6b44db964c592e93528cd6ad8ccfa76 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 22 Aug 2018 16:26:42 +0800 Subject: [PATCH 018/140] Add debug info for anakin cpu --- paddle/scripts/paddle_build.sh | 6 ++++++ python/paddle/fluid/tests/unittests/CMakeLists.txt | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 8460f93b84..1e3e2ed0e9 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -192,6 +192,12 @@ function build() { cd ${PADDLE_ROOT}/build cat < Date: Wed, 22 Aug 2018 16:53:21 +0800 Subject: [PATCH 019/140] fix load_vars bug (#12869) --- python/paddle/fluid/io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index b3ed094c89..5c4ec99c53 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -406,6 +406,9 @@ def load_vars(executor, attrs={'file_path': os.path.join(dirname, filename)}) executor.run(load_prog) + if main_program is None: + main_program = default_main_program() + # load slice vars on pserver, if have it. _load_slice_up_vars(executor, dirname, main_program._slice_vars_and_attrs) From eb8fd853bceb45cdc5cdb57095d075cc7b260f2c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 22 Aug 2018 09:15:06 +0000 Subject: [PATCH 020/140] Fix sequence_softmax_cudnn op --- paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc b/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc index 0ddacb5710..7aca9f7111 100644 --- a/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc @@ -68,7 +68,9 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel { auto* out_grad = ctx.Input(framework::GradVarName("Out")); auto* x = ctx.Input("X"); auto* x_grad = ctx.Output(framework::GradVarName("X")); - + if (x_grad) { + x_grad->set_lod(x->lod()); + } auto lod = x->lod(); const size_t level = lod.size() - 1; From 774896347943f7100adc9763dad529ffd5754f6e Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 22 Aug 2018 18:57:30 +0800 Subject: [PATCH 021/140] refine op_test (#12846) --- python/paddle/fluid/tests/unittests/op_test.py | 2 +- python/paddle/fluid/tests/unittests/testsuite.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 972e44c952..44cd073379 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -56,8 +56,8 @@ def get_numeric_gradient(place, def get_output(): sum = [] + op.run(scope, place) for output_name in output_names: - op.run(scope, place) sum.append( np.array(scope.find_var(output_name).get_tensor()).mean()) return np.array(sum).mean() diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index 31ae25f02c..34fbb1b549 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -153,9 +153,6 @@ def append_input_output(block, op_proto, np_list, is_input, dtype): def append_loss_ops(block, output_names): mean_inputs = list(map(block.var, output_names)) - # for item in mean_inputs: - # print(item) - # print("Item", item.dtype) if len(mean_inputs) == 1: loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) From 6d107b0f392c8471ca6295a9a5366fd390cb7950 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 22 Aug 2018 19:43:56 +0800 Subject: [PATCH 022/140] Fix the test_desc_clone's problem --- paddle/scripts/paddle_build.sh | 6 ------ python/paddle/dataset/flowers.py | 2 +- python/paddle/fluid/tests/unittests/CMakeLists.txt | 2 +- .../fluid/tests/unittests/test_desc_clone.py | 14 ++++++-------- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 1e3e2ed0e9..8460f93b84 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -192,12 +192,6 @@ function build() { cd ${PADDLE_ROOT}/build cat < Date: Wed, 22 Aug 2018 20:06:48 +0800 Subject: [PATCH 023/140] Disable in_place in batch_norm API. (#12736) * Disable in_place in batch_norm API. --- paddle/fluid/operators/batch_norm_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 9 +++++++-- python/paddle/fluid/nets.py | 2 +- .../paddle/fluid/tests/book/test_image_classification.py | 5 ++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 5912a1a17c..969f75544f 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Variance", "The global variance (for training) " "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization").Reuse("X"); + AddOutput("Y", "result after normalization"); AddOutput("MeanOut", "Share memory with Mean. " "Store the global mean when training") diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 71592618f5..a815ba0f2f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,6 +27,7 @@ from . import utils import random from .. import unique_name from functools import reduce +import warnings __all__ = [ 'fc', @@ -2046,7 +2047,7 @@ def batch_norm(input, param_attr(ParamAttr): The parameter attribute for Parameter `scale`. bias_attr(ParamAttr): The parameter attribute for Parameter `bias`. data_layout(string, default NCHW): NCHW|NHWC - in_place(bool, Default False): Make the input and output of batch norm reuse memory. + in_place(bool, Default False): This argument is deprecated since 0.15.0. use_mkldnn(bool, Default false): ${use_mkldnn_comment} name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2068,6 +2069,10 @@ def batch_norm(input, helper = LayerHelper('batch_norm', **locals()) dtype = helper.input_dtype() + if in_place: + raise warnings.warn("The argument in_place is deprecated since 0.15.0, " + "please do not set it True.") + input_shape = input.shape if data_layout == 'NCHW': channel_num = input_shape[1] @@ -2117,7 +2122,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) + batch_norm_out = helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 051fe84364..01563cbbb7 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -229,7 +229,7 @@ def img_conv_group(input, use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) + tmp = layers.batch_norm(input=tmp, act=conv_act) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 9fe361425c..cd1e8cd682 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -256,7 +256,10 @@ def main(net_type, use_cuda, is_local=True): save_dirname = "image_classification_" + net_type + ".inference.model" train(net_type, use_cuda, save_dirname, is_local) - infer(use_cuda, save_dirname) + + # There is bug in fluid.InferenceTranspiler for VGG. + if net_type == "resnet": + infer(use_cuda, save_dirname) class TestImageClassification(unittest.TestCase): From d49a0d755b6fa74c5cda9915f9238d600916c9d9 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 22 Aug 2018 20:57:17 +0800 Subject: [PATCH 024/140] Fix common download problem --- python/paddle/dataset/common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 1d7ff582c8..ece4046f5b 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -19,6 +19,7 @@ import hashlib import os import errno import shutil +import six import sys import importlib import paddle.dataset @@ -94,6 +95,8 @@ def download(url, module_name, md5sum, save_name=None): dl = 0 total_length = int(total_length) for data in r.iter_content(chunk_size=4096): + if six.PY2: + data = six.b(data) dl += len(data) f.write(data) done = int(50 * dl / total_length) From 064b7f3de1d9584505e5b68f0a3822304f24e899 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 22 Aug 2018 20:57:43 +0800 Subject: [PATCH 025/140] Change the md5sum of 102flowers dataset --- python/paddle/dataset/flowers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index fd191b6a6d..ce0cd6009a 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -48,7 +48,7 @@ __all__ = ['train', 'test', 'valid'] DATA_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/102flowers.tgz' LABEL_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/imagelabels.mat' SETID_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/setid.mat' -DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' +DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' # In official 'readme', tstid is the flag of test data From f72ab8961e443c030a50f373f21eceac3800f528 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 11:43:20 +0800 Subject: [PATCH 026/140] refine blas gemm --- CMakeLists.txt | 6 -- paddle/fluid/operators/math/blas.h | 9 ++ paddle/fluid/operators/math/blas_impl.h | 116 +++++++++++------------ paddle/fluid/operators/math/fc_compute.h | 22 +++-- 4 files changed, 77 insertions(+), 76 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e52961a9..317f7f9eb4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,12 +138,6 @@ else() set(THIRD_PARTY_BUILD_TYPE Release) endif() -if(WITH_MKL) - option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF) - if (MKL_SPLIT_GEMM) - add_definitions(-DPADDLE_MKL_SPLIT_GEMM) - endif() -endif() set(WITH_MKLML ${WITH_MKL}) if (NOT DEFINED WITH_MKLDNN) if (WITH_MKL AND AVX2_FOUND) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 8dcf7c99f3..295431347a 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,11 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + template + void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, + int ldc) const; + #ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, @@ -109,6 +114,10 @@ class Blas { void GEMM_FREE(T* data) const; #endif + template + void MatMul(const int M, const int N, const int K, const T* A, const T* B, + T* C) const; + template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index dc77b6d793..d39a3e7f6e 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -217,64 +217,6 @@ struct CBlas { #endif }; -template -inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa, - bool transb, const T &alpha, const T &beta) { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom - constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - if (m * n * k > LIBXSMM_THRESHOLD || transa || transb || - std::abs(alpha - static_cast(1) > - std::numeric_limits::epsilon()) || - std::abs(beta) > std::numeric_limits::epsilon()) { - return false; - } else { - return true; - } -#endif - return false; -} - -template <> -inline bool UseXSMM(const int &m, const int &n, const int &k, - bool transa, bool transb, - const platform::float16 &alpha, - const platform::float16 &beta) { - return false; -} - -template -inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, - const T *A, int lda, const T *B, int ldb, T beta, T *C, - int ldc) { -#ifdef PADDLE_WITH_LIBXSMM - if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, - beta)) { - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, - &beta, C, &ldc); - return; - } -#endif - -#ifdef PADDLE_MKL_SPLIT_GEMM - constexpr int bs = 2; - if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) { - for (int off = 0; off < M; off += bs) { - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha, - A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc); - } - return; - } -#endif - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - #ifdef PADDLE_WITH_MKLML template <> template @@ -319,8 +261,8 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; - GEMM_WARP(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> @@ -329,9 +271,20 @@ void Blas::GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T *A, int lda, const T *B, int ldb, T beta, T *C, int ldc) const { - GEMM_WARP(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, + int N, int K, T alpha, const T *A, + int lda, const T *B, int ldb, + T beta, T *C, int ldc) const { + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template @@ -440,6 +393,43 @@ void Blas::BatchedGEMM( #endif } +template +template +void Blas::MatMul(const int M, const int N, const int K, + const T *A, const T *B, T *C) const { + this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, + N); +} + +template <> +template +void Blas::MatMul(const int M, const int N, + const int K, const T *A, + const T *B, T *C) const { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + + // Since the matrix is very small, + // so the unit of calculation is already very fast, + // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, + // use xsmm directly. + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const T alpha = static_cast(1); + const T beta = static_cast(0); + CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, + C, &N); + return; + +#endif + + CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, N); +} + template template void Blas::MatMul(const framework::Tensor &mat_a, diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 8600fa9e2c..1f5a49c0ab 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -25,17 +25,25 @@ namespace math { template inline void FCCompute(const BlasT& blas, const int M, const int N, const int K, const T* X, const T* W, T* Y, - const T* B = NULL) { - blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), X, W, - static_cast(0), Y); - if (B) { + const T* B = NULL, bool relu = false) { + blas.MatMul(M, N, K, X, W, Y); + if (B == NULL) { + return; + } + #ifdef PADDLE_WITH_MKLML #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) #endif - for (int i = 0; i < M; i++) { - blas.AXPY(N, static_cast(1), B, Y + i * N); - } + for (int i = 0; i < M; i++) { + blas.AXPY(N, static_cast(1), B, Y + i * N); } + + if (!relu) { + return; + } + + // TODO(TJ): fuse relu + LOG(FATAL) << "Not implemented!"; } } // namespace math From a2203d0466462fcde20bdd80d79a0f7964760eb8 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 12:08:31 +0800 Subject: [PATCH 027/140] add cblas dot --- paddle/fluid/operators/math/blas.h | 3 +++ paddle/fluid/operators/math/blas_impl.h | 27 ++++++++++++++++++++++++- paddle/fluid/platform/dynload/mklml.h | 2 ++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 295431347a..96d481f739 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -153,6 +153,9 @@ class Blas { void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; + template + T DOT(int n, const T* x, const T* y) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index d39a3e7f6e..bbd9d4b60a 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -73,6 +73,11 @@ struct CBlas { platform::dynload::cblas_sgemv(args...); } + template + static float DOT(ARGS... args) { + return platform::dynload::cblas_sdot(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -138,6 +143,11 @@ struct CBlas { platform::dynload::cblas_dgemv(args...); } + template + static double DOT(ARGS... args) { + return platform::dynload::cblas_ddot(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -210,6 +220,7 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -352,6 +363,21 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +T Blas::DOT(int n, const T *x, const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, y); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, @@ -423,7 +449,6 @@ void Blas::MatMul(const int M, const int N, CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); return; - #endif CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 15ad4a3b40..6efa160df0 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -66,6 +66,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_free); \ __macro(cblas_sgemm_batch); \ __macro(cblas_dgemm_batch); \ + __macro(cblas_sdot); \ + __macro(cblas_ddot); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \ From 0ec1f65cf110ee4e73a7bfa03456b52111426288 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 12:47:10 +0800 Subject: [PATCH 028/140] fix blas dot and add cblas scal --- paddle/fluid/operators/math/blas.h | 3 +++ paddle/fluid/operators/math/blas_impl.h | 27 ++++++++++++++++++++++++- paddle/fluid/platform/dynload/mklml.h | 2 ++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 96d481f739..fc02534a69 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -156,6 +156,9 @@ class Blas { template T DOT(int n, const T* x, const T* y) const; + template + void SCAL(int n, const T a, const T* x) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index bbd9d4b60a..b7c56e8df1 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -78,6 +78,11 @@ struct CBlas { return platform::dynload::cblas_sdot(args...); } + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_sscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -148,6 +153,11 @@ struct CBlas { return platform::dynload::cblas_ddot(args...); } + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_dscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -221,6 +231,7 @@ struct CBlas { } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; + static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -367,7 +378,7 @@ template <> template T Blas::DOT(int n, const T *x, const T *y) const { #ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, y); + return CBlas::DOT(n, x, 1, y, 1); #else // try to find if openblas support cblas_dot T sum = 0; @@ -378,6 +389,20 @@ T Blas::DOT(int n, const T *x, const T *y) const { #endif } +template <> +template +void Blas::SCAL(int n, const T a, + const T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 6efa160df0..e50ea6740a 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -68,6 +68,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(cblas_sdot); \ __macro(cblas_ddot); \ + __macro(cblas_sscal); \ + __macro(cblas_dscal); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \ From 3dd66390b2702fe3083fee5e84f2ad6d5322b76b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 13:13:58 +0800 Subject: [PATCH 029/140] add blas vexp --- paddle/fluid/operators/math/blas.h | 3 +++ paddle/fluid/operators/math/blas_impl.h | 24 ++++++++++++++++++++++++ paddle/fluid/platform/dynload/mklml.h | 2 ++ 3 files changed, 29 insertions(+) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index fc02534a69..5aba170221 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -149,6 +149,9 @@ class Blas { template void VCOPY(int n, const T* x, T* y) const; + template + void VEXP(int n, const T* x, T* y) const; + template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index b7c56e8df1..eaad83ba18 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -97,6 +97,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vsMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vsExp(args...); + } }; template <> @@ -172,6 +177,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vdMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vdExp(args...); + } }; #else @@ -230,6 +240,7 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML @@ -374,6 +385,19 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VEXP(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} + template <> template T Blas::DOT(int n, const T *x, const T *y) const { diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index e50ea6740a..aa20553cef 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -74,6 +74,8 @@ extern void* mklml_dso_handle; __macro(vdAdd); \ __macro(vsMul); \ __macro(vdMul); \ + __macro(vsExp); \ + __macro(vdExp); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); From 9affc36c89c2df4e26d00b1a081db0eabfd8e4fe Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 20 Aug 2018 21:15:39 +0800 Subject: [PATCH 030/140] init attention lstm --- paddle/fluid/operators/attention_lstm_op.cc | 354 ++++++++++++++++++++ paddle/fluid/operators/attention_lstm_op.h | 42 +++ 2 files changed, 396 insertions(+) create mode 100644 paddle/fluid/operators/attention_lstm_op.cc create mode 100644 paddle/fluid/operators/attention_lstm_op.h diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc new file mode 100644 index 0000000000..087df06ad5 --- /dev/null +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -0,0 +1,354 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/attention_lstm_op.h" +#include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/detail/activation_functions.h" +#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/lstm_compute.h" +#include "paddle/fluid/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { + +void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("WeightX"), + "Input(WeightX) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("WeightH"), + "Input(WeightH) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(Bias) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("XX"), + "Output(XX) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), + "Output(BatchedGate) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), + "Output(BatchedGate) of LSTM should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(Cell) and Input(Hidden) of LSTM should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + auto wx_dims = ctx->GetInputDim("WeightX"); + PADDLE_ENFORCE_EQ(wx_dims.size(), 2, + "The rank of Input(WeightX) should be 2."); + PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], + "The first dimension of Input(WeightX) " + "should be %d.", + x_dims[1]); + + int frame_size = wx_dims[1] / 4; + auto wh_dims = ctx->GetInputDim("WeightH"); + PADDLE_ENFORCE_EQ(wh_dims.size(), 2, + "The rank of Input(WeightH) should be 2."); + PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, + "The first dimension of Input(WeightH) " + "should be %d.", + frame_size); + PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, + "The second dimension of Input(WeightH) " + "should be 4 * %d.", + frame_size); + + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + + PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"), + "Do not support peephole yet."); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes connection", + frame_size); + + framework::DDim out_dims({x_dims[0], frame_size}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); + ctx->SetOutputDim("BatchCellPreAct", out_dims); + ctx->ShareLoD("X", "Hidden"); + ctx->ShareLoD("X", "Cell"); + + int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + ctx->SetOutputDim("XX", {x_dims[0], xx_width}); + ctx->ShareLoD("X", "XX"); +} + +framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); +} + +void FusionLSTMOpMaker::Make() { + AddInput("X", + "(LoDTensor) the input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X M), where T is the " + "total time steps in this mini-batch, M is the dim size of x."); + AddInput("WeightX", + "(Tensor) the learnable weights of X." + " - The shape is (M x 4D), where M is the dim size of x, D is the " + "hidden size. " + " - Weight = {W_cx, W_ix, W_fx, W_ox}"); + AddInput("WeightH", + "(Tensor) same as LSTMOp, the learnable hidden-hidden weights." + " - The shape is (D x 4D), where D is the hidden size. " + " - Weight = {W_ch, W_ih, W_fh, W_oh}"); + AddInput("Bias", + "(Tensor) the learnable weights. Almost same as LSTMOp" + "Note: we should add the fc bias into this (1x4D) in bias." + "input-hidden bias weight and peephole connections weight if " + "setting `use_peepholes` True. " + "1. `use_peepholes = False` " + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." + "2. `use_peepholes = True` " + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + AddInput("H0", + "(Tensor, optional) (same as LSTMOp) the initial hidden state is an " + "optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size and D is the hidden size.") + .AsDispensable(); + AddInput("C0", + "(Tensor, optional) (same as LSTMOp) (the initial cell state is an " + "optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time.") + .AsDispensable(); + AddOutput("Hidden", + "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("XX", + "(LoDTensor) the result after X * WeightX (size is T x 4D)" + " or batched_X (size is T x M), this will be automatically chosen," + " where T is the total time steps in this mini-batch," + " D is the hidden size, M is the dim size of x input.") + .AsIntermediate(); + AddOutput("BatchedGate", "(LoDTensor) (same as LSTMOp).").AsIntermediate(); + AddOutput("BatchCellPreAct", "(LoDTensor) (same as LSTMOp).") + .AsIntermediate(); + AddAttr("use_peepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr("is_reverse", + "(bool, defalut: False) " + "whether to compute reversed LSTM.") + .SetDefault(false); + AddAttr("gate_activation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("cell_activation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("candidate_activation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddComment(R"DOC( +Fusion Long-Short Term Memory (LSTM) Operator. +This operator fuse the X into LSTM, more details can refer to LSTM op. +)DOC"); +} + +template +inline void ReorderInitState(const DeviceContext& ctx, + const framework::Tensor& src, + framework::Vector index_lod, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + // TODO(TJ): check mem copy perf + row_shuffle(ctx, src, index_lod, dst, indexed_src); +} + +template +class FuisonLSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* wx = ctx.Input("WeightX"); + auto* wh = ctx.Input("WeightH"); + auto* bias = ctx.Input("Bias"); + auto* hidden_t0 = ctx.Input("H0"); + auto* cell_t0 = ctx.Input("C0"); + + auto* xx = ctx.Output("XX"); + auto* batched_gate = ctx.Output("BatchedGate"); + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); + bool is_reverse = ctx.Attr("is_reverse"); + + T* xx_data = xx->mutable_data(ctx.GetPlace()); + T* batched_gate_data = batched_gate->mutable_data(ctx.GetPlace()); + hidden_out->mutable_data(ctx.GetPlace()); + cell_out->mutable_data(ctx.GetPlace()); + + const T* x_data = x->data(); + const T* wx_data = wx->data(); + auto x_dims = x->dims(); + auto wx_dims = wx->dims(); + + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + if (x_dims[1] > wx_dims[1]) { + math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], + x_data, wx_data, xx_data, + bias->data()); + to_batch(dev_ctx, *xx, batched_gate, true, is_reverse); + } else { + to_batch(dev_ctx, *x, xx, true, is_reverse); + batched_gate->set_lod(xx->lod()); + math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], + xx_data, wx_data, batched_gate_data, + bias->data()); + } + + int frame_size = static_cast(wx_dims[1] / 4); + framework::DDim out_dims({x_dims[0], frame_size}); + math::LstmMetaValue lstm_value; + // no peephole + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; + lstm_value.prev_state_value = nullptr; + Tensor ordered_c0; + + framework::Vector order(batched_gate->lod()[2]); + + if (cell_t0) { + // Since the batch computing for LSTM reorders the input sequence + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(dev_ctx, *cell_t0, order, &ordered_c0, + true); + lstm_value.prev_state_value = ordered_c0.data(); + } + + // Use the local variable as here. + LoDTensor batch_hidden, batch_cell; + auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + batch_hidden.mutable_data(out_dims, ctx.GetPlace()); + batch_cell.mutable_data(out_dims, ctx.GetPlace()); + batch_cell_pre_act->mutable_data(out_dims, ctx.GetPlace()); + + auto batch_starts = batched_gate->lod()[0]; + size_t max_seq_len = batch_starts.size() - 1; + auto gate_act = math::detail::GetActivationType( + ctx.Attr("gate_activation")); + auto cell_act = math::detail::GetActivationType( + ctx.Attr("cell_activation")); + auto cand_act = math::detail::GetActivationType( + ctx.Attr("candidate_activation")); + + for (size_t n = 0; n < max_seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate_t = batched_gate->Slice(bstart, bend); + Tensor out_t = batch_hidden.Slice(bstart, bend); + Tensor cell_t = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); + + int cur_batch_size = bend - bstart; + + if (n > 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); + // TODO(TJ): use gemm directly + blas.MatMul(pre_hidden_t, false, *wh, false, static_cast(1.0), + &gate_t, static_cast(1.0)); + } else if (hidden_t0) { + // TODO(TJ): move h0 outside for + // If n == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros, the calculation W_h * H0 will be skiped. + // If n == 0 and there is initialized hidden state, calculate W_h * H0. + + // Since the batch computing for LSTM reorders the input sequence + // according to their length. The initialized hidden state also needs + // to reorder. + Tensor ordered_h0; + ReorderInitState(dev_ctx, *hidden_t0, order, + &ordered_h0, true); + // TODO(TJ): use gemm directly + blas.MatMul(ordered_h0, false, *wh, false, static_cast(1.0), &gate_t, + static_cast(1.0)); + } + + lstm_value.gate_value = gate_t.data(); + lstm_value.output_value = out_t.data(); + lstm_value.state_value = cell_t.data(); + lstm_value.state_active_value = cell_pre_act_t.data(); + math::LstmUnitFunctor::compute( + dev_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, + cand_act); + lstm_value.prev_state_value = lstm_value.state_value; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden.set_lod(batched_gate->lod()); + // restore the output hidden in LoDTensor from the batch hidden + to_seq(dev_ctx, batch_hidden, hidden_out); + + batch_cell.set_lod(batched_gate->lod()); + // restore the output cell state in LoDTensor from the batch cell + to_seq(dev_ctx, batch_cell, cell_out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL( + fusion_lstm, + ops::FuisonLSTMKernel, + ops::FuisonLSTMKernel); diff --git a/paddle/fluid/operators/attention_lstm_op.h b/paddle/fluid/operators/attention_lstm_op.h new file mode 100644 index 0000000000..39dc09b4d1 --- /dev/null +++ b/paddle/fluid/operators/attention_lstm_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +// #include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class FusionLSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionLSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle From 508548f897028bb93847f33705a30c4765fe0181 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 00:17:23 +0800 Subject: [PATCH 031/140] implement attention lstm cpu forward --- paddle/fluid/operators/attention_lstm_op.cc | 466 ++++++++++++-------- paddle/fluid/operators/attention_lstm_op.h | 5 +- paddle/fluid/operators/fusion_lstm_op.h | 1 - 3 files changed, 278 insertions(+), 194 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 087df06ad5..178a1c19a9 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -20,10 +20,12 @@ limitations under the License. */ #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/operators/math/cpu_vec.h" + namespace paddle { namespace operators { -void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { +void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("WeightX"), "Input(WeightX) of LSTM should not be null."); @@ -57,6 +59,9 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "should be the same."); } + // fc_out , shape (maxseqlen,1) + int max_seq_len = 0; + auto wx_dims = ctx->GetInputDim("WeightX"); PADDLE_ENFORCE_EQ(wx_dims.size(), 2, "The rank of Input(WeightX) should be 2."); @@ -103,241 +108,321 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "XX"); } -framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( +framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } -void FusionLSTMOpMaker::Make() { +void AttentionLSTMOpMaker::Make() { AddInput("X", "(LoDTensor) the input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " "this LoDTensor is a matrix with shape (T X M), where T is the " "total time steps in this mini-batch, M is the dim size of x."); - AddInput("WeightX", - "(Tensor) the learnable weights of X." - " - The shape is (M x 4D), where M is the dim size of x, D is the " - "hidden size. " - " - Weight = {W_cx, W_ix, W_fx, W_ox}"); - AddInput("WeightH", - "(Tensor) same as LSTMOp, the learnable hidden-hidden weights." - " - The shape is (D x 4D), where D is the hidden size. " - " - Weight = {W_ch, W_ih, W_fh, W_oh}"); - AddInput("Bias", - "(Tensor) the learnable weights. Almost same as LSTMOp" - "Note: we should add the fc bias into this (1x4D) in bias." - "input-hidden bias weight and peephole connections weight if " - "setting `use_peepholes` True. " - "1. `use_peepholes = False` " - " - The shape is (1 x 4D). " - " - Bias = {b_c, b_i, b_f, b_o}." - "2. `use_peepholes = True` " - " - The shape is (1 x 7D). " - " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + AddInput("C0", + "(Tensor) LSTM C0" + "This is a tensor with shape (N x D), where N is the batch size, D " + "is the gate size." + "C0 is necessary because of attention."); AddInput("H0", - "(Tensor, optional) (same as LSTMOp) the initial hidden state is an " - "optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size and D is the hidden size.") + "(Tensor, optional) LSTM H0" + "This is a tensor with shape (N x D), where N is the " + "batch size and D is the gate size.") .AsDispensable(); - AddInput("C0", - "(Tensor, optional) (same as LSTMOp) (the initial cell state is an " - "optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size. `H0` and `C0` can be NULL but only at the same time.") + AddInput("AttentionWeight", + "(Tensor) the weights of attention fc. Always relu the fc result." + "The shape is ((M+D) x 1), where M is the dim size of x, D is the " + "gate size of LSTM."); + AddInput("AttentionBias, optional", + "(Tensor) the bias of attention fc." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalar", + "(Tensor, optional) the scalar on the result of attentioned fc. " + "Always relu the Scalar." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalarBias", + "(Tensor, optional) the scalar bias of attention fc." + "The shape is (1 x 1)") .AsDispensable(); + AddInput("LSTMWeight", + "(Tensor) the combined weight of LSTM" + " - The shape is ((D+M) x 4D), where D is the hidden gate size, M " + "is the dim size of x" + " - Weight = {W_forget, W_input, W_output, W_cell}"); + AddInput("LSTMBias", + "(Tensor) the combined bias of LSTM, shape (1x4D)." + "Note: we should add the bias of hidden and context accorindg to " + "the same gate: " + "{B_forget, B_input, B_output, B_cell}"); AddOutput("Hidden", "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("Cell", "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput("XX", - "(LoDTensor) the result after X * WeightX (size is T x 4D)" - " or batched_X (size is T x M), this will be automatically chosen," - " where T is the total time steps in this mini-batch," - " D is the hidden size, M is the dim size of x input.") + AddOutput( + "AttentionedX", + "(LodTensor) shape is (T x 1), the result after X * AttentionWeight," + " where T is the total time steps in this mini-batch," + " D is the hidden size.") .AsIntermediate(); - AddOutput("BatchedGate", "(LoDTensor) (same as LSTMOp).").AsIntermediate(); - AddOutput("BatchCellPreAct", "(LoDTensor) (same as LSTMOp).") + AddOutput("AttentionFCOut", + "(Tensor) (max_seq_len, 1), compute at each step.") .AsIntermediate(); - AddAttr("use_peepholes", - "(bool, defalut: True) " - "whether to enable diagonal/peephole connections.") - .SetDefault(true); - AddAttr("is_reverse", - "(bool, defalut: False) " - "whether to compute reversed LSTM.") - .SetDefault(false); + AddOutput("LSTMX", + "(Tensor) the input X of LSTM for each step." + "Shape is (1 x M), where M is the x frame size") + .AsIntermediate(); + AddOutput( + "LSTMOUT", + "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." + "Shape is (1 x 4D), where M is the x frame size") + .AsIntermediate(); + // TODO(TJ): InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("gate_activation", "(string, default: sigmoid)" "The activation for input gate, forget gate and output " "gate, `sigmoid` by default.") .SetDefault("sigmoid") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); + .InEnum({"sigmoid"}); AddAttr("cell_activation", "(string, default: tanh)" "The activation for cell output, `tanh` by defalut.") .SetDefault("tanh") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); + .InEnum({"tanh"}); AddAttr("candidate_activation", "(string, default: tanh)" "The activation for candidate hidden state, " "`tanh` by default.") .SetDefault("tanh") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); + .InEnum({"tanh"}); AddComment(R"DOC( -Fusion Long-Short Term Memory (LSTM) Operator. -This operator fuse the X into LSTM, more details can refer to LSTM op. +Attention Long-Short Term Memory (LSTM) Operator. + +Attention part: +concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D)) + +tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu + +fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu + +dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M) + +LSTM part: +use lstm_x_t as input and compute as standard LSTM. + )DOC"); } +// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; +template +inline void bias_relu(const int n, const T* x, const T* bias, T* y) { + if (bias) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + bias[0]; + } + vec_relu(n, y, y); + } else { + vec_relu(n, x, y); + } +} + template -inline void ReorderInitState(const DeviceContext& ctx, - const framework::Tensor& src, - framework::Vector index_lod, - framework::Tensor* dst, bool indexed_src) { - math::CopyMatrixRowsFunctor row_shuffle; - dst->mutable_data(src.dims(), ctx.GetPlace()); - // TODO(TJ): check mem copy perf - row_shuffle(ctx, src, index_lod, dst, indexed_src); +inline void vec_softmax(const BlasT& blas, const int n, + const T* x, T* y) { + T scalar = x[0]; + // max + for (int i = 1; i < n; ++i) { + scalar = scalar < x[i] ? x[i] : scalar; + } + + // sub + for (int i = 0; i < n; ++i) { + y[c] = x[c] - alpha; + } + + // exp + blas.VEXP(n, y, y); + + // sum + scalar = T(0); + for (int i = 0; i < n; ++i) { + scalar += y[i]; + } + + // scale + blas.VSCAL(n, static_cast(1) / scalar, y); +} + +__m256 exp(__m256 a) { return exp256_ps(a); } + +__m256 log(__m256 a) { return log256_ps(a); } + +__m256 sin(__m256 a) { return sin256_ps(a); } + +__m256 cos(__m256 a) { return cos256_ps(a); } + +__m256 relu(const __m256 a) { + __m256 tmp = _mm256_set1_ps(0.0f); + return _mm256_max_ps(a, tmp); +} + +__m256 sigmoid(const __m256 a) { + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 tmp = _mm256_max_ps(a, min); + tmp = _mm256_min_ps(tmp, max); + tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); + tmp = exp(tmp); + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); + tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); + return tmp; +} + +__m256 tanh(const __m256 a) { + __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); + __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); + tmp = _mm256_min_ps(tmp, max); + tmp = exp(tmp); + return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), + _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), + _mm256_set1_ps(1.0f)); +} + +__m256 linear(const __m256 a) { return a; } + +inline void vec_sigmoid(const T* x, T* y) { + const real min = SIGMOID_THRESHOLD_MIN; + const real max = SIGMOID_THRESHOLD_MAX; + real tmp = (a < min) ? min : ((a > max) ? max : a); + return 1.0 / (1.0 + exp(-tmp)); } template -class FuisonLSTMKernel : public framework::OpKernel { +class AttentionLSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* wx = ctx.Input("WeightX"); - auto* wh = ctx.Input("WeightH"); - auto* bias = ctx.Input("Bias"); - auto* hidden_t0 = ctx.Input("H0"); - auto* cell_t0 = ctx.Input("C0"); - - auto* xx = ctx.Output("XX"); - auto* batched_gate = ctx.Output("BatchedGate"); - auto* hidden_out = ctx.Output("Hidden"); - auto* cell_out = ctx.Output("Cell"); - bool is_reverse = ctx.Attr("is_reverse"); - - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* batched_gate_data = batched_gate->mutable_data(ctx.GetPlace()); - hidden_out->mutable_data(ctx.GetPlace()); - cell_out->mutable_data(ctx.GetPlace()); + auto* x = ctx.Input("X"); // T x M + auto* h0 = ctx.Input("H0"); // N x D + auto* c0 = ctx.Input("C0"); // N x D + auto* atten_w = ctx.Input("AttentionWeight"); // (M+D) x 1 + auto* atten_b = ctx.Input("AttentionBias"); // 1x1 + auto* atten_scalar = ctx.Input("AttentionScalar"); // 1x1 + auto* atten_scalar_bias = ctx.Input("AttentionScalar"); // 1x1 + auto* lstm_w = ctx.Input("LSTMWeight"); // (D+M) x D*4 + auto* lstm_b = ctx.Input("LSTMBias"); // 1 x D*4 + + auto* hidden_out = ctx.Output("Hidden"); // TxD + auto* cell_out = ctx.Output("Cell"); // TxD + auto* atted_x = ctx.Output("AttentionedX"); // T x 1 + auto* fc_out = ctx.Output('AttentionFCOut'); // max_seq_len x 1 + auto* lstm_x = ctx.Output("LSTMX"); // 1 x M + auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D const T* x_data = x->data(); - const T* wx_data = wx->data(); - auto x_dims = x->dims(); - auto wx_dims = wx->dims(); - - math::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); - if (x_dims[1] > wx_dims[1]) { - math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], - x_data, wx_data, xx_data, - bias->data()); - to_batch(dev_ctx, *xx, batched_gate, true, is_reverse); - } else { - to_batch(dev_ctx, *x, xx, true, is_reverse); - batched_gate->set_lod(xx->lod()); - math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], - xx_data, wx_data, batched_gate_data, - bias->data()); - } - - int frame_size = static_cast(wx_dims[1] / 4); - framework::DDim out_dims({x_dims[0], frame_size}); - math::LstmMetaValue lstm_value; - // no peephole - lstm_value.check_ig = nullptr; - lstm_value.check_fg = nullptr; - lstm_value.check_og = nullptr; - lstm_value.prev_state_value = nullptr; - Tensor ordered_c0; - - framework::Vector order(batched_gate->lod()[2]); - - if (cell_t0) { - // Since the batch computing for LSTM reorders the input sequence - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState(dev_ctx, *cell_t0, order, &ordered_c0, - true); - lstm_value.prev_state_value = ordered_c0.data(); - } - - // Use the local variable as here. - LoDTensor batch_hidden, batch_cell; - auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); - batch_hidden.mutable_data(out_dims, ctx.GetPlace()); - batch_cell.mutable_data(out_dims, ctx.GetPlace()); - batch_cell_pre_act->mutable_data(out_dims, ctx.GetPlace()); - - auto batch_starts = batched_gate->lod()[0]; - size_t max_seq_len = batch_starts.size() - 1; - auto gate_act = math::detail::GetActivationType( - ctx.Attr("gate_activation")); - auto cell_act = math::detail::GetActivationType( - ctx.Attr("cell_activation")); - auto cand_act = math::detail::GetActivationType( - ctx.Attr("candidate_activation")); - - for (size_t n = 0; n < max_seq_len; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - - Tensor gate_t = batched_gate->Slice(bstart, bend); - Tensor out_t = batch_hidden.Slice(bstart, bend); - Tensor cell_t = batch_cell.Slice(bstart, bend); - Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); - - int cur_batch_size = bend - bstart; - - if (n > 0) { - int pre_h_start = static_cast(batch_starts[n - 1]); - int pre_h_end = pre_h_start + cur_batch_size; - auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); - // TODO(TJ): use gemm directly - blas.MatMul(pre_hidden_t, false, *wh, false, static_cast(1.0), - &gate_t, static_cast(1.0)); - } else if (hidden_t0) { - // TODO(TJ): move h0 outside for - // If n == 0 and there is no initialized hidden state, that is to say - // the H0 is zeros, the calculation W_h * H0 will be skiped. - // If n == 0 and there is initialized hidden state, calculate W_h * H0. - - // Since the batch computing for LSTM reorders the input sequence - // according to their length. The initialized hidden state also needs - // to reorder. - Tensor ordered_h0; - ReorderInitState(dev_ctx, *hidden_t0, order, - &ordered_h0, true); - // TODO(TJ): use gemm directly - blas.MatMul(ordered_h0, false, *wh, false, static_cast(1.0), &gate_t, - static_cast(1.0)); + const T* h0_data = h0->data(); + const T* c0_data = c0->data(); + const T* lstm_w_data = lstm_w->data(); + const T* lstm_b_data = lstm_b->data(); + const T* atten_w_data = atten_w->data(); + const T* atten_b_data = atten_b ? atten_b->data() : NULL; + const T* atten_scalar_data = atten_scalar ? atten_scalar->data() : NULL; + const T* atten_scalar_bias_data = + atten_scalar_bias ? atten_scalar_bias->data() : NULL; + + T* hidden_out_data = hidden_out->mutable_data(); + T* cell_out_data = cell_out->mutable_data(); + T* atted_x_data = atted_x->mutable_data(); + T* fc_out_data = fc_out->mutable_data(); + T* lstm_x_data = lstm_x->mutable_data(); + T* lstm_out_data = lstm_out->mutable_data(); + + auto x_lod = x->lod(); + auto x_dims = x->dims(); // T x M + auto w_dims = w->dims(); // (D+M) x 4D + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size + const int D2 = D * 2; + const int D3 = D * 3; + const int D4 = w_dims[1]; + const int batch_size = x_lod[0].size() - 1; // assert lod.size() == 1 + + // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, T, 1, M, x_data, atten_w_data, + atted_x_data, atten_b_data); + + const T* cur_x_data = x_data; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + T* cur_cell_out_data = cell_out_data; + T* cur_hidden_out_data = hidden_out_data; + for (int i = 0; i < batch_size; ++i) { + int seq_len = x_lod[0][i + 1]; + prev_cell_data = c0_data + i * D; + prev_hidden_data = h0 ? h0_data + i * D : NULL; + + for (int step = 0; step < seq_len; ++step) { + /// compute attention vector + // prev_cell(1xD) * fc(D) rest part of atten_wgt + // T = cblas_dot(); + T prev_cell_bias = blas.VDOT(D, prev_cell_data, atten_w_data + M); + // add cell bias and relu + bias_relu(seq_len, atted_x_data, &prev_cell_bias, fc_out_data); + // fc2: scalar + if (atten_scalar_data) { + // x = a*x + blas.VSCAL(seq_len, atten_scalar_data, fc_out_data); + bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, + fc_out_data); + } + vec_softmax(blas, seq_len, fc_out_data, fc_out_data); + // mul x(seq_len*M) and sum pool + math::FCCompute(blas, 1, M, seq_len, fc_out_data, + cur_x_data, lstm_x_data); + + /// compute LSTM step + // lstm weight : concat[forget , input , output , tilde] + // shape : (D + M) x (4 * D) + // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D + blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data); + if (prev_hidden_data) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), + prev_hidden_data, D, lstm_w_data, D4, static_cast(1), + lstm_out_data, D4); + } + // since input is 1xM, so can use add bias + blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); + + // gate act: sigmoid + vec_sigmoid(D3, lstm_out_data, lstm_out_data); + // candicate act: tanh + vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3); + + // a = forget * prev_cell + blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); + + // b = input * tilde + blas.VMUL(D, lstm_out_data + D, lstm_out + D3, lstm_out_data + D); + + // cell_out = a + b + blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); + + // state act tanh(cell_out) * output_gate + vec_tanh(D, cur_cell_out_data, lstm_out_data); + blas.VMUL(D, lstm_out_data, lstm_out + D2, cur_hidden_out_data); + + prev_hidden_data = hidden_out + i * gate_size; + prev_cell_data = cur_cell_out_data; + cur_cell_out_data = cur_cell_out_data + D; + cur_hidden_out_data = cur_hidden_out_data + D; } - - lstm_value.gate_value = gate_t.data(); - lstm_value.output_value = out_t.data(); - lstm_value.state_value = cell_t.data(); - lstm_value.state_active_value = cell_pre_act_t.data(); - math::LstmUnitFunctor::compute( - dev_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, - cand_act); - lstm_value.prev_state_value = lstm_value.state_value; + cur_x_data = cur_x_data + seq_len * M; } - - math::Batch2LoDTensorFunctor to_seq; - batch_hidden.set_lod(batched_gate->lod()); - // restore the output hidden in LoDTensor from the batch hidden - to_seq(dev_ctx, batch_hidden, hidden_out); - - batch_cell.set_lod(batched_gate->lod()); - // restore the output cell state in LoDTensor from the batch cell - to_seq(dev_ctx, batch_cell, cell_out); } }; @@ -345,10 +430,11 @@ class FuisonLSTMKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, +REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp, + ops::AttentionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OP_CPU_KERNEL( - fusion_lstm, - ops::FuisonLSTMKernel, - ops::FuisonLSTMKernel); + attention_lstm, + ops::AttentionLSTMKernel, + ops::AttentionLSTMKernel); diff --git a/paddle/fluid/operators/attention_lstm_op.h b/paddle/fluid/operators/attention_lstm_op.h index 39dc09b4d1..6ede3a7f3c 100644 --- a/paddle/fluid/operators/attention_lstm_op.h +++ b/paddle/fluid/operators/attention_lstm_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -// #include #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -22,7 +21,7 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; -class FusionLSTMOp : public framework::OperatorWithKernel { +class AttentionLSTMOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -33,7 +32,7 @@ class FusionLSTMOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override; }; -class FusionLSTMOpMaker : public framework::OpProtoAndCheckerMaker { +class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; }; diff --git a/paddle/fluid/operators/fusion_lstm_op.h b/paddle/fluid/operators/fusion_lstm_op.h index 39dc09b4d1..7f79601602 100644 --- a/paddle/fluid/operators/fusion_lstm_op.h +++ b/paddle/fluid/operators/fusion_lstm_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -// #include #include "paddle/fluid/framework/op_registry.h" namespace paddle { From 6ed20474d47a2577159a3799549c457e9f38f420 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 10:17:47 +0800 Subject: [PATCH 032/140] refine attention lstm infershape --- paddle/fluid/operators/attention_lstm_op.cc | 198 +++++++++++--------- 1 file changed, 111 insertions(+), 87 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 178a1c19a9..636deb04a1 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -26,86 +26,102 @@ namespace paddle { namespace operators { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Input(WeightX) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Input(WeightH) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("XX"), - "Output(XX) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(C0) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), + "Input(LSTMWeight) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), + "Input(LSTMBias) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), + "Input(AttentionWeight) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of LSTM should not be null."); + "Output(Hidden) of AttentionLSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), - "Output(BatchedGate) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), - "Output(BatchedGate) of LSTM should not be null."); + "Output(Cell) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), + "Output(AttentionedX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), + "Output(AttentionFCOut) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), + "Output(LSTMX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), + "Output(LSTMOUT) of AttentionLSTM should not be null."); auto x_dims = ctx->GetInputDim("X"); + const int M = x_dims[1]; PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + auto w_dims = ctx->GetInputDim("LSTMWeight"); + const int D = w_dims[1] / 4; + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], D + M, + "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); + + auto b_dims = ctx->GetInputDim("LSTMBias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M, + D); + PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).", + M, D); + + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); if (ctx->HasInput("H0")) { - PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(Cell) and Input(Hidden) of LSTM should not " - "be null at the same time."); auto h_dims = ctx->GetInputDim("H0"); - auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE(h_dims == c_dims, "The dimension of Input(H0) and Input(C0) " "should be the same."); } - // fc_out , shape (maxseqlen,1) - int max_seq_len = 0; - - auto wx_dims = ctx->GetInputDim("WeightX"); - PADDLE_ENFORCE_EQ(wx_dims.size(), 2, - "The rank of Input(WeightX) should be 2."); - PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], - "The first dimension of Input(WeightX) " - "should be %d.", - x_dims[1]); - - int frame_size = wx_dims[1] / 4; - auto wh_dims = ctx->GetInputDim("WeightH"); - PADDLE_ENFORCE_EQ(wh_dims.size(), 2, - "The rank of Input(WeightH) should be 2."); - PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, - "The first dimension of Input(WeightH) " - "should be %d.", - frame_size); - PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, - "The second dimension of Input(WeightH) " - "should be 4 * %d.", - frame_size); - - auto b_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, - "The first dimension of Input(Bias) should be 1."); - - PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"), - "Do not support peephole yet."); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, - "The second dimension of Input(Bias) should be " - "4 * %d if disable peepholes connection", - frame_size); - - framework::DDim out_dims({x_dims[0], frame_size}); + auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); + PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, + "Input(AttentionWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->HasInput("AttentionBias")) { + auto atten_b_dims = ctx->GetInputDim("AttentionBias"); + PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, + "Input(AttentionBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, + "AttentionBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, + "AttentionBias shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalar")) { + auto dims = ctx->GetInputDim("AttentionScalar"); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalar)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalarBias")) { + auto dims = ctx->GetInputDim("AttentionScalarBias"); + PADDLE_ENFORCE( + ctx->HasInput("AttentionScalar"), + "AttentionScalar should not be null when have AttentionScalarBias."); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalarBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); + } + + framework::DDim out_dims({x_dims[0], D}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchCellPreAct", out_dims); + ctx->SetOutputDim("AttentionedX", {x_dims[0], 1}); + ctx->SetOutputDim("LSTMX", {1, M}); + ctx->SetOutputDim("LSTMOUT", {1, 4 * D}); + // AttentionFCOut should be reshape as (maxseqlen,1) in runtime ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - - int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; - ctx->SetOutputDim("XX", {x_dims[0], xx_width}); - ctx->ShareLoD("X", "XX"); } framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( @@ -164,11 +180,10 @@ void AttentionLSTMOpMaker::Make() { AddOutput("Cell", "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput( - "AttentionedX", - "(LodTensor) shape is (T x 1), the result after X * AttentionWeight," - " where T is the total time steps in this mini-batch," - " D is the hidden size.") + AddOutput("AttentionedX", + "(Tensor) shape is (T x 1), the result after X * AttentionWeight," + " where T is the total time steps in this mini-batch," + " D is the hidden size.") .AsIntermediate(); AddOutput("AttentionFCOut", "(Tensor) (max_seq_len, 1), compute at each step.") @@ -316,12 +331,31 @@ class AttentionLSTMKernel : public framework::OpKernel { auto* lstm_w = ctx.Input("LSTMWeight"); // (D+M) x D*4 auto* lstm_b = ctx.Input("LSTMBias"); // 1 x D*4 - auto* hidden_out = ctx.Output("Hidden"); // TxD - auto* cell_out = ctx.Output("Cell"); // TxD - auto* atted_x = ctx.Output("AttentionedX"); // T x 1 - auto* fc_out = ctx.Output('AttentionFCOut'); // max_seq_len x 1 - auto* lstm_x = ctx.Output("LSTMX"); // 1 x M - auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D + auto* hidden_out = ctx.Output("Hidden"); // TxD + auto* cell_out = ctx.Output("Cell"); // TxD + auto* atted_x = ctx.Output("AttentionedX"); // T x 1 + auto* fc_out = ctx.Output('AttentionFCOut'); // max_seq_len x 1 + auto* lstm_x = ctx.Output("LSTMX"); // 1 x M + auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D + + // some shape should be reshape here since infershape can not get lod info + auto x_lod = x->lod(); + const int N = x_lod[0].size() - 1; // batch size + auto x_dims = x->dims(); // T x M + auto w_dims = w->dims(); // (D+M) x 4D + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size + const int D2 = D * 2; + const int D3 = D * 3; + const int D4 = w_dims[1]; + int max_seq_len = x_lod[0][1]; + for (int i = 1; i < N; ++i) { + int len = x_lod[0][i + 1] - x_lod[0][i]; + max_seq_len = max_seq_len < len ? len : max_seq_len; + } + PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1."); + PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); + fc_out->Resize({max_seq_len, 1}); const T* x_data = x->data(); const T* h0_data = h0->data(); @@ -341,16 +375,6 @@ class AttentionLSTMKernel : public framework::OpKernel { T* lstm_x_data = lstm_x->mutable_data(); T* lstm_out_data = lstm_out->mutable_data(); - auto x_lod = x->lod(); - auto x_dims = x->dims(); // T x M - auto w_dims = w->dims(); // (D+M) x 4D - const int M = x_dims[1]; // x frame size - const int D = w_dims[1] / 4; // gate frame size - const int D2 = D * 2; - const int D3 = D * 3; - const int D4 = w_dims[1]; - const int batch_size = x_lod[0].size() - 1; // assert lod.size() == 1 - // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto blas = math::GetBlas(ctx); math::FCCompute(blas, T, 1, M, x_data, atten_w_data, @@ -361,7 +385,7 @@ class AttentionLSTMKernel : public framework::OpKernel { const T* prev_hidden_data = NULL; T* cur_cell_out_data = cell_out_data; T* cur_hidden_out_data = hidden_out_data; - for (int i = 0; i < batch_size; ++i) { + for (int i = 0; i < N; ++i) { int seq_len = x_lod[0][i + 1]; prev_cell_data = c0_data + i * D; prev_hidden_data = h0 ? h0_data + i * D : NULL; @@ -370,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel { /// compute attention vector // prev_cell(1xD) * fc(D) rest part of atten_wgt // T = cblas_dot(); - T prev_cell_bias = blas.VDOT(D, prev_cell_data, atten_w_data + M); + T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); // add cell bias and relu bias_relu(seq_len, atted_x_data, &prev_cell_bias, fc_out_data); // fc2: scalar if (atten_scalar_data) { // x = a*x - blas.VSCAL(seq_len, atten_scalar_data, fc_out_data); + blas.SCAL(seq_len, atten_scalar_data, fc_out_data); bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data); } From cf5ea925c3eea2f63b099513b85eaf5032db38fa Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 16:10:55 +0800 Subject: [PATCH 033/140] fix bugs --- paddle/fluid/operators/attention_lstm_op.cc | 123 +++++++++----------- paddle/fluid/operators/math/blas.h | 17 ++- paddle/fluid/operators/math/blas_impl.h | 3 +- 3 files changed, 75 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 636deb04a1..87fda12ea6 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -15,12 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/sequence2batch.h" - -#include "paddle/fluid/operators/math/cpu_vec.h" +// #include "paddle/fluid/operators/math/detail/activation_functions.h" +// #include "paddle/fluid/operators/math/cpu_vec.h" namespace paddle { namespace operators { @@ -233,6 +230,13 @@ use lstm_x_t as input and compute as standard LSTM. )DOC"); } +template +inline void vec_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + // y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { @@ -240,14 +244,14 @@ inline void bias_relu(const int n, const T* x, const T* bias, T* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] + bias[0]; } - vec_relu(n, y, y); + vec_relu(n, y, y); } else { - vec_relu(n, x, y); + vec_relu(n, x, y); } } template -inline void vec_softmax(const BlasT& blas, const int n, +inline void vec_softmax(const math::BlasT& blas, const int n, const T* x, T* y) { T scalar = x[0]; // max @@ -257,7 +261,7 @@ inline void vec_softmax(const BlasT& blas, const int n, // sub for (int i = 0; i < n; ++i) { - y[c] = x[c] - alpha; + y[i] = x[i] - scalar; } // exp @@ -270,57 +274,45 @@ inline void vec_softmax(const BlasT& blas, const int n, } // scale - blas.VSCAL(n, static_cast(1) / scalar, y); + blas.SCAL(n, static_cast(1) / scalar, y); } -__m256 exp(__m256 a) { return exp256_ps(a); } +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 -__m256 log(__m256 a) { return log256_ps(a); } - -__m256 sin(__m256 a) { return sin256_ps(a); } - -__m256 cos(__m256 a) { return cos256_ps(a); } - -__m256 relu(const __m256 a) { - __m256 tmp = _mm256_set1_ps(0.0f); - return _mm256_max_ps(a, tmp); +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); } -__m256 sigmoid(const __m256 a) { - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); - __m256 tmp = _mm256_max_ps(a, min); - tmp = _mm256_min_ps(tmp, max); - tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); - tmp = exp(tmp); - tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); - tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); - return tmp; +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; } -__m256 tanh(const __m256 a) { - __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); - __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); - tmp = _mm256_min_ps(tmp, max); - tmp = exp(tmp); - return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), - _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), - _mm256_set1_ps(1.0f)); +template +inline void vec_sigmoid(const int n, const T* x, T* y) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = 1.0 / (1.0 + std::exp(-tmp)); + } } -__m256 linear(const __m256 a) { return a; } - -inline void vec_sigmoid(const T* x, T* y) { - const real min = SIGMOID_THRESHOLD_MIN; - const real max = SIGMOID_THRESHOLD_MAX; - real tmp = (a < min) ? min : ((a > max) ? max : a); - return 1.0 / (1.0 + exp(-tmp)); +template +inline void vec_tanh(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = tanh(x[i]); + } } -template +template class AttentionLSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); // T x M auto* h0 = ctx.Input("H0"); // N x D auto* c0 = ctx.Input("C0"); // N x D @@ -334,7 +326,7 @@ class AttentionLSTMKernel : public framework::OpKernel { auto* hidden_out = ctx.Output("Hidden"); // TxD auto* cell_out = ctx.Output("Cell"); // TxD auto* atted_x = ctx.Output("AttentionedX"); // T x 1 - auto* fc_out = ctx.Output('AttentionFCOut'); // max_seq_len x 1 + auto* fc_out = ctx.Output("AttentionFCOut"); // max_seq_len x 1 auto* lstm_x = ctx.Output("LSTMX"); // 1 x M auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D @@ -342,9 +334,10 @@ class AttentionLSTMKernel : public framework::OpKernel { auto x_lod = x->lod(); const int N = x_lod[0].size() - 1; // batch size auto x_dims = x->dims(); // T x M - auto w_dims = w->dims(); // (D+M) x 4D - const int M = x_dims[1]; // x frame size - const int D = w_dims[1] / 4; // gate frame size + auto w_dims = lstm_w->dims(); // (D+M) x 4D + const int total_T = x_dims[0]; + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size const int D2 = D * 2; const int D3 = D * 3; const int D4 = w_dims[1]; @@ -357,6 +350,8 @@ class AttentionLSTMKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); fc_out->Resize({max_seq_len, 1}); + // TODO(TJ): act functor init here + const T* x_data = x->data(); const T* h0_data = h0->data(); const T* c0_data = c0->data(); @@ -368,16 +363,16 @@ class AttentionLSTMKernel : public framework::OpKernel { const T* atten_scalar_bias_data = atten_scalar_bias ? atten_scalar_bias->data() : NULL; - T* hidden_out_data = hidden_out->mutable_data(); - T* cell_out_data = cell_out->mutable_data(); - T* atted_x_data = atted_x->mutable_data(); - T* fc_out_data = fc_out->mutable_data(); - T* lstm_x_data = lstm_x->mutable_data(); - T* lstm_out_data = lstm_out->mutable_data(); + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + T* atted_x_data = atted_x->mutable_data(ctx.GetPlace()); + T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); + T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); + T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto blas = math::GetBlas(ctx); - math::FCCompute(blas, T, 1, M, x_data, atten_w_data, + math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, atted_x_data, atten_b_data); const T* cur_x_data = x_data; @@ -400,7 +395,7 @@ class AttentionLSTMKernel : public framework::OpKernel { // fc2: scalar if (atten_scalar_data) { // x = a*x - blas.SCAL(seq_len, atten_scalar_data, fc_out_data); + blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data); } @@ -431,16 +426,16 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); // b = input * tilde - blas.VMUL(D, lstm_out_data + D, lstm_out + D3, lstm_out_data + D); + blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D); // cell_out = a + b blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); // state act tanh(cell_out) * output_gate vec_tanh(D, cur_cell_out_data, lstm_out_data); - blas.VMUL(D, lstm_out_data, lstm_out + D2, cur_hidden_out_data); + blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); - prev_hidden_data = hidden_out + i * gate_size; + prev_hidden_data = cur_hidden_out_data; prev_cell_data = cur_cell_out_data; cur_cell_out_data = cur_cell_out_data + D; cur_hidden_out_data = cur_hidden_out_data + D; @@ -458,7 +453,5 @@ REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp, ops::AttentionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL( - attention_lstm, - ops::AttentionLSTMKernel, - ops::AttentionLSTMKernel); +REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel, + ops::AttentionLSTMKernel); diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 5aba170221..da185d93c0 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -160,7 +160,7 @@ class Blas { T DOT(int n, const T* x, const T* y) const; template - void SCAL(int n, const T a, const T* x) const; + void SCAL(int n, const T a, T* x) const; template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, @@ -233,11 +233,26 @@ class BlasT : private Blas { Base()->template VCOPY(args...); } + template + void VEXP(ARGS... args) const { + Base()->template VEXP(args...); + } + template void GEMV(ARGS... args) const { Base()->template GEMV(args...); } + template + T DOT(ARGS... args) const { + return Base()->template DOT(args...); + } + + template + void SCAL(ARGS... args) const { + Base()->template SCAL(args...); + } + template void BatchedGEMM(ARGS... args) const { Base()->template BatchedGEMM(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index eaad83ba18..e1df78d11e 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -415,8 +415,7 @@ T Blas::DOT(int n, const T *x, const T *y) const { template <> template -void Blas::SCAL(int n, const T a, - const T *x) const { +void Blas::SCAL(int n, const T a, T *x) const { #ifdef PADDLE_WITH_MKLML CBlas::SCAL(n, a, x, 1); #else From ec59f0d454569ef536c9ac0f7224bc7062b110ce Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 16:40:37 +0800 Subject: [PATCH 034/140] add cpu vec --- paddle/fluid/operators/attention_lstm_op.cc | 56 +++----------- paddle/fluid/operators/math/cpu_vec.h | 81 +++++++++++++++++++++ paddle/fluid/platform/cpu_info.cc | 2 + paddle/fluid/platform/cpu_info.h | 1 + 4 files changed, 95 insertions(+), 45 deletions(-) create mode 100644 paddle/fluid/operators/math/cpu_vec.h diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 87fda12ea6..14985a3f74 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" -// #include "paddle/fluid/operators/math/detail/activation_functions.h" -// #include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -230,13 +230,6 @@ use lstm_x_t as input and compute as standard LSTM. )DOC"); } -template -inline void vec_relu(const int n, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; - } -} - // y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { @@ -244,9 +237,9 @@ inline void bias_relu(const int n, const T* x, const T* bias, T* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] + bias[0]; } - vec_relu(n, y, y); + math::vec_relu(n, y, y); } else { - vec_relu(n, x, y); + math::vec_relu(n, x, y); } } @@ -277,37 +270,6 @@ inline void vec_softmax(const math::BlasT& blas, const int n, blas.SCAL(n, static_cast(1) / scalar, y); } -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 - -template -inline T sigmoid(T x) { - return 1. / (1. + exp(-x)); -} - -template -inline T tanh(T x) { - return 2. * sigmoid(2. * x) - 1.; -} - -template -inline void vec_sigmoid(const int n, const T* x, T* y) { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < n; ++i) { - T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = 1.0 / (1.0 + std::exp(-tmp)); - } -} - -template -inline void vec_tanh(const int n, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = tanh(x[i]); - } -} - template class AttentionLSTMKernel : public framework::OpKernel { public: @@ -351,6 +313,10 @@ class AttentionLSTMKernel : public framework::OpKernel { fc_out->Resize({max_seq_len, 1}); // TODO(TJ): act functor init here + // if (platform::jit::MayIUse(platform::jit::avx2)) { + // } else if (platform::jit::MayIUse(platform::jit::avx)) { + // } else { + // } const T* x_data = x->data(); const T* h0_data = h0->data(); @@ -418,9 +384,9 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); // gate act: sigmoid - vec_sigmoid(D3, lstm_out_data, lstm_out_data); + math::vec_sigmoid(D3, lstm_out_data, lstm_out_data); // candicate act: tanh - vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3); + math::vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3); // a = forget * prev_cell blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); @@ -432,7 +398,7 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); // state act tanh(cell_out) * output_gate - vec_tanh(D, cur_cell_out_data, lstm_out_data); + math::vec_tanh(D, cur_cell_out_data, lstm_out_data); blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); prev_hidden_data = cur_hidden_out_data; diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h new file mode 100644 index 0000000000..29476fce70 --- /dev/null +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace math { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +inline void vec_sigmoid(const int n, const T* x, T* y) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = 1.0 / (1.0 + std::exp(-tmp)); + } +} + +template +inline void vec_tanh(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = tanh(x[i]); + } +} + +template +inline void vec_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template <> +inline void vec_relu(const int n, const float* x, + float* y) { + // TODO(TJ): complete me + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template <> +inline void vec_relu(const int n, const float* x, + float* y) { + // TODO(TJ): complete me + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 7d53a684d6..79a924434b 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -112,6 +112,8 @@ bool MayIUse(const cpu_isa_t cpu_isa) { switch (cpu_isa) { case sse42: return cpu.has(Cpu::tSSE42); + case avx: + return cpu.has(Cpu::tAVX); case avx2: return cpu.has(Cpu::tAVX2); case avx512_common: diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index f5f6766759..2baa21c1bd 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -43,6 +43,7 @@ namespace jit { typedef enum { isa_any, sse42, + avx, avx2, avx512_common, avx512_core, From 93cc29abc09ac3d9cc85d4490f878da46431cdda Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 16:54:57 +0800 Subject: [PATCH 035/140] init attention lstm op test --- .../tests/unittests/test_attention_lstm_op.py | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_attention_lstm_op.py diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py new file mode 100644 index 0000000000..cd555a022b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -0,0 +1,149 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from test_fusion_lstm_op import fc, ACTIVATION + + +def attention_lstm( + x, # T x M + lod, # 1 x N + h0, # N x D + c0, # N x D + fcws, # (M+D) x 1, 1x1 + fcbs, # 1 x 1, 1x1 + w, # (M+D) x 4D + b, # 1 x 4D + act_gate, + act_cell, + act_cand): + hidden + cell + return hidden, cell + + +class TestAttentionLSTMOp(OpTest): + def set_conf(self): + self.lod = [[3]] + + def setUp(self): + self.op_type = 'attention_lstm' + self.lod = [[3]] + self.M = 30 + self.D = 15 + self.has_initial_hidden = True + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.set_conf() + + T = sum(self.lod[0]) + bs = len(self.lod[0]) + + x = np.random.normal(size=(T, self.M)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') + if self.has_initial_hidden: + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + else: + h0 = np.zeros((bs, self.D)).astype('float32') + + fcw1 = np.random.normal(size=(self.M + self.D, 1)).astype('float32') + fcb1 = np.random.normal(size=(1, 1)).astype('float32') + fcw2 = np.random.normal(size=(1, 1)).astype('float32') + fcb2 = np.random.normal(size=(1, 1)).astype('float32') + + # lstm weight and bias + w = np.random.normal(size=(self.M + self.D, + self.D * 4)).astype('float32') + b = np.random.normal(size=(1, self.D * 4)).astype('float32') + + h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2], + ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand]) + + self.inputs = { + 'X': (x, self.lod), + 'C0': c0, + 'AttentionWeight': fcw1, + 'AttentionBias': fcb1, + 'AttentionScalar': fcw2, + 'AttentionScalarBias': fcb2, + 'LSTMWeight': w, + 'LSTMBias': b + } + + if self.has_initial_hidden: + self.inputs['H0'] = h0 + + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + } + self.attrs = { + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand + } + + def test_check_output(self): + self.check_output() + + +class TestAttentionOpNonInit(TestAttentionLSTMOp): + def set_conf(self): + self.has_initial_hidden = False + + +class TestAttentionOpMD1(TestAttentionLSTMOp): + def set_conf(self): + self.M = 36 + self.D = 8 + + +class TestAttentionOpMD2(TestAttentionLSTMOp): + def set_conf(self): + self.M = 8 + self.D = 8 + + +class TestAttentionOpMD3(TestAttentionLSTMOp): + def set_conf(self): + self.M = 15 + self.D = 30 + + +class TestAttentionOpBS1(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[5]] + self.M = 16 + self.D = 32 + + +class TestAttentionOpBS2(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[3, 6]] + + +class TestAttentionOpBS5(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[3, 2, 4, 7, 5]] + + +if __name__ == '__main__': + unittest.main() From 522b3e411f33400ae2735e81c4bc65ca26438445 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 19:40:59 +0800 Subject: [PATCH 036/140] complete attention lstm op test --- .../tests/unittests/test_attention_lstm_op.py | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py index cd555a022b..dea6ec7668 100644 --- a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -18,6 +18,7 @@ import unittest import numpy as np from op_test import OpTest from test_fusion_lstm_op import fc, ACTIVATION +from test_softmax_op import stable_softmax def attention_lstm( @@ -32,8 +33,56 @@ def attention_lstm( act_gate, act_cell, act_cand): - hidden - cell + + T = sum(lod[0]) + N = len(lod[0]) + M = x.shape[1] + D = b.shape[1] / 4 + assert T == x.shape[0] + assert len(fcws) == len(fcbs) + + hidden = [] + cell = [] + + start_offset = 0 + for bid in range(N): + seq_len = lod[0][bid] + xi = np.copy(x[start_offset:seq_len, :]).reshape(seq_len, M) + prev_cell = np.copy(c0[bid]).reshape([1, D]) + prev_hidden = np.copy(h0[bid]).reshape([1, D]) + for step in range(seq_len): + expanded_cell = np.repeat(prev_cell, seq_len, axis=0) + tmp = np.concatenate((xi, expanded_cell), axis=1) + assert tmp.shape[1] == M + D + for fcid in range(len(fcbs)): + tmp = fc(tmp, fcws[fcid], fcbs[fcid]) + tmp = ACTIVATION['relu'](tmp) + tmp = np.reshape(tmp, (1, seq_len)) + tmp = stable_softmax(tmp).reshape(seq_len, 1) + lstmx = xi * tmp # seq * M + lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M]) + lstmin = np.concatenate((prev_hidden, lstmx), axis=1) + lstmout = np.dot(lstmin, w).reshape([1, 4 * D]) + + g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1) + g_f = act_gate(g_f).reshape([1, D]) + g_i = act_gate(g_i).reshape([1, D]) + g_o = act_gate(g_o).reshape([1, D]) + cand = act_cand(cand).reshape([1, D]) + + cell_t = (prev_cell * g_f) + (g_i * cand) + hidden_t = g_o * act_cell(cell_t) + + hidden.append(hidden_t.flatten()) + cell.append(cell_t.flatten()) + + prev_cell = cell_t.reshape([1, D]) + prev_hidden = hidden_t.reshape([1, D]) + + start_offset += seq_len + + hidden = np.array(hidden).astype('float32').reshape([T, D]) + cell = np.array(cell).astype('float32').reshape([T, D]) return hidden, cell @@ -73,7 +122,7 @@ class TestAttentionLSTMOp(OpTest): b = np.random.normal(size=(1, self.D * 4)).astype('float32') h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2], - ACTIVATION[self.act_gate], + w, b, ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], ACTIVATION[self.act_cand]) From dd938d0b948cca5f968411704a023efc8b2971f4 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 22:00:45 +0800 Subject: [PATCH 037/140] fix bugs and pass op test --- paddle/fluid/operators/attention_lstm_op.cc | 36 +++++++++---------- .../tests/unittests/test_attention_lstm_op.py | 9 ++--- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 14985a3f74..5d57703c0b 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto b_dims = ctx->GetInputDim("LSTMBias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M, - D); - PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).", - M, D); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); @@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() { "(Tensor) the weights of attention fc. Always relu the fc result." "The shape is ((M+D) x 1), where M is the dim size of x, D is the " "gate size of LSTM."); - AddInput("AttentionBias, optional", - "(Tensor) the bias of attention fc." + AddInput("AttentionBias", + "(Tensor, optional) the bias of attention fc." "The shape is (1 x 1)") .AsDispensable(); AddInput("AttentionScalar", @@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel { auto* atten_w = ctx.Input("AttentionWeight"); // (M+D) x 1 auto* atten_b = ctx.Input("AttentionBias"); // 1x1 auto* atten_scalar = ctx.Input("AttentionScalar"); // 1x1 - auto* atten_scalar_bias = ctx.Input("AttentionScalar"); // 1x1 + auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); // 1x1 auto* lstm_w = ctx.Input("LSTMWeight"); // (D+M) x D*4 auto* lstm_b = ctx.Input("LSTMBias"); // 1 x D*4 @@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel { // } const T* x_data = x->data(); - const T* h0_data = h0->data(); + const T* h0_data = h0 ? h0->data() : NULL; const T* c0_data = c0->data(); const T* lstm_w_data = lstm_w->data(); const T* lstm_b_data = lstm_b->data(); @@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel { math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, atted_x_data, atten_b_data); + const T* cur_atten_x_data = atted_x_data; const T* cur_x_data = x_data; const T* prev_cell_data = NULL; const T* prev_hidden_data = NULL; T* cur_cell_out_data = cell_out_data; T* cur_hidden_out_data = hidden_out_data; for (int i = 0; i < N; ++i) { - int seq_len = x_lod[0][i + 1]; + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; prev_cell_data = c0_data + i * D; - prev_hidden_data = h0 ? h0_data + i * D : NULL; - + prev_hidden_data = h0_data ? h0_data + i * D : NULL; for (int step = 0; step < seq_len; ++step) { - /// compute attention vector - // prev_cell(1xD) * fc(D) rest part of atten_wgt - // T = cblas_dot(); + /// 1. compute attention vector + // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); - // add cell bias and relu - bias_relu(seq_len, atted_x_data, &prev_cell_bias, fc_out_data); - // fc2: scalar + // 1b. add cell bias and relu + bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data); + // 1c. fc scalar if (atten_scalar_data) { - // x = a*x blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data); } + // 1d. softmax vec_softmax(blas, seq_len, fc_out_data, fc_out_data); // mul x(seq_len*M) and sum pool math::FCCompute(blas, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data); - /// compute LSTM step + /// 2. compute LSTM step // lstm weight : concat[forget , input , output , tilde] // shape : (D + M) x (4 * D) // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D @@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel { cur_hidden_out_data = cur_hidden_out_data + D; } cur_x_data = cur_x_data + seq_len * M; + cur_atten_x_data = cur_atten_x_data + seq_len; } } }; diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py index dea6ec7668..cb02c7e586 100644 --- a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -40,19 +40,20 @@ def attention_lstm( D = b.shape[1] / 4 assert T == x.shape[0] assert len(fcws) == len(fcbs) - hidden = [] cell = [] start_offset = 0 for bid in range(N): seq_len = lod[0][bid] - xi = np.copy(x[start_offset:seq_len, :]).reshape(seq_len, M) + xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len, + M) prev_cell = np.copy(c0[bid]).reshape([1, D]) prev_hidden = np.copy(h0[bid]).reshape([1, D]) for step in range(seq_len): expanded_cell = np.repeat(prev_cell, seq_len, axis=0) tmp = np.concatenate((xi, expanded_cell), axis=1) + assert tmp.shape[0] == seq_len assert tmp.shape[1] == M + D for fcid in range(len(fcbs)): tmp = fc(tmp, fcws[fcid], fcbs[fcid]) @@ -62,7 +63,7 @@ def attention_lstm( lstmx = xi * tmp # seq * M lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M]) lstmin = np.concatenate((prev_hidden, lstmx), axis=1) - lstmout = np.dot(lstmin, w).reshape([1, 4 * D]) + lstmout = fc(lstmin, w, b).reshape([1, 4 * D]) g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1) g_f = act_gate(g_f).reshape([1, D]) @@ -88,7 +89,7 @@ def attention_lstm( class TestAttentionLSTMOp(OpTest): def set_conf(self): - self.lod = [[3]] + pass def setUp(self): self.op_type = 'attention_lstm' From ba168bd2d23f763f1b4c6357943da01890fc6421 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 22 Aug 2018 12:14:26 +0000 Subject: [PATCH 038/140] modify API.spec --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/stack_op.h | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 9250cde1b2..c03df86e0f 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index b139f48d87..c777d5feae 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -154,17 +154,22 @@ class StackKernel : public framework::OpKernel { if (std::is_same::value || n > kMaxThreshold) { #ifdef __NVCC__ + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors on GPU may be slow."; thrust::device_vector device_x_vec(x_datas); auto x_data_arr = device_x_vec.data().get(); #else auto x_data_arr = x_datas.data(); #endif StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); +#ifdef __NVCC__ + // Wait() must be called because device_x_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#endif } #ifdef __NVCC__ else { // NOLINT - VLOG(10) << "Stack more than " << kMaxThreshold - << " tensors on GPU may be slow."; framework::Array x_data_arr; for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i]; StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); @@ -243,6 +248,8 @@ class StackGradKernel : public framework::OpKernel { if (std::is_same::value || n > kMaxThreshold) { #ifdef __NVCC__ + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors on GPU may be slow."; thrust::device_vector device_dx_vec(dx_datas); auto dx_data_arr = device_dx_vec.data().get(); #else @@ -250,11 +257,14 @@ class StackGradKernel : public framework::OpKernel { #endif StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); +#ifdef __NVCC__ + // Wait() must be called because device_dx_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#endif } #ifdef __NVCC__ else { // NOLINT - VLOG(10) << "Stack more than " << kMaxThreshold - << " tensors on GPU may be slow."; framework::Array dx_data_arr; for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i]; StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, From 01eec0af91a5dffc4cdbf462f48f4effe8fc4db9 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 22 Aug 2018 23:00:08 +0800 Subject: [PATCH 039/140] Fix flowers dataset reading problem --- python/paddle/dataset/flowers.py | 5 ++++- python/paddle/fluid/tests/unittests/CMakeLists.txt | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index ce0cd6009a..8c9c721b33 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -120,7 +120,10 @@ def reader_creator(data_file, file = file.strip() batch = None with open(file, 'rb') as f: - batch = pickle.loads(f.read()) + if six.PY2: + batch = pickle.load(f) + else: + batch = pickle.load(f, encoding='bytes') data = batch['data'] labels = batch['label'] for sample, label in zip(data, batch['label']): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ae13d7ff31..e7dd85ef5c 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -64,7 +64,6 @@ if(WITH_DISTRIBUTE) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) -set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 500) py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) From 5ca0bb9aadd50b10dc0e20bbc528604b8937e2c1 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 23 Aug 2018 00:01:45 +0800 Subject: [PATCH 040/140] support more activation type and remove some comments --- paddle/fluid/operators/attention_lstm_op.cc | 57 ++++++++++--------- paddle/fluid/operators/math/cpu_vec.h | 26 ++++++++- .../tests/unittests/test_attention_lstm_op.py | 9 +++ 3 files changed, 63 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 5d57703c0b..1cb65346ee 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" +#include #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" @@ -192,24 +193,23 @@ void AttentionLSTMOpMaker::Make() { "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." "Shape is (1 x 4D), where M is the x frame size") .AsIntermediate(); - // TODO(TJ): InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("gate_activation", "(string, default: sigmoid)" "The activation for input gate, forget gate and output " "gate, `sigmoid` by default.") .SetDefault("sigmoid") - .InEnum({"sigmoid"}); + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("cell_activation", "(string, default: tanh)" "The activation for cell output, `tanh` by defalut.") .SetDefault("tanh") - .InEnum({"tanh"}); + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("candidate_activation", "(string, default: tanh)" "The activation for candidate hidden state, " "`tanh` by default.") .SetDefault("tanh") - .InEnum({"tanh"}); + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Attention Long-Short Term Memory (LSTM) Operator. @@ -273,22 +273,23 @@ class AttentionLSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); // T x M - auto* h0 = ctx.Input("H0"); // N x D - auto* c0 = ctx.Input("C0"); // N x D - auto* atten_w = ctx.Input("AttentionWeight"); // (M+D) x 1 - auto* atten_b = ctx.Input("AttentionBias"); // 1x1 - auto* atten_scalar = ctx.Input("AttentionScalar"); // 1x1 - auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); // 1x1 - auto* lstm_w = ctx.Input("LSTMWeight"); // (D+M) x D*4 - auto* lstm_b = ctx.Input("LSTMBias"); // 1 x D*4 - - auto* hidden_out = ctx.Output("Hidden"); // TxD - auto* cell_out = ctx.Output("Cell"); // TxD - auto* atted_x = ctx.Output("AttentionedX"); // T x 1 - auto* fc_out = ctx.Output("AttentionFCOut"); // max_seq_len x 1 - auto* lstm_x = ctx.Output("LSTMX"); // 1 x M - auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D + + auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); + auto* atten_w = ctx.Input("AttentionWeight"); + auto* atten_b = ctx.Input("AttentionBias"); + auto* atten_scalar = ctx.Input("AttentionScalar"); + auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); + auto* lstm_w = ctx.Input("LSTMWeight"); + auto* lstm_b = ctx.Input("LSTMBias"); + + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); + auto* atted_x = ctx.Output("AttentionedX"); + auto* fc_out = ctx.Output("AttentionFCOut"); + auto* lstm_x = ctx.Output("LSTMX"); + auto* lstm_out = ctx.Output("LSTMOUT"); // some shape should be reshape here since infershape can not get lod info auto x_lod = x->lod(); @@ -310,11 +311,11 @@ class AttentionLSTMKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); fc_out->Resize({max_seq_len, 1}); - // TODO(TJ): act functor init here - // if (platform::jit::MayIUse(platform::jit::avx2)) { - // } else if (platform::jit::MayIUse(platform::jit::avx)) { - // } else { - // } + math::VecActivations act_functor; + std::function act_gate, act_cell, act_cand; + act_gate = act_functor(ctx.Attr("gate_activation")); + act_cell = act_functor(ctx.Attr("cell_activation")); + act_cand = act_functor(ctx.Attr("candidate_activation")); const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : NULL; @@ -381,9 +382,9 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); // gate act: sigmoid - math::vec_sigmoid(D3, lstm_out_data, lstm_out_data); + act_gate(D3, lstm_out_data, lstm_out_data); // candicate act: tanh - math::vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3); + act_cand(D, lstm_out_data + D3, lstm_out_data + D3); // a = forget * prev_cell blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); @@ -395,7 +396,7 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); // state act tanh(cell_out) * output_gate - math::vec_tanh(D, cur_cell_out_data, lstm_out_data); + act_cell(D, cur_cell_out_data, lstm_out_data); blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); prev_hidden_data = cur_hidden_out_data; diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 29476fce70..48c0da0e36 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - +#include #include "paddle/fluid/platform/cpu_info.h" namespace paddle { @@ -34,6 +34,12 @@ inline T tanh(T x) { return 2. * sigmoid(2. * x) - 1.; } +template +inline void vec_identity(const int n, const T* x, T* y) { + // do nothing + return; +} + template inline void vec_sigmoid(const int n, const T* x, T* y) { const T min = SIGMOID_THRESHOLD_MIN; @@ -76,6 +82,24 @@ inline void vec_relu(const int n, const float* x, } } +template +class VecActivations { + public: + std::function operator()( + const std::string& type) { + if (type == "sigmoid") { + return vec_sigmoid; + } else if (type == "relu") { + return vec_relu; + } else if (type == "tanh") { + return vec_tanh; + } else if (type == "identity" || type == "") { + return vec_identity; + } + PADDLE_THROW("Not support type %s.", type); + } +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py index cb02c7e586..a7382c2244 100644 --- a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -160,6 +160,15 @@ class TestAttentionOpNonInit(TestAttentionLSTMOp): self.has_initial_hidden = False +class TestAttentionOpAct(TestAttentionLSTMOp): + def set_conf(self): + self.M = 3 + self.D = 2 + self.act_gate = 'relu' + self.act_cell = 'tanh' + self.act_cand = 'sigmoid' + + class TestAttentionOpMD1(TestAttentionLSTMOp): def set_conf(self): self.M = 36 From 4e538db14d56af761d8adb8936a7f4f7435b7187 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 23 Aug 2018 00:04:03 +0800 Subject: [PATCH 041/140] refine jit space --- paddle/fluid/platform/cpu_info.cc | 13 ++++++++++--- paddle/fluid/platform/cpu_info.h | 3 --- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 79a924434b..fcd658d67c 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -103,9 +103,8 @@ size_t CUDAPinnedMaxChunkSize() { return CUDAPinnedMaxAllocSize() / 256; } -#ifdef PADDLE_WITH_XBYAK namespace jit { - +#ifdef PADDLE_WITH_XBYAK static Xbyak::util::Cpu cpu; bool MayIUse(const cpu_isa_t cpu_isa) { using namespace Xbyak::util; // NOLINT @@ -136,8 +135,16 @@ bool MayIUse(const cpu_isa_t cpu_isa) { } return false; } +#else +bool MayIUse(const cpu_isa_t cpu_isa) { + if (cpu_isa == isa_any) { + return true; + } else { + return false; + } +} +#endif } // namespace jit -#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index 2baa21c1bd..5d17978dd7 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -37,9 +37,7 @@ size_t CUDAPinnedMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t CUDAPinnedMaxChunkSize(); -#ifdef PADDLE_WITH_XBYAK namespace jit { - typedef enum { isa_any, sse42, @@ -56,7 +54,6 @@ typedef enum { inline bool MayIUse(const cpu_isa_t cpu_isa); } // namespace jit -#endif } // namespace platform } // namespace paddle From 8b8f6487d9b7ed78bbc8c10fddbc217f4dfcd030 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 00:17:15 +0800 Subject: [PATCH 042/140] Add debug info for fetch feed --- paddle/scripts/paddle_build.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 8460f93b84..a55a9e89f7 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -313,7 +313,9 @@ function run_test() { Running unit tests ... ======================================== EOF - ctest --output-on-failure + echo $http_proxy + echo $https_proxy + ctest -V # make install should also be test when unittest make install -j `nproc` pip install /usr/local/opt/paddle/share/wheels/*.whl From c95ff1c410165f8c97972dedaa81c079b19f8721 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 00:18:30 +0800 Subject: [PATCH 043/140] Add debug info --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e7dd85ef5c..0c9bbb766f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -64,6 +64,7 @@ if(WITH_DISTRIBUTE) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) +set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 600) py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) From 698c926ce5f1666d18b00bdd12fe63803dc738fe Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 22 Aug 2018 20:42:29 +0800 Subject: [PATCH 044/140] copy program and fix op_desc --- paddle/fluid/framework/ir/graph.h | 4 +--- paddle/fluid/framework/ir/node.h | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 25e33861c0..0d27be5fc0 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -142,8 +142,6 @@ class Graph { nodes_.erase(node); } - const ProgramDesc &program() const { return program_; } - private: // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { @@ -154,7 +152,7 @@ class Graph { } // NOTE: program_ shouldn't be exposed to user. - const ProgramDesc &program_; + const ProgramDesc program_; std::map attrs_; std::map> attr_dels_; std::map> nodes_; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 063c70fb7b..63277d2d01 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -41,8 +41,7 @@ class Node { explicit Node(OpDesc* op_desc) : name_(op_desc->Type()), var_desc_(nullptr), - op_desc_(new OpDesc(*op_desc)), // TODO(panyx0718) the pointer in the - // original OpDesc might go out. + op_desc_(new OpDesc(*op_desc, op_desc->Block())), type_(Type::kOperation) {} Type NodeType() const { return type_; } From b2df17003f22712078df75b299fb27934650319d Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 23 Aug 2018 12:25:47 +0800 Subject: [PATCH 045/140] Add Python Callstacks when Op::Run error (#12759) * Add Python Callstacks when Op::Run error * Skip op with sub-block * refactor: refine callstack info's format * Reshape only support matrix * Polish Python code * Fix UT * Fix Py3 --- paddle/fluid/framework/op_proto_maker.cc | 4 ++ paddle/fluid/framework/op_proto_maker.h | 1 + paddle/fluid/framework/operator.cc | 61 ++++++++++++++----- paddle/fluid/operators/top_k_op.cc | 2 + paddle/fluid/pybind/const_value.cc | 3 + python/paddle/fluid/framework.py | 5 ++ .../tests/unittests/test_operator_desc.py | 5 +- 7 files changed, 65 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 2288c7fe66..9c289243c5 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -129,6 +129,10 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, "Optimized for variable") .SetDefault({}); + AddAttr>(OpCreationCallstackAttrName(), + "Callstack for Op Creatation.") + .SetDefault({}); + Validate(); } diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 80970291c9..cb9c8ab170 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker { public: static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleVarAttrName() { return "op_role_var"; } + static const char *OpCreationCallstackAttrName() { return "op_callstack"; } void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d04f774496..9f8cdf1aeb 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include - +#include "paddle/fluid/framework/operator.h" #include - +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/profiler.h" @@ -127,19 +129,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } void OperatorBase::Run(const Scope& scope, const platform::Place& place) { - VLOG(4) << place << " " << DebugStringEx(&scope); - if (platform::is_gpu_place(place)) { + try { + if (VLOG_IS_ON(4)) { + VLOG(4) << place << " " << DebugStringEx(&scope); + } + if (platform::is_gpu_place(place)) { #ifndef PADDLE_WITH_CUDA - PADDLE_THROW("Cannot run operator on place %s", place); + PADDLE_THROW("Cannot run operator on place %s", place); #else - auto dev_id = boost::get(place).device; - platform::SetDeviceId(dev_id); + auto dev_id = boost::get(place).device; + platform::SetDeviceId(dev_id); #endif + } + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::RecordEvent record_event(Type(), pool.Get(place)); + RunImpl(scope, place); + if (VLOG_IS_ON(3)) { + VLOG(3) << place << " " << DebugStringEx(&scope); + } + } catch (platform::EnforceNotMet exception) { + if (Attrs().count("sub_block") != 0) { + throw exception; + } + + auto& callstack = Attr>( + OpProtoAndCheckerMaker::OpCreationCallstackAttrName()); + + if (callstack.empty()) { + throw exception; + } + std::ostringstream sout; + sout << "Invoke operator " << Type() << " error.\n"; + sout << "Python Callstacks: \n"; + for (auto& line : callstack) { + sout << line; + } + sout << "C++ Callstacks: \n"; + sout << exception.err_str_; + exception.err_str_ = sout.str(); + throw exception; + } catch (...) { + std::rethrow_exception(std::current_exception()); } - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::RecordEvent record_event(Type(), pool.Get(place)); - RunImpl(scope, place); - VLOG(3) << place << " " << DebugStringEx(&scope); } bool OperatorBase::HasInputs(const std::string& name) const { @@ -167,7 +198,7 @@ const std::vector& OperatorBase::Inputs( } bool OperatorBase::HasOutputs(const std::string& name) const { - if (outputs_.find(name) != outputs_.end()) { + if (outputs_.end() != outputs_.find(name)) { return true; } else { return false; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 4a8ac441cf..92a0697e27 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel { "Output(Indices) of TopkOp should not be null."); auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(input_dims.size(), 2, + "Rank of TopK op's input must be 2."); const int k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 76aa7d2010..9094f6051c 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -40,6 +40,9 @@ void BindConstValue(pybind11::module* m) { op_proto_and_checker_maker.def( "kOpRoleVarAttrName", framework::OpProtoAndCheckerMaker::OpRoleVarAttrName); + op_proto_and_checker_maker.def( + "kOpCreationCallstackAttrName", + framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName); } } // namespace pybind diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 62682d1032..389fce1874 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -18,6 +18,7 @@ import collections import contextlib import re import six +import traceback import numpy as np @@ -499,6 +500,10 @@ class Operator(object): if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: del op_attrs[role_var_name] + callstack_var_name = op_maker.kOpCreationCallstackAttrName() + op_attrs[callstack_var_name] = list( + reversed(traceback.format_stack()))[1:] + if len(self.desc.type()) != 0: return if type is None: diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 6d01955993..3ac8268073 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase): self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual( set(mul_op.attr_names), - set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"])) + set([ + "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", + "op_callstack" + ])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr("x_num_col_dims"), 1) From 79918a84429d7dab4eff9487002a7eb01d4f2aaf Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 22 Aug 2018 20:06:48 +0800 Subject: [PATCH 046/140] add sequence_mask_op for DAM model --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/batch_norm_op.cc | 2 +- paddle/fluid/operators/sequence_mask_op.cc | 26 ++++ paddle/fluid/operators/sequence_mask_op.cu | 22 ++++ paddle/fluid/operators/sequence_mask_op.h | 117 ++++++++++++++++++ python/paddle/fluid/layers/nn.py | 22 +++- python/paddle/fluid/nets.py | 2 +- .../tests/book/test_image_classification.py | 5 +- .../tests/unittests/test_sequence_mask.py | 86 +++++++++++++ 9 files changed, 278 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/sequence_mask_op.cc create mode 100644 paddle/fluid/operators/sequence_mask_op.cu create mode 100644 paddle/fluid/operators/sequence_mask_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_sequence_mask.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 9250cde1b2..359db26ed6 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'max_len', 'mask_dtype'], varargs=None, keywords=None, defaults=('int64',)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 5912a1a17c..969f75544f 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Variance", "The global variance (for training) " "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization").Reuse("X"); + AddOutput("Y", "result after normalization"); AddOutput("MeanOut", "Share memory with Mean. " "Store the global mean when training") diff --git a/paddle/fluid/operators/sequence_mask_op.cc b/paddle/fluid/operators/sequence_mask_op.cc new file mode 100644 index 0000000000..e45c18d6af --- /dev/null +++ b/paddle/fluid/operators/sequence_mask_op.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_mask_op.h" + +REGISTER_OPERATOR(sequence_mask, paddle::operators::SequenceMaskOp, + paddle::operators::SequenceMaskOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + sequence_mask, + paddle::operators::SequenceMaskKernel, + paddle::operators::SequenceMaskKernel); diff --git a/paddle/fluid/operators/sequence_mask_op.cu b/paddle/fluid/operators/sequence_mask_op.cu new file mode 100644 index 0000000000..ff5acf4d9e --- /dev/null +++ b/paddle/fluid/operators/sequence_mask_op.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_mask_op.h" + +REGISTER_OP_CUDA_KERNEL( + sequence_mask, + paddle::operators::SequenceMaskKernel, + paddle::operators::SequenceMaskKernel); diff --git a/paddle/fluid/operators/sequence_mask_op.h b/paddle/fluid/operators/sequence_mask_op.h new file mode 100644 index 0000000000..237857b51d --- /dev/null +++ b/paddle/fluid/operators/sequence_mask_op.h @@ -0,0 +1,117 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +class SequenceMaskOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); + auto max_len = ctx->Attrs().Get("max_len"); + PADDLE_ENFORCE_GT(max_len, 1, "Attr(max_len) must be larger than 1"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); + auto dim = framework::vectorize2int(ctx->GetInputDim("X")); + dim.push_back(max_len); + ctx->SetOutputDim("Y", framework::make_ddim(dim)); + } +}; + +class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of sequence_mask op."); + AddOutput("Y", "The output mask of sequence_mask op."); + AddAttr("max_len", "The maximum length of the sequence.") + .GreaterThan(1); + AddAttr("out_dtype", "Output data type"); + AddComment(R"DOC( +SequenceMask Operator + +This operator outputs a Mask according to Input(X) and Attr(max_len). +Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the +Output(Y) is a mask with shape [d_1, d_2, ..., d_n, max_len], where: + +Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n)) + )DOC"); + } +}; + +template +struct SequenceMaskForRangeFunctor { + HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int max_len) + : x_(x), y_(y), max_len_(max_len) {} + + HOSTDEVICE void operator()(int y_idx) const { + int x_idx = y_idx / max_len_; + int j = y_idx % max_len_; + y_[y_idx] = static_cast(j < x_[x_idx] ? 1 : 0); + } + + private: + const Tx *x_; + Ty *y_; + int max_len_; +}; + +template +struct SequenceMaskFunctor { + using Tensor = framework::LoDTensor; + + SequenceMaskFunctor(const DeviceContext &ctx, const Tx *x, Tensor *y, + int limits, int max_len) + : ctx_(ctx), x_(x), y_(y), limits_(limits), max_len_(max_len) {} + + template + void operator()() const { + auto *y_data = y_->mutable_data(ctx_.GetPlace()); + platform::ForRange for_range(ctx_, limits_); + for_range(SequenceMaskForRangeFunctor(x_, y_data, max_len_)); + } + + private: + const DeviceContext &ctx_; + const Tx *x_; + Tensor *y_; + int limits_; + int max_len_; +}; + +template +class SequenceMaskKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *y = ctx.Output("Y"); + auto max_len = ctx.Attr("max_len"); + auto out_dtype = static_cast( + ctx.Attr("out_dtype")); + auto &dev_ctx = ctx.template device_context(); + framework::VisitDataType(out_dtype, SequenceMaskFunctor( + dev_ctx, x->data(), y, + x->numel() * max_len, max_len)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 71592618f5..1fe457452f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,6 +27,7 @@ from . import utils import random from .. import unique_name from functools import reduce +import warnings __all__ = [ 'fc', @@ -103,6 +104,7 @@ __all__ = [ 'rank_loss', 'prelu', 'flatten', + 'sequence_mask', ] @@ -2046,7 +2048,7 @@ def batch_norm(input, param_attr(ParamAttr): The parameter attribute for Parameter `scale`. bias_attr(ParamAttr): The parameter attribute for Parameter `bias`. data_layout(string, default NCHW): NCHW|NHWC - in_place(bool, Default False): Make the input and output of batch norm reuse memory. + in_place(bool, Default False): This argument is deprecated since 0.15.0. use_mkldnn(bool, Default false): ${use_mkldnn_comment} name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2068,6 +2070,10 @@ def batch_norm(input, helper = LayerHelper('batch_norm', **locals()) dtype = helper.input_dtype() + if in_place: + raise warnings.warn("The argument in_place is deprecated since 0.15.0, " + "please do not set it True.") + input_shape = input.shape if data_layout == 'NCHW': channel_num = input_shape[1] @@ -2117,7 +2123,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) + batch_norm_out = helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", @@ -5517,3 +5523,15 @@ def flatten(x, axis=1, name=None): outputs={'Out': out}, attrs={"axis": axis}) return out + + +def sequence_mask(x, max_len, mask_dtype='int64'): + helper = LayerHelper('sequence_mask', **locals()) + y = helper.create_tmp_variable(dtype=mask_dtype) + helper.append_op( + type='sequence_mask', + inputs={'X': [x]}, + outputs={'Y': y}, + attrs={'max_len': max_len, + 'out_dtype': y.dtype}) + return y diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 051fe84364..01563cbbb7 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -229,7 +229,7 @@ def img_conv_group(input, use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) + tmp = layers.batch_norm(input=tmp, act=conv_act) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 9fe361425c..cd1e8cd682 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -256,7 +256,10 @@ def main(net_type, use_cuda, is_local=True): save_dirname = "image_classification_" + net_type + ".inference.model" train(net_type, use_cuda, save_dirname, is_local) - infer(use_cuda, save_dirname) + + # There is bug in fluid.InferenceTranspiler for VGG. + if net_type == "resnet": + infer(use_cuda, save_dirname) class TestImageClassification(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_mask.py b/python/paddle/fluid/tests/unittests/test_sequence_mask.py new file mode 100644 index 0000000000..c6d09df984 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_mask.py @@ -0,0 +1,86 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +from paddle.fluid.framework import convert_np_dtype_to_dtype_ +import numpy as np +import copy +import unittest + + +class SequenceMaskTestBase(OpTest): + def initDefaultParameters(self): + self.op_type = 'sequence_mask' + self.max_len = 10 + self.mask_dtype = 'int64' + self.x = [[0, 3, 4], [5, 7, 9]] + + def initParameters(self): + pass + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + if not isinstance(self.x, np.ndarray): + self.x = np.array(self.x) + + self.inputs = {'X': self.x} + self.outputs = {'Y': self.calc_ground_truth_mask()} + self.attrs = { + 'max_len': self.max_len, + 'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype) + } + + def calc_ground_truth_mask(self): + shape = self.x.shape + (self.max_len, ) + index_broadcast = np.broadcast_to( + np.reshape( + range(self.max_len), newshape=[1] * self.x.ndim + [-1]), + shape=shape) + x_broadcast = np.broadcast_to( + np.reshape( + self.x, newshape=self.x.shape + (-1, )), shape=shape) + return (index_broadcast < x_broadcast).astype(self.mask_dtype) + + def test_check_output(self): + self.check_output() + + +class SequenceMaskTest1(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'bool' + + +class SequenceMaskTest2(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'uint8' + + +class SequenceMaskTest3(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'int32' + + +class SequenceMaskTest4(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'float32' + + +class SequenceMaskTest5(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'float64' + + +if __name__ == '__main__': + unittest.main() From b8da70c37098beff9b5ccf3b13ac4eb6091e0f3f Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Thu, 23 Aug 2018 13:47:16 +0800 Subject: [PATCH 047/140] Resovle multi gpu async deps (#12828) * dist transpiler add control dependency var between send and recv * fix async deps * follow comments and refine * fix deps connect for rpc ops --- .../details/multi_devices_graph_pass.cc | 26 ++++++++++++++++--- paddle/fluid/framework/ir/node.cc | 2 +- paddle/fluid/framework/ir/node.h | 2 +- paddle/fluid/pybind/const_value.cc | 5 +++- python/paddle/fluid/framework.py | 6 +++++ .../fluid/transpiler/distribute_transpiler.py | 18 +++++++++++-- 6 files changed, 50 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index c5a13e7e1f..bc61b0eacb 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, // Create RPC related op handles that connects its in ops and out ops. void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { + // FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode + // put them into transpiler. int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. @@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by // split_byref op - // so that we can balance the variable blocks to all the pserver - // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && node->inputs[0]->Name().find(".block") == std::string::npos) { std::vector input_var_names; for (ir::Node *n : node->inputs) { input_var_names.push_back(n->Name()); } - op_dev_id = GetAppropriateDeviceID(input_var_names); + auto send_param_grad = boost::get>( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U); + op_dev_id = GetAppropriateDeviceID({send_param_grad[1]}); + VLOG(10) << "send grad " << input_var_names[0] << " origin " + << send_param_grad[1] << " place: " << op_dev_id; for (auto &varname : input_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } + result->Get(kShardedVarDevice) + .emplace(send_param_grad[1], op_dev_id); } } else if (node->Op()->Type() == "recv") { std::vector output_var_names; for (ir::Node *n : node->outputs) { output_var_names.push_back(n->Name()); } - op_dev_id = GetAppropriateDeviceID(output_var_names); + auto recv_param_grad = boost::get>( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + // FIXME(typhoonzero): assume each recv op output one param + // Use the same place as send. + if (recv_param_grad.size() == 2U) { + op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); + VLOG(10) << "recv param " << recv_param_grad[0] + << " get grad place: " << recv_param_grad[1] + << " place: " << op_dev_id; + } else { + op_dev_id = GetAppropriateDeviceID(output_var_names); + } for (auto &varname : output_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index aca77da8d6..65c45c7d20 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { -const char Node::kControlDepVarName[] = "__control_var"; +constexpr char Node::kControlDepVarName[]; } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 63277d2d01..aab3180e7e 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -27,7 +27,7 @@ namespace ir { class Node { public: enum class Type { kOperation, kVariable }; - static const char kControlDepVarName[]; + static constexpr char kControlDepVarName[] = "__control_var"; explicit Node(const std::string& name, Type type) : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 9094f6051c..a81715c3b3 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/pybind/const_value.h" -#include +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" namespace paddle { @@ -24,6 +25,8 @@ void BindConstValue(pybind11::module* m) { m->def("kTempVarName", [] { return framework::kTempVarName; }); m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; }); m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); + m->def("kControlDepVarName", + [] { return framework::ir::Node::kControlDepVarName; }); auto op_proto_and_checker_maker = m->def_submodule("op_proto_and_checker_maker"); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 389fce1874..e0ddd3b5ff 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -50,6 +50,12 @@ EMPTY_VAR_NAME = core.kEmptyVarName() TEMP_VAR_NAME = core.kTempVarName() GRAD_VAR_SUFFIX = core.kGradVarSuffix() ZERO_VAR_SUFFIX = core.kZeroVarSuffix() +CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() + + +def generate_control_dev_var_name(): + import random + return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random()) def grad_var_name(var_name): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 540eb8c833..80d9758b3d 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -212,8 +212,10 @@ class DistributeTranspiler(object): ps_dispatcher = self.config.split_method(self.pserver_endpoints) self.has_distributed_lookup_table = self._has_distributed_lookup_table() self.param_name_to_grad_name = dict() + self.grad_name_to_param_name = dict() for param_var, grad_var in self.params_grads: self.param_name_to_grad_name[param_var.name] = grad_var.name + self.grad_name_to_param_name[grad_var.name] = param_var.name # add distributed attrs to program self.origin_program._is_distributed = True @@ -262,8 +264,10 @@ class DistributeTranspiler(object): AssertionError("Can not insert the send op by original " "variable name :", splited_grad_varname) - dummy_output = program.global_block().create_var() + dummy_output = program.global_block().create_var( + name=framework.generate_control_dev_var_name()) grad_name_to_send_dummy_out[grad_varname] = dummy_output + program.global_block()._insert_op( index=index + 1, type="send", @@ -272,6 +276,8 @@ class DistributeTranspiler(object): attrs={ "epmap": eplist, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: + [self.grad_name_to_param_name[grad_varname], grad_varname], "sync_mode": not self.sync_mode, }) for _, var in enumerate(splited_vars): @@ -313,6 +319,10 @@ class DistributeTranspiler(object): attrs={ "epmap": eps, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: [ + param_varname, + self.param_name_to_grad_name[param_varname] + ], "sync_mode": not self.sync_mode }) @@ -971,7 +981,11 @@ class DistributeTranspiler(object): attrs={ "sync_mode": True, "epmap": pserver_endpoints, - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: [ + self.grad_name_to_param_name[table_grad_name], + table_grad_name + ] }) break From 80e3ce411d16052766aca33d702b31cb0ec81419 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 13:51:47 +0800 Subject: [PATCH 048/140] For test --- paddle/scripts/paddle_build.sh | 3 ++- python/paddle/dataset/flowers.py | 1 + python/paddle/fluid/tests/unittests/CMakeLists.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index a55a9e89f7..02bf8533d8 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -315,7 +315,8 @@ function run_test() { EOF echo $http_proxy echo $https_proxy - ctest -V + ctest -R test_parallel_executor_fetch_feed -V + ctest -R test_dist_se_resnext -V # make install should also be test when unittest make install -j `nproc` pip install /usr/local/opt/paddle/share/wheels/*.whl diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 8c9c721b33..c4a3eb55dd 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -41,6 +41,7 @@ from paddle.reader import * import os import numpy as np from multiprocessing import cpu_count +import six from six.moves import cPickle as pickle from six.moves import zip __all__ = ['train', 'test', 'valid'] diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 0c9bbb766f..228a5ab917 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -64,7 +64,7 @@ if(WITH_DISTRIBUTE) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) -set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 600) +set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 300) py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) From 8ad90558047fcd844db270a5744c78bf772242fb Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 23 Aug 2018 14:00:58 +0800 Subject: [PATCH 049/140] Add is_test for while_op (#12874) * add is_test for while_op * Change API --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/while_op.cc | 7 +++++++ python/paddle/fluid/layers/control_flow.py | 7 +++++-- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 9250cde1b2..bbf1623c39 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -191,7 +191,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 48e37796e1..65a3bc928e 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -58,11 +58,15 @@ class WhileOp : public framework::OperatorBase { PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), "Condition of while op must in CPU memory."); + bool is_test = Attr("is_test"); auto ctx = executor.Prepare(*program, block->ID()); while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); executor.RunPreparedContext(ctx.get(), ¤t_scope, false); + if (is_test) { + scope.DeleteScope(¤t_scope); + } } } }; @@ -88,6 +92,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { "variables generated in the i'th step."); AddAttr(kStepBlock, "The step block inside WhileOp"); + AddAttr("is_test", "True if in test phase.").SetDefault(false); AddComment(R"DOC( )DOC"); } @@ -103,6 +108,8 @@ class WhileGradOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { + PADDLE_ENFORCE(!Attr("is_test"), + "GradOp is only callable when is_test is false"); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 8bfe11916b..d2954c4c22 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -661,6 +661,7 @@ class While(object): Args: cond (Variable): condition used to compare. + is_test(bool): A flag indicating whether execution is in test phase. name (str): The name of this layer. Examples: @@ -683,7 +684,7 @@ class While(object): IN_WHILE_BLOCK = 1 AFTER_WHILE_BLOCK = 2 - def __init__(self, cond, name=None): + def __init__(self, cond, is_test=False, name=None): self.helper = LayerHelper("while", name=name) self.status = While.BEFORE_WHILE_BLOCK if not isinstance(cond, Variable): @@ -694,6 +695,7 @@ class While(object): if reduce(lambda a, b: a * b, cond.shape, 1) != 1: raise TypeError("condition should be a bool scalar") self.cond_var = cond + self.is_test = is_test def block(self): return WhileGuard(self) @@ -735,7 +737,8 @@ class While(object): }, outputs={'Out': out_vars, 'StepScopes': [step_scope]}, - attrs={'sub_block': while_block}) + attrs={'sub_block': while_block, + "is_test": self.is_test}) def lod_rank_table(x, level=0): From 9c7fde45a7fec127e3f7dc7e1c161ec647e5683b Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 23 Aug 2018 13:32:02 +0800 Subject: [PATCH 050/140] enhance test_analyzer to profile ditu inference demo --- .../ir/graph_pattern_detecter_tester.cc | 4 +- paddle/fluid/framework/selected_rows.cc | 4 +- .../inference/analysis/analyzer_tester.cc | 48 +++++++++++-------- paddle/fluid/operators/sampling_id_op.h | 2 +- paddle/scripts/paddle_build.sh | 2 - 5 files changed, 32 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc index 993c885a81..06f9df5546 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc @@ -163,8 +163,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) { // 3. Detect op2 -> var2 -> op4 // 4. Detect op2 -> var3 -> op5 // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 - ASSERT_GE(count, 1UL); - ASSERT_LE(count, 2UL); + ASSERT_GE(count, 1); + ASSERT_LE(count, 2); } } // namespace ir diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index c202b0a5be..a4319ffabb 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -139,7 +139,7 @@ int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) { } auto write_iter = id_to_index_.find(key); if (write_iter == id_to_index_.end()) { - size_t row_num = rows_.size(); + int row_num = rows_.size(); if (row_num == value_->dims()[0]) { rwlock_->UNLock(); PADDLE_THROW("selected rows is full, then length exceed %d", row_num); @@ -182,7 +182,7 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], "output tensor should have the same shape with table " "except the dims[0]."); - for (size_t i = 0; i < ids.numel(); ++i) { + for (int i = 0; i < ids.numel(); ++i) { int64_t index = AutoGrownIndex(ids.data()[i], auto_grown); framework::VisitDataType( framework::ToDataType(value_->type()), diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 52f5c4f5ae..baa7600283 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -23,6 +23,8 @@ DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN"); +DEFINE_int32(batch_size, 10, "batch size."); +DEFINE_int32(repeat, 1, "Running the inference program repeat times."); namespace paddle { namespace inference { @@ -92,7 +94,7 @@ struct DataRecord { size_t batch_iter{0}; size_t batch_size{1}; DataRecord() = default; - DataRecord(const std::string &path, int batch_size = 1) + explicit DataRecord(const std::string &path, int batch_size = 1) : batch_size(batch_size) { Load(path); } @@ -165,7 +167,6 @@ struct DataRecord { }; void PrepareInputs(std::vector *input_slots, DataRecord *data, int batch_size) { - // DataRecord data(FLAGS_datapath, batch_size); PaddleTensor lod_attention_tensor, init_zero_tensor, lod_tensor_tensor, week_tensor, minute_tensor; lod_attention_tensor.name = "data_lod_attention"; @@ -174,28 +175,33 @@ void PrepareInputs(std::vector *input_slots, DataRecord *data, week_tensor.name = "week"; minute_tensor.name = "minute"; auto one_batch = data->NextBatch(); - // clang-format off - std::vector rnn_link_data_shape - ({static_cast(one_batch.rnn_link_data.size()), static_cast(one_batch.rnn_link_data.front().size())}); + std::vector rnn_link_data_shape( + {static_cast(one_batch.rnn_link_data.size()), + static_cast(one_batch.rnn_link_data.front().size())}); lod_attention_tensor.shape.assign({1, 2}); lod_attention_tensor.lod.assign({one_batch.lod1, one_batch.lod2}); init_zero_tensor.shape.assign({batch_size, 15}); init_zero_tensor.lod.assign({one_batch.lod3}); lod_tensor_tensor.shape = rnn_link_data_shape; lod_tensor_tensor.lod.assign({one_batch.lod1}); - week_tensor.shape.assign({(int) one_batch.rnn_week_datas.size(), (int) one_batch.rnn_week_datas.front().size()}); + // clang-format off + week_tensor.shape.assign( + {static_cast(one_batch.rnn_week_datas.size()), + static_cast(one_batch.rnn_week_datas.front().size())}); week_tensor.lod.assign({one_batch.lod3}); - minute_tensor.shape.assign({(int) one_batch.rnn_minute_datas.size(), - (int) one_batch.rnn_minute_datas.front().size()}); + minute_tensor.shape.assign( + {static_cast(one_batch.rnn_minute_datas.size()), + static_cast(one_batch.rnn_minute_datas.front().size())}); minute_tensor.lod.assign({one_batch.lod3}); + // clang-format on // assign data - TensorAssignData(&lod_attention_tensor, std::vector>({{0, 0}})); + TensorAssignData(&lod_attention_tensor, + std::vector>({{0, 0}})); std::vector tmp_zeros(batch_size * 15, 0.); TensorAssignData(&init_zero_tensor, {tmp_zeros}); TensorAssignData(&lod_tensor_tensor, one_batch.rnn_link_data); TensorAssignData(&week_tensor, one_batch.rnn_week_datas); TensorAssignData(&minute_tensor, one_batch.rnn_minute_datas); - // clang-format on // Set inputs. auto init_zero_tensor1 = init_zero_tensor; init_zero_tensor1.name = "hidden_init"; @@ -231,12 +237,9 @@ std::string DescribeTensor(const PaddleTensor &tensor) { os << "\n"; os << " - data: "; - // clang-format off - int dim = std::accumulate(tensor.shape.begin(), - tensor.shape.end(), - 1, - [](int a, int b) { return a * b; }); // clang-format on - for (size_t i = 0; i < dim; i++) { + int dim = std::accumulate(tensor.shape.begin(), tensor.shape.end(), 1, + [](int a, int b) { return a * b; }); + for (int i = 0; i < dim; i++) { os << static_cast(tensor.data.data())[i] << " "; } os << '\n'; @@ -300,13 +303,16 @@ void TestDituRNNPrediction(const std::string &model_path, for (int i = 0; i < num_times; i++) { predictor->Run(input_slots, &outputs); } - LOG(INFO) << "time/batch: " << timer.toc() / num_times; + LOG(INFO) << "===========profile result==========="; + LOG(INFO) << "batch_size: " << batch_size << ", repeat: " << num_times + << ", latency: " << timer.toc() / num_times << "ms"; + LOG(INFO) << "====================================="; for (auto &out : outputs) { size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1, [](int a, int b) { return a * b; }); float *data = static_cast(out.data.data()); - for (int i = 0; + for (size_t i = 0; i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size); i++) { EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3); @@ -336,7 +342,7 @@ TEST(Analyzer, SupportIRPass) { // Directly infer with the original model. TEST(Analyzer, DituRNN_without_analysis) { TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, - 10, false, false); + FLAGS_batch_size, false, false, FLAGS_repeat); } // Inference with the original model with the analysis turned on, the analysis @@ -344,14 +350,14 @@ TEST(Analyzer, DituRNN_without_analysis) { TEST(Analyzer, DituRNN_with_analysis) { LOG(INFO) << "ditu rnn with analysis"; TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, - 10, true, false, 1); + FLAGS_batch_size, true, false, FLAGS_repeat); } // Inference with analysis and IR. The IR module will fuse some large kernels. TEST(Analyzer, DituRNN_with_analysis_with_IR) { LOG(INFO) << "ditu rnn with analysis and IR fuse"; TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, - 10, true, true, 1); + FLAGS_batch_size, true, true, FLAGS_repeat); } } // namespace analysis diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index f730a9746d..e1dd4539b3 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -54,7 +54,7 @@ class SamplingIdKernel : public framework::OpKernel { static_cast(context.Attr("max"))); std::vector ids(batch_size); - for (size_t i = 0; i < batch_size; ++i) { + for (int i = 0; i < batch_size; ++i) { T r = dist(engine); int idx = width - 1; for (int j = 0; j < width; ++j) { diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 8460f93b84..f2a9a6b3b9 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -116,7 +116,6 @@ function cmake_gen() { -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DWITH_CONTRIB=${WITH_CONTRIB:-ON} -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} - -DWITH_INFERENCE_DEMO=${WITH_INFERENCE_DEMO:-ON} -DPY_VERSION=${PY_VERSION:-2.7} ======================================== EOF @@ -146,7 +145,6 @@ EOF -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DWITH_CONTRIB=${WITH_CONTRIB:-ON} \ -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} \ - -DWITH_INFERENCE_DEMO=${WITH_INFERENCE_DEMO:-ON} \ -DPY_VERSION=${PY_VERSION:-2.7} } From acdd95d5caf92f38a995bc6d2edf20a56520d799 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 23 Aug 2018 16:47:12 +0800 Subject: [PATCH 051/140] bug fix --- paddle/fluid/operators/sampling_id_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index f730a9746d..3f7860e1fa 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel { break; } } - ids[i] = ins_vector[i * width + idx]; + ids[i] = ins_vector[idx]; } std::vector out_dim; From f4a4a4cbd934a8af2d4d889cdb0db74fc6a9cfd2 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 22 Aug 2018 16:40:38 +0800 Subject: [PATCH 052/140] add op comment and python layer --- .../fluid/operators/math/sequence_padding.cu | 3 ++ paddle/fluid/operators/sequence_pad_op.cc | 46 +++++++++++++++++++ python/paddle/fluid/layers/nn.py | 45 ++++++++++++++++++ 3 files changed, 94 insertions(+) diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 93d239351a..f94e8dbc3a 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -66,6 +66,9 @@ class PaddingLoDTensorFunctor { if (pad_seq_len == -1) { pad_seq_len = max_seq_len; } + PADDLE_ENFORCE_GE(pad_seq_len, max_seq_len, + "The pad_seq_len must be equal to or greater than the " + "original max sequence length."); int step_width = seq_tensor.numel() / seq_tensor_dims[0]; int seq_num = seq_offsets.size() - 1; diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index f23710cf4d..a08804cfba 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -101,6 +101,52 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { "sequence.") .SetDefault(-1); AddComment(R"DOC( + Sequence Pad Operator + + This operator pads sequences in a same batch to a consistent length. + The length is specified by attribute 'padded_length'. New elements, + whose values are specified by input 'PadValue', will be appended to + the end of each sequence, to make their final lengths consistent. + + Following are cases to better explain how this works: + + Case 1: + + Given a 1-level LoDTensor input(X): + X.lod = [[0, 2, 5]] + X.data = [a, b, c, d, e] + and Input(PadValue): + PadValue.data = [0] + and attribite 'padded_length' = 4, + then we get 1-level LoDTensor: + Out.lod = [[0, 4, 8]] + Out.data = [a, b, 0, 0, c, d, e, 0] + + Case 2: + + Given a 1-level LoDTensor input(X): + X.lod = [[0, 2, 5]] + X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]] + and Input(PadValue): + PadValue.data = [0] + and attribite 'padded_length' = -1, which mean using the length + of longest input sequence(3 in this case), + then we get 1-level LoDTensor: + Out.lod = [[0, 3, 6]] + Out.data = [[a1, a2], [b1, b2], [0, 0], [c1, c2], [d1, d2], [e1, e2]] + + Case 3: + + Given a 1-level LoDTensor input(X): + X.lod = [[0, 2, 5]] + X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]] + and Input(PadValue): + PadValue.data = [p1, p2] + and attribite 'padded_length' = -1, which mean using the length + of longest input sequence(3 in this case), + then we get 1-level LoDTensor: + Out.lod = [[0, 3, 6]] + Out.data = [[a1, a2], [b1, b2], [p1, p2], [c1, c2], [d1, d2], [e1, e2]] )DOC"); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0960b54123..d782ea7470 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2662,6 +2662,51 @@ def sequence_expand(x, y, ref_level=-1, name=None): return tmp +@templatedoc() +def sequence_pad(x, pad_value, maxlen=None): + """ + ${comment} + + Args: + x(Variable): Input variable which should contain lod information. + pad_value(Variable): The Variable that holds values that will be fill + into padded steps. It can be a scalar or a tensor whose shape + equals to time steps in sequences. If it's a scalar, it will be + automatically broadcasted to the shape of time step. + maxlen(int, default None): The length of padded sequences. It can be + None or any positive int. When it is None, all sequences will be + padded up to the length of the longest one among them; when it a + certain positive value, it must be greater than the length of the + longest original sequence." + + Returns: + Variable: The padded sequence batch. All sequences has the same length. + + Examples: + .. code-block:: python + + import numpy + + x = fluid.layers.data(name='y', shape=[10, 5], + dtype='float32', lod_level=1) + pad_value = fluid.layers.assign(input=numpy.array([0])) + out = fluid.layers.sequence_pad(x=x, pad_value=pad_value) + """ + + helper = LayerHelper('sequence_pad', input=x, **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + if maxlen is None: + maxlen = -1 + helper.append_op( + type='sequence_pad', + inputs={'X': x, + 'PadValue': pad_value}, + outputs={'Out': out}, + attrs={'padded_length': maxlen}) + return out + + def beam_search(pre_ids, pre_scores, ids, From 709c37023ae8cf301cc460b665655311523e8b52 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 17:18:14 +0800 Subject: [PATCH 053/140] Polish code --- paddle/scripts/paddle_build.sh | 5 +---- python/paddle/fluid/tests/unittests/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 02bf8533d8..8460f93b84 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -313,10 +313,7 @@ function run_test() { Running unit tests ... ======================================== EOF - echo $http_proxy - echo $https_proxy - ctest -R test_parallel_executor_fetch_feed -V - ctest -R test_dist_se_resnext -V + ctest --output-on-failure # make install should also be test when unittest make install -j `nproc` pip install /usr/local/opt/paddle/share/wheels/*.whl diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 228a5ab917..8ac1cb164e 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -64,7 +64,7 @@ if(WITH_DISTRIBUTE) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) -set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 300) +set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 150) py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) From 0fb5e351c1ddc36c9361de8b279deba99444bd58 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 23 Aug 2018 09:33:00 +0000 Subject: [PATCH 054/140] update API.spec --- paddle/fluid/API.spec | 1 + python/paddle/fluid/layers/nn.py | 1 + 2 files changed, 2 insertions(+) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 46e56981ea..0df617f76d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -113,6 +113,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)) +paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None)) paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d782ea7470..2f115afa6f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -66,6 +66,7 @@ __all__ = [ 'conv2d_transpose', 'conv3d_transpose', 'sequence_expand', + 'sequence_pad', 'lstm_unit', 'reduce_sum', 'reduce_mean', From 0dc5d9c2157bd95479bff67181d05c105e623aa3 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 17:36:54 +0800 Subject: [PATCH 055/140] Port print_siignatures --- tools/print_signatures.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/print_signatures.py b/tools/print_signatures.py index 5e7ffd44c7..e2805c4e7e 100644 --- a/tools/print_signatures.py +++ b/tools/print_signatures.py @@ -17,6 +17,8 @@ Print all signature of a python module in alphabet order. Usage: ./print_signature "paddle.fluid" > signature.txt """ +from __future__ import print_function + import importlib import inspect import collections @@ -64,4 +66,4 @@ def visit_all_module(mod): visit_all_module(importlib.import_module(sys.argv[1])) for name in member_dict: - print name, member_dict[name] + print(name, member_dict[name]) From e895c98f0ae43853e8150594c8ff1cc03a7663b8 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Thu, 23 Aug 2018 07:45:17 +0000 Subject: [PATCH 056/140] add support to max_len is None --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/sequence_mask_op.h | 83 ++++++++++++++----- python/paddle/fluid/layers/nn.py | 45 ++++++++-- .../tests/unittests/test_sequence_mask.py | 16 +++- 4 files changed, 112 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 359db26ed6..01b6053524 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -162,7 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) -paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'max_len', 'mask_dtype'], varargs=None, keywords=None, defaults=('int64',)) +paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/sequence_mask_op.h b/paddle/fluid/operators/sequence_mask_op.h index 237857b51d..0dd554adfe 100644 --- a/paddle/fluid/operators/sequence_mask_op.h +++ b/paddle/fluid/operators/sequence_mask_op.h @@ -14,6 +14,14 @@ #pragma once +#ifdef __NVCC__ +#include +#include +#include +#else +#include +#endif + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" @@ -26,50 +34,60 @@ class SequenceMaskOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); - auto max_len = ctx->Attrs().Get("max_len"); - PADDLE_ENFORCE_GT(max_len, 1, "Attr(max_len) must be larger than 1"); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); - auto dim = framework::vectorize2int(ctx->GetInputDim("X")); - dim.push_back(max_len); - ctx->SetOutputDim("Y", framework::make_ddim(dim)); + + auto maxlen = ctx->Attrs().Get("maxlen"); + if (maxlen > 0) { // We can only infershape when maxlen > 0 + auto dim = framework::vectorize2int(ctx->GetInputDim("X")); + dim.push_back(maxlen); + ctx->SetOutputDim("Y", framework::make_ddim(dim)); + } } }; class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "The input of sequence_mask op."); + AddInput("X", "The input tensor of sequence_mask op."); AddOutput("Y", "The output mask of sequence_mask op."); - AddAttr("max_len", "The maximum length of the sequence.") - .GreaterThan(1); + AddAttr("maxlen", + "The maximum length of the sequence. If maxlen < 0, maxlen " + "= max(Input(X)).") + .SetDefault(-1) + .AddCustomChecker([](int &v) { + PADDLE_ENFORCE(v < 0 || v >= 1, + "Attr(maxlen) must be less than 0 or larger than 1"); + }); AddAttr("out_dtype", "Output data type"); AddComment(R"DOC( SequenceMask Operator -This operator outputs a Mask according to Input(X) and Attr(max_len). +This operator outputs a Mask according to Input(X) and Attr(maxlen). Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the -Output(Y) is a mask with shape [d_1, d_2, ..., d_n, max_len], where: +Output(Y) is a mask with shape [d_1, d_2, ..., d_n, maxlen], where: Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n)) + +If maxlen < 0, maxlen = max(X) )DOC"); } }; template struct SequenceMaskForRangeFunctor { - HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int max_len) - : x_(x), y_(y), max_len_(max_len) {} + HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int maxlen) + : x_(x), y_(y), maxlen_(maxlen) {} HOSTDEVICE void operator()(int y_idx) const { - int x_idx = y_idx / max_len_; - int j = y_idx % max_len_; + int x_idx = y_idx / maxlen_; + int j = y_idx % maxlen_; y_[y_idx] = static_cast(j < x_[x_idx] ? 1 : 0); } private: const Tx *x_; Ty *y_; - int max_len_; + int maxlen_; }; template @@ -77,14 +95,14 @@ struct SequenceMaskFunctor { using Tensor = framework::LoDTensor; SequenceMaskFunctor(const DeviceContext &ctx, const Tx *x, Tensor *y, - int limits, int max_len) - : ctx_(ctx), x_(x), y_(y), limits_(limits), max_len_(max_len) {} + int limits, int maxlen) + : ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {} template void operator()() const { auto *y_data = y_->mutable_data(ctx_.GetPlace()); platform::ForRange for_range(ctx_, limits_); - for_range(SequenceMaskForRangeFunctor(x_, y_data, max_len_)); + for_range(SequenceMaskForRangeFunctor(x_, y_data, maxlen_)); } private: @@ -92,7 +110,7 @@ struct SequenceMaskFunctor { const Tx *x_; Tensor *y_; int limits_; - int max_len_; + int maxlen_; }; template @@ -103,13 +121,32 @@ class SequenceMaskKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { auto *x = ctx.Input("X"); auto *y = ctx.Output("Y"); - auto max_len = ctx.Attr("max_len"); + auto maxlen = ctx.Attr("maxlen"); + + auto *x_data = x->data(); + auto x_numel = x->numel(); + if (maxlen < 0) { +#ifdef __NVCC__ + VLOG(10) + << "SequenceMaskOp on GPU may be slow when maxlen is not provided."; + maxlen = static_cast( + thrust::reduce(thrust::device_pointer_cast(x_data), + thrust::device_pointer_cast(x_data) + x_numel, + static_cast(0), thrust::maximum())); +#else + maxlen = static_cast(*std::max_element(x_data, x_data + x_numel)); +#endif + auto y_dim = framework::vectorize2int(x->dims()); + y_dim.push_back(maxlen); + y->Resize(framework::make_ddim(y_dim)); + } + auto out_dtype = static_cast( ctx.Attr("out_dtype")); auto &dev_ctx = ctx.template device_context(); - framework::VisitDataType(out_dtype, SequenceMaskFunctor( - dev_ctx, x->data(), y, - x->numel() * max_len, max_len)); + framework::VisitDataType(out_dtype, + SequenceMaskFunctor( + dev_ctx, x_data, y, x_numel * maxlen, maxlen)); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1fe457452f..211f828d6f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5525,13 +5525,46 @@ def flatten(x, axis=1, name=None): return out -def sequence_mask(x, max_len, mask_dtype='int64'): +def sequence_mask(x, maxlen=None, dtype='int64', name=None): + """ + **SequenceMask Layer** + + This layer outputs a mask according to the input :code:`x` and + :code:`maxlen` with data type of :code:`dtype`. + + Supposing :code:`x` is a Tensor with shape [d_1, d_2, ..., d_n], the + :code:`y` is a mask with shape [d_1, d_2, ..., d_n, maxlen], where: + + .. math:: + + y(i_1, i_2,..., i_n, j) = (j < x(i_1, i_2,..., i_n)) + + Args: + x (Variable): Input tensor of sequence_mask layer, + whose elements are integers less than :code:`maxlen`. + maxlen (int|None): Maximum length of the sequence. If :code:`maxlen` + is None, it would be replace with :math:`max(x)`. + dtype (np.dtype|core.VarDesc.VarType|str): Data type of the output. + name (str|None): A name for this layer(optional). If set None, the + layer will be named automatically. + + Returns: + Variable: The output sequence mask. + + """ + helper = LayerHelper('sequence_mask', **locals()) - y = helper.create_tmp_variable(dtype=mask_dtype) + if name is None: + out = helper.create_tmp_variable(dtype=dtype) + else: + out = helper.create_tmp_variable(dtype=dtype, name=name) + helper.append_op( type='sequence_mask', inputs={'X': [x]}, - outputs={'Y': y}, - attrs={'max_len': max_len, - 'out_dtype': y.dtype}) - return y + outputs={'Y': out}, + attrs={ + 'max_len': maxlen if maxlen is not None else -1, + 'out_dtype': out.dtype + }) + return out diff --git a/python/paddle/fluid/tests/unittests/test_sequence_mask.py b/python/paddle/fluid/tests/unittests/test_sequence_mask.py index c6d09df984..02c5b20408 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_mask.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_mask.py @@ -13,7 +13,9 @@ # limitations under the License. from op_test import OpTest +import paddle.fluid as fluid from paddle.fluid.framework import convert_np_dtype_to_dtype_ +import paddle.fluid.core as core import numpy as np import copy import unittest @@ -22,7 +24,7 @@ import unittest class SequenceMaskTestBase(OpTest): def initDefaultParameters(self): self.op_type = 'sequence_mask' - self.max_len = 10 + self.maxlen = 10 self.mask_dtype = 'int64' self.x = [[0, 3, 4], [5, 7, 9]] @@ -38,15 +40,16 @@ class SequenceMaskTestBase(OpTest): self.inputs = {'X': self.x} self.outputs = {'Y': self.calc_ground_truth_mask()} self.attrs = { - 'max_len': self.max_len, + 'maxlen': self.maxlen, 'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype) } def calc_ground_truth_mask(self): - shape = self.x.shape + (self.max_len, ) + maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen + shape = self.x.shape + (maxlen, ) index_broadcast = np.broadcast_to( np.reshape( - range(self.max_len), newshape=[1] * self.x.ndim + [-1]), + range(maxlen), newshape=[1] * self.x.ndim + [-1]), shape=shape) x_broadcast = np.broadcast_to( np.reshape( @@ -82,5 +85,10 @@ class SequenceMaskTest5(SequenceMaskTestBase): self.mask_dtype = 'float64' +class SequenceMaskTest6(SequenceMaskTestBase): + def initParameters(self): + self.maxlen = -1 + + if __name__ == '__main__': unittest.main() From 41c10799b8165e67416f26728569377bc92e5775 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 17:40:26 +0800 Subject: [PATCH 057/140] Port tools --- tools/check_ctest_hung.py | 4 +++- tools/timeline.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tools/check_ctest_hung.py b/tools/check_ctest_hung.py index 7de76c381b..c44690a93a 100644 --- a/tools/check_ctest_hung.py +++ b/tools/check_ctest_hung.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import sys import re @@ -46,7 +48,7 @@ Diff: set(['test_parallel_executor_crf']) start_parts = escape(l).split(" ") m = re.search("Start\s+[0-9]+\:\s([a-z0-9_]+)", escape(l)) started.add(m.group(1)) - print "Diff: ", started - passed + print("Diff: ", started - passed) if __name__ == "__main__": diff --git a/tools/timeline.py b/tools/timeline.py index b413bb6fe0..f850476831 100644 --- a/tools/timeline.py +++ b/tools/timeline.py @@ -14,6 +14,7 @@ import argparse import json +import six import sys import unittest @@ -124,7 +125,7 @@ class Timeline(object): return cur_pid def _allocate_pids(self): - for k, profile_pb in self._profile_dict.iteritems(): + for k, profile_pb in six.iteritems(self._profile_dict): for event in profile_pb.events: if event.type == profiler_pb2.Event.CPU: if (k, event.device_id, "CPU") not in self._devices: @@ -140,7 +141,7 @@ class Timeline(object): (k, event.device_id), pid) def _allocate_events(self): - for k, profile_pb in self._profile_dict.iteritems(): + for k, profile_pb in six.iteritems(self._profile_dict): for event in profile_pb.events: if event.type == profiler_pb2.Event.CPU: type = "CPU" From 13686c44747f5a678ee10adf3cee4c509fe07d00 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 17:41:16 +0800 Subject: [PATCH 058/140] Change to debug case --- paddle/scripts/paddle_build.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 8460f93b84..00fb0310e1 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -313,7 +313,9 @@ function run_test() { Running unit tests ... ======================================== EOF - ctest --output-on-failure + #ctest --output-on-failure + ctest -R test_parallel_executor_fetch_feed -V + ctest -R test_dist_se_resnext -V # make install should also be test when unittest make install -j `nproc` pip install /usr/local/opt/paddle/share/wheels/*.whl From 2aac36b3f9bf47e7862091ba28ea925cf6ba346f Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 19:03:15 +0800 Subject: [PATCH 059/140] For test --- paddle/scripts/paddle_build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 00fb0310e1..d5af0eefe3 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -314,8 +314,8 @@ function run_test() { ======================================== EOF #ctest --output-on-failure - ctest -R test_parallel_executor_fetch_feed -V ctest -R test_dist_se_resnext -V + ctest -R test_parallel_executor_fetch_feed -V # make install should also be test when unittest make install -j `nproc` pip install /usr/local/opt/paddle/share/wheels/*.whl From 405d6d09e1b2199711818540cfcfb87494999852 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 23 Aug 2018 19:26:28 +0800 Subject: [PATCH 060/140] Fix doc typo. (#12863) --- doc/fluid/dev/new_op_cn.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/fluid/dev/new_op_cn.md b/doc/fluid/dev/new_op_cn.md index c00f73be95..ff7408111f 100644 --- a/doc/fluid/dev/new_op_cn.md +++ b/doc/fluid/dev/new_op_cn.md @@ -36,19 +36,19 @@ OpProtoMake定义 -`.cc`文件,Backward Op不需要定义OpProtoMake +.cc 文件,Backward Op不需要定义OpProtoMake Op定义 - `.cc`文件 + .cc 文件 Kernel实现 - CPU、CUDA共享Kernel实现在`.h`文件中,否则,CPU 实现在`.cc`文件中,CUDA 实现在`.cu`文件中。 + CPU、CUDA共享Kernel实现在.h 文件中,否则,CPU 实现在.cc 文件中,CUDA 实现在.cu 文件中。 注册Op - Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,CUDA实现在`.cu`文件中 + Op注册实现在.cc 文件;Kernel注册CPU实现在.cc 文件中,CUDA实现在.cu 文件中 @@ -391,7 +391,7 @@ PADDLE_ENFORCE(ctx->HasInput("X"), ""); ``` 问题示例2 :提示信息过于简单 ``` -PADDLE_ENFORCE(i != nullptr, "I must be set"); // I是什么? +PADDLE_ENFORCE(i != nullptr, "i must be set"); // i是什么? ``` 2. 在报错信息中使用开发人员定义的变量缩写,不易理解! From 83f4edabe990bd496720a5dd098f3220dbdb337a Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 23 Aug 2018 19:58:14 +0800 Subject: [PATCH 061/140] remove broadcast in sequence_expand --- paddle/fluid/operators/sequence_expand_op.h | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 39301e1ac0..9228c81310 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -53,25 +53,27 @@ struct SequenceExpandFunctor { const framework::Vector& ref_lod, /*expand referenced lod*/ LoDTensor* out) { int out_offset = 0; - auto& eigen_place = *context.eigen_device(); + int x_item_length = x.numel() / x.dims()[0]; + auto out_data = out->data(); + auto x_data = x.data(); for (size_t i = 1; i < ref_lod.size(); ++i) { int repeat_num = ref_lod[i] - ref_lod[i - 1]; int x_start = x_lod[i - 1]; int x_end = x_lod[i]; int x_seq_len = x_end - x_start; if (repeat_num > 0) { - auto x_sub_tensor = x.Slice(x_start, x_end); - x_sub_tensor.Resize({1, x_sub_tensor.numel()}); int out_start = out_offset; if (out->lod().size() == 1) { out_start = out->lod()[0][out_offset]; } - auto out_sub_tensor = - out->Slice(out_start, out_start + x_seq_len * repeat_num); - out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]}); - EigenMatrix::From(out_sub_tensor).device(eigen_place) = - EigenMatrix::From(x_sub_tensor) - .broadcast(Eigen::array({{repeat_num, 1}})); + for (int j = 0; j < repeat_num; j++) { + for (int k = 0; k < x_seq_len; k++) { + for (int l = 0; l < x_item_length; l++) { + out_data[(out_start + j * x_seq_len + k) * x_item_length + l] = + x_data[(x_start + k) * x_item_length + l]; + } + } + } } out_offset += repeat_num; } From 23bfdf9987c1105ebc067dd42b6ffd3ec8104b4e Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 20:16:26 +0800 Subject: [PATCH 062/140] Port APISpec check --- paddle/scripts/paddle_build.sh | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index d5af0eefe3..4979bd55c1 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -331,7 +331,17 @@ function assert_api_not_changed() { virtualenv .env source .env/bin/activate pip install ${PADDLE_ROOT}/build/python/dist/*whl - python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec + if [ "$1" != "" ]; then + echo "checking python abi: $1" + if [ "$1" == "cp35-cp35m" ]; then + # Always use python2 to generate api signature + LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:} PATH=/opt/python/cp27-cp27mu/bin/:${PATH} python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec + else + python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec + fi + else + python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec + fi python ${PADDLE_ROOT}/tools/diff_api.py ${PADDLE_ROOT}/paddle/fluid/API.spec new.spec deactivate @@ -625,7 +635,7 @@ function main() { gen_capi_package gen_fluid_inference_lib test_fluid_inference_lib - assert_api_not_changed + assert_api_not_changed ${PYTHON_ABI:-""} ;; *) print_usage From 0d46f518aef8d8893f5f438475e6bc53b6f2b8bd Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 23 Aug 2018 14:32:46 +0800 Subject: [PATCH 063/140] refine avx condition and warning --- cmake/configure.cmake | 22 ++++++++++++++-------- paddle/fluid/platform/CMakeLists.txt | 2 +- paddle/fluid/platform/cpu_info.h | 2 +- paddle/fluid/platform/init.cc | 17 +++++++++++++++++ 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index e03e15bfc0..7e5d8a7621 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -50,14 +50,20 @@ if(NOT WITH_PROFILER) endif(NOT WITH_PROFILER) if(NOT CMAKE_CROSSCOMPILING) - if(WITH_AVX AND AVX512F_FOUND) - set(SIMD_FLAG ${AVX512F_FLAG}) - elseif(WITH_AVX AND AVX2_FOUND) - set(SIMD_FLAG ${AVX2_FLAG}) - elseif(WITH_AVX AND AVX_FOUND) - set(SIMD_FLAG ${AVX_FLAG}) - elseif(SSE3_FOUND) - set(SIMD_FLAG ${SSE3_FLAG}) + set(SIMD_FLAG) + if(WITH_AVX) + if (AVX512F_FOUND) + set(SIMD_FLAG "${SIMD_FLAG} ${AVX512F_FLAG}") + endif() + if (AVX2_FOUND) + set(SIMD_FLAG "${SIMD_FLAG} ${AVX2_FLAG}") + endif() + if (AVX_FOUND) + set(SIMD_FLAG "${SIMD_FLAG} ${AVX_FLAG}") + endif() + if (SSE3_FOUND) + set(SIMD_FLAG "${SIMD_FLAG} ${SSE3_FLAG}") + endif() endif() endif() diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index f08c0e8e34..75d3856d0d 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -50,7 +50,7 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS malloc - place eigen3 stringpiece cpu_helper framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) cc_test(init_test SRCS init_test.cc DEPS device_context) diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index 5d17978dd7..30c8fbcfce 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -51,7 +51,7 @@ typedef enum { } cpu_isa_t; // Instruction set architecture // May I use some instruction -inline bool MayIUse(const cpu_isa_t cpu_isa); +bool MayIUse(const cpu_isa_t cpu_isa); } // namespace jit diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 6f1f0c4796..020ce4d6f5 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" @@ -120,6 +121,22 @@ void InitDevices(bool init_p2p, const std::vector devices) { #ifndef PADDLE_WITH_MKLDNN platform::SetNumThreads(FLAGS_paddle_num_threads); #endif + + if (platform::jit::MayIUse(platform::jit::avx512_common)) { +#ifndef __AVX512F__ + LOG(WARNING) << "AVX512F is available, Please re-compile on local machine"; +#endif + } + if (platform::jit::MayIUse(platform::jit::avx2)) { +#ifndef __AVX2__ + LOG(WARNING) << "AVX2 is available, Please re-compile on local machine"; +#endif + } + if (platform::jit::MayIUse(platform::jit::avx)) { +#ifndef __AVX__ + LOG(WARNING) << "AVX is available, Please re-compile on local machine"; +#endif + } } void InitGLOG(const std::string &prog_name) { From 7570e5ef043dcbf5f78904e98df2c4283ec86d47 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 23 Aug 2018 20:32:36 +0800 Subject: [PATCH 064/140] Print readable program codes. (#12673) --- .../tests/unittests/test_program_code.py | 81 +++++++++++ .../fluid/transpiler/details/program_utils.py | 133 ++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_program_code.py diff --git a/python/paddle/fluid/tests/unittests/test_program_code.py b/python/paddle/fluid/tests/unittests/test_program_code.py new file mode 100644 index 0000000000..e9c2b92861 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_program_code.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import unittest +from multiprocessing import Process +import signal + +import numpy + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +from paddle.fluid.layers.io import ListenAndServ +from paddle.fluid.layers.io import Recv +from paddle.fluid.layers.io import Send + +from paddle.fluid.transpiler.details import program_to_code + + +class TestProgram2Code(unittest.TestCase): + def test_print(self): + place = fluid.CPUPlace() + self.init_serv(place) + self.init_client(place, 9123) + + def init_serv(self, place): + main = fluid.Program() + + with fluid.program_guard(main): + serv = ListenAndServ("127.0.0.1:0", ["X"], optimizer_mode=False) + with serv.do(): + out_var = main.global_block().create_var( + name="scale_0.tmp_0", + psersistable=True, + dtype="float32", + shape=[32, 32]) + x = layers.data( + shape=[32, 32], + dtype='float32', + name="X", + append_batch_size=False) + fluid.initializer.Constant(value=1.0)(x, main.global_block()) + layers.scale(x=x, scale=10.0, out=out_var) + + program_to_code(main) + + def init_client(self, place, port): + main = fluid.Program() + with fluid.program_guard(main): + x = layers.data( + shape=[32, 32], + dtype='float32', + name='X', + append_batch_size=False) + fluid.initializer.Constant(value=2.3)(x, main.global_block()) + get_var = main.global_block().create_var( + name="scale_0.tmp_0", # server side var + dtype="float32", + persistable=False, + shape=[32, 32]) + fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) + Send("127.0.0.1:%d" % port, [x]) + o = Recv("127.0.0.1:%d" % port, [get_var]) + + program_to_code(main) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 640dbf4bbe..420ae6dfd4 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -16,6 +16,9 @@ from __future__ import print_function import six +from paddle.fluid import core +import paddle + def delete_ops(block, ops): try: @@ -39,3 +42,133 @@ def find_op_by_output_arg(block, arg_name): if arg_name in op.output_arg_names: return index return -1 + + +def get_indent_space(indent, space_num=4): + ret = "" + for i in range(0, indent * space_num): + ret += " " + + return ret + + +def variable_to_code(var): + """ + Get readable codes of fluid variable. + + Args: + var: A fluid operator. + + Returns: + string: The formatted string. + """ + + var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".\ + format(i="{", e="}", name=var.name, type=var.type, shape=var.shape, dtype=var.dtype) + + if type(var) == paddle.fluid.framework.Parameter: + if var.trainable: + var_str = "trainable parameter " + var_str + else: + var_str = "parameter " + var_str + else: + var_str = "var " + var_str + + if var.persistable: + var_str = "persist " + var_str + + return var_str + + +def op_to_code(op): + """ + Get readable codes of fluid operator. + + Args: + op: A fluid operator. + + Returns: + string: The foramtted string. + """ + + outputs_str = "{" + for i in range(0, len(op.output_names)): + outputs_str += "{name}=".format(name=op.output_names[i]) + o = op.output(op.output_names[i]) + outputs_str += "{value}".format(value=o) + if i != len(op.output_names) - 1: + outputs_str += ", " + outputs_str += "}" + + inputs_str = "{" + for i in range(0, len(op.input_names)): + inputs_str += "{name}=".format(name=op.input_names[i]) + o = op.input(op.input_names[i]) + inputs_str += "{value}".format(value=o) + + if i != len(op.input_names) - 1: + inputs_str += ", " + inputs_str += "}" + + attrs_str = "" + for i in range(0, len(op.attr_names)): + name = op.attr_names[i] + + attr_type = op.desc.attr_type(name) + if attr_type == core.AttrType.BLOCK: + a = "{name} = block[{value}]".format( + name=name, type=attr_type, value=op.block_attr_id(name)) + attrs_str += a + continue + + if attr_type == core.AttrType.BLOCKS: + a = "{name} = blocks{value}".format( + name=name, type=attr_type, value=op.blocks_attr_ids(name)) + attrs_str += a + continue + + a = "{name} = {value}".format( + name=name, type=attr_type, value=op.desc.attr(name)) + attrs_str += a + if i != len(op.attr_names) - 1: + attrs_str += ", " + + if outputs_str != "{}": + op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\ + format(outputs = outputs_str, op_type=op.type, inputs=inputs_str, attrs=attrs_str) + else: + op_str = "{op_type}(inputs={inputs}, {attrs})".\ + format(op_type=op.type, inputs=inputs_str, attrs=attrs_str) + return op_str + + +def program_to_code(prog): + """ + Print readable codes of fluid program. + + Args: + prog : A fluid program. + + An example result like bellow: + https://github.com/PaddlePaddle/Paddle/pull/12673 + """ + indent = 0 + block_idx = 0 + for block in prog.blocks: + print("{0}{1} // block {2}".format( + get_indent_space(indent), '{', block_idx)) + indent += 1 + # sort all vars + all_vars = sorted(block.vars.iteritems(), key=lambda x: x[0]) + for var in all_vars: + print("{}{}".format( + get_indent_space(indent), variable_to_code(var[1]))) + + if len(all_vars) > 0: + print("") + + for op in block.ops: + print("{}{}".format(get_indent_space(indent), op_to_code(op))) + indent -= 1 + print("{0}{1}".format(get_indent_space(indent), '}')) + block_idx += 1 From 2eb46c2b06ce745eba77489029198cef15eb9980 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 23 Aug 2018 18:10:11 +0800 Subject: [PATCH 065/140] add cpu vec test --- paddle/fluid/operators/math/CMakeLists.txt | 1 + paddle/fluid/operators/math/cpu_vec.h | 12 +- paddle/fluid/operators/math/cpu_vec_test.cc | 140 ++++++++++++++++++++ 3 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/math/cpu_vec_test.cc diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d2b772d113..1b75df5d7d 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -65,3 +65,4 @@ if(WITH_GPU) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) +cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 48c0da0e36..3575d9ca67 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -15,6 +15,13 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/platform/cpu_info.h" +#ifdef __AVX__ +#include +#endif + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif namespace paddle { namespace operators { @@ -22,7 +29,6 @@ namespace math { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 template inline T sigmoid(T x) { @@ -46,7 +52,7 @@ inline void vec_sigmoid(const int n, const T* x, T* y) { const T max = SIGMOID_THRESHOLD_MAX; for (int i = 0; i < n; ++i) { T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = 1.0 / (1.0 + std::exp(-tmp)); + y[i] = sigmoid(tmp); } } @@ -96,7 +102,7 @@ class VecActivations { } else if (type == "identity" || type == "") { return vec_identity; } - PADDLE_THROW("Not support type %s.", type); + LOG(FATAL) << "Not support type: " << type; } }; diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc new file mode 100644 index 0000000000..773d4bec4f --- /dev/null +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -0,0 +1,140 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/fluid/operators/math/cpu_vec.h" + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} +constexpr int repeat = 1000; + +template +inline T _sigmoid(T x) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + T tmp = (x < min) ? min : ((x > max) ? max : x); + return 1. / (1. + std::exp(-tmp)); +} + +template +inline T _tanh(T x) { + return 2. * _sigmoid(2. * x) - 1.; +} + +template +void ref_sigmoid(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = _sigmoid(x[i]); + } +} + +template +void ref_tanh(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = _tanh(x[i]); + } +} +template +void ref_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template +void RandomVec(const int n, T* a) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + const T lower = static_cast(-20.f); + const T upper = static_cast(-20.f); + for (int i = 0; i < n; ++i) { + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + } +} + +template +void TestAndBench(const int n, std::function tgt, + std::function ref) { + std::vector x(n); + std::vector ytgt(n), yref(n); + RandomVec(n, x.data()); + + const T* x_data = x.data(); + T* ytgt_data = ytgt.data(); + T* yref_data = yref.data(); + auto st = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + tgt(n, x_data, ytgt_data); + } + auto mt = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ref(n, x_data, yref_data); + } + auto et = GetCurrentUS(); + + VLOG(3) << "Vec size " << n << ": refer takes: " << (et - mt) / repeat + << " us, tgt takes: " << (mt - st) / repeat; + for (int i = 0; i < n; ++i) { + EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3); + } +} + +TEST(CpuVecTest, sigmoid) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 32, 128, 200, 512}) { + TestAndBench(sz, vec_sigmoid, ref_sigmoid); + TestAndBench(sz, vec_sigmoid, ref_sigmoid); + TestAndBench(sz, vec_sigmoid, ref_sigmoid); + TestAndBench(sz, vec_sigmoid, + ref_sigmoid); + } + TestAndBench(30, vec_sigmoid, ref_sigmoid); +} + +TEST(CpuVecTest, tanh) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 32, 128, 200, 512}) { + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, + ref_tanh); + } + TestAndBench(30, vec_tanh, ref_tanh); +} + +TEST(CpuVecTest, relu) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 32, 128, 200, 512}) { + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, + ref_relu); + } + TestAndBench(30, vec_relu, ref_relu); +} From 3fd169daedb408fb922d6342f3f8b550ec1483b9 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 21:32:51 +0800 Subject: [PATCH 066/140] Resume all tests --- paddle/scripts/paddle_build.sh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 4979bd55c1..49a66799bc 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -313,9 +313,7 @@ function run_test() { Running unit tests ... ======================================== EOF - #ctest --output-on-failure - ctest -R test_dist_se_resnext -V - ctest -R test_parallel_executor_fetch_feed -V + ctest --output-on-failure # make install should also be test when unittest make install -j `nproc` pip install /usr/local/opt/paddle/share/wheels/*.whl From 25976fe736804e415a4f3b7fadc5c8ce3c9495f7 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 23 Aug 2018 21:50:35 +0800 Subject: [PATCH 067/140] optimize the sigmoid and tanh --- paddle/fluid/operators/math/cpu_vec.h | 34 ++++++++++++++++----- paddle/fluid/operators/math/cpu_vec_test.cc | 5 +-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 3575d9ca67..6d8acbe539 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/platform/cpu_info.h" #ifdef __AVX__ @@ -31,15 +32,24 @@ namespace math { #define SIGMOID_THRESHOLD_MAX 13.0 template -inline T sigmoid(T x) { - return 1. / (1. + exp(-x)); +inline void vec_exp(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } } -template -inline T tanh(T x) { - return 2. * sigmoid(2. * x) - 1.; +#ifdef PADDLE_WITH_MKLML +template <> +inline void vec_exp(const int n, const float* x, float* y) { + platform::dynload::vsExp(n, x, y); } +template <> +inline void vec_exp(const int n, const double* x, double* y) { + platform::dynload::vdExp(n, x, y); +} +#endif + template inline void vec_identity(const int n, const T* x, T* y) { // do nothing @@ -51,15 +61,23 @@ inline void vec_sigmoid(const int n, const T* x, T* y) { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; for (int i = 0; i < n; ++i) { - T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = sigmoid(tmp); + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + vec_exp(n, y, y); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); } } template inline void vec_tanh(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { - y[i] = tanh(x[i]); + y[i] = static_cast(2) * x[i]; + } + vec_exp(n, y, y); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * y[i] - static_cast(1); } } diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index 773d4bec4f..ab4858984d 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -33,12 +33,13 @@ inline T _sigmoid(T x) { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; T tmp = (x < min) ? min : ((x > max) ? max : x); - return 1. / (1. + std::exp(-tmp)); + return static_cast(1) / (static_cast(1) + std::exp(-tmp)); } template inline T _tanh(T x) { - return 2. * _sigmoid(2. * x) - 1.; + return static_cast(2) * _sigmoid(static_cast(2) * x) - + static_cast(1); } template From b1fc23869417d3bf6c1f647042c8ecfea58043b4 Mon Sep 17 00:00:00 2001 From: guochaorong Date: Thu, 23 Aug 2018 22:36:53 +0800 Subject: [PATCH 068/140] Revert "Disable in_place in batch_norm API. (#12736)" This reverts commit f5d5d7b2d989e8aa5b5e637fd04318566b23f2fe. --- paddle/fluid/operators/batch_norm_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 9 ++------- python/paddle/fluid/nets.py | 2 +- .../paddle/fluid/tests/book/test_image_classification.py | 5 +---- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 969f75544f..5912a1a17c 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Variance", "The global variance (for training) " "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization"); + AddOutput("Y", "result after normalization").Reuse("X"); AddOutput("MeanOut", "Share memory with Mean. " "Store the global mean when training") diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 83250f65e4..4bd260a005 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,7 +27,6 @@ from . import utils import random from .. import unique_name from functools import reduce -import warnings __all__ = [ 'fc', @@ -2048,7 +2047,7 @@ def batch_norm(input, param_attr(ParamAttr): The parameter attribute for Parameter `scale`. bias_attr(ParamAttr): The parameter attribute for Parameter `bias`. data_layout(string, default NCHW): NCHW|NHWC - in_place(bool, Default False): This argument is deprecated since 0.15.0. + in_place(bool, Default False): Make the input and output of batch norm reuse memory. use_mkldnn(bool, Default false): ${use_mkldnn_comment} name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2070,10 +2069,6 @@ def batch_norm(input, helper = LayerHelper('batch_norm', **locals()) dtype = helper.input_dtype() - if in_place: - raise warnings.warn("The argument in_place is deprecated since 0.15.0, " - "please do not set it True.") - input_shape = input.shape if data_layout == 'NCHW': channel_num = input_shape[1] @@ -2123,7 +2118,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = helper.create_tmp_variable(dtype) + batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 01563cbbb7..051fe84364 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -229,7 +229,7 @@ def img_conv_group(input, use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act) + tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index cd1e8cd682..9fe361425c 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -256,10 +256,7 @@ def main(net_type, use_cuda, is_local=True): save_dirname = "image_classification_" + net_type + ".inference.model" train(net_type, use_cuda, save_dirname, is_local) - - # There is bug in fluid.InferenceTranspiler for VGG. - if net_type == "resnet": - infer(use_cuda, save_dirname) + infer(use_cuda, save_dirname) class TestImageClassification(unittest.TestCase): From d82453fbdd5611d4a825a6c4ca8ce95e6aff9e07 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 23 Aug 2018 22:46:02 +0800 Subject: [PATCH 069/140] fix typo (#12896) --- .../design/dist_train/dist_train_nccl2.md | 12 +++++------ .../howto/cluster/nccl2_rdma_training.md | 20 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/doc/fluid/design/dist_train/dist_train_nccl2.md b/doc/fluid/design/dist_train/dist_train_nccl2.md index aa7455ec5d..b8b8427811 100644 --- a/doc/fluid/design/dist_train/dist_train_nccl2.md +++ b/doc/fluid/design/dist_train/dist_train_nccl2.md @@ -1,7 +1,7 @@ # Distributed Training with NCCL2 We design a pattern that can enable training with `ParallelExecutor` and -using [NCCL2](https://developer.nvidia.com/nccl) as it's collective +use [NCCL2](https://developer.nvidia.com/nccl) as it's collective communication library. In `ParallelExecutor` we can use `AllReduce` or `Reduce` and `Broadcast` @@ -9,14 +9,14 @@ to do multi GPU training. And if we initialize NCCL2 communicators as ranks in a distributed environment, we can simply run the `ParallelExecutor` as a distributed program! The only thing that may be different than in the single node version is that we need to broadcast the NCCL unique ID -to all the nodes, and initialize communicators using that ID, so NCCL2 -will know each other as ranks. +to all the nodes and initialize communicators using that ID, so NCCL2 +can know each other as ranks. To achieve this feature, we introduce a new operator: `gen_nccl_id` op, so we are ***not*** "bind to" running NCCL2 with MPI, we can run it in -what ever platform you like. +whatever platform you like. -It have two running modes: +It has two running modes: 1. Generate and broadcast mode, which should be used on trainer 0; 1. Listen and fetch mode, which should be used on trainers other than 0. @@ -29,7 +29,7 @@ initialize NCCL communicator objects. The above figure indicates the general process when training with NCCL2 -distributed. Each trainer have the number of communicators equal to the +distributed. Each trainer has the number of communicators equal to the number of GPUs, but the ranks should match the global ranks number: here we have total 8 GPUs, so `nranks==8`, for each trainer, the ranks should be from 0 ~ 3 on trainer 0 and 4 ~ 7 on trainer 1. diff --git a/doc/fluid/howto/cluster/nccl2_rdma_training.md b/doc/fluid/howto/cluster/nccl2_rdma_training.md index cecd5c3a7a..8adaf324fc 100644 --- a/doc/fluid/howto/cluster/nccl2_rdma_training.md +++ b/doc/fluid/howto/cluster/nccl2_rdma_training.md @@ -1,12 +1,12 @@ # Distributed Training with NCCL2 and RDMA -When doing distributed multi-GPU training, network bandwith often becomes the -bottle neck. We introduce a way to use NCCL2 to do such training job to -achieve best performace. +When doing distributed multi-GPU training, network bandwidth often becomes the +bottleneck. We introduce a way to use NCCL2 to do such training job to +achieve best performance. -## Prepare Hardwares with RDMA and Multiple GPUs +## Prepare Hardware with RDMA and Multiple GPUs -I'm using two Linux servers each of them is installed with 8 GPUs and +I'm using two Linux servers each of them installed with 8 GPUs and one 100Gb RDMA card. Base environment is: @@ -25,7 +25,7 @@ In general, the steps including: 1. Use docker to run tests and make sure GPUs and RDMA can work inside the container. -I'll ommit section "Install GPU drivers" because we can find it easily +I'll omit the section "Install GPU drivers" because we can find it easily somewhere else. ### Install RDMA drivers @@ -33,7 +33,7 @@ somewhere else. For my case, I've got two machines with device "Mellanox Technologies MT27700 Family [ConnectX-4]" installed. The OS was "CentOS 7.4" and I updated the kernel to version 4.4 so that docker can -work with latest overlay2 filesystem. +work with the latest overlay2 filesystem. ***NOTE: before you start, make sure you have a way to get a console of the server other than ssh because we may need to re-configure the @@ -45,14 +45,14 @@ network device.*** 1. Run `./mlnxofedinstall --add-kernel-support` in the software package. 1. Run `/etc/init.d/openibd restart` to make everything work, note that this operation may cause the network goes down if you are using this - RDMA device as default network device and use ssh to login the server. + RDMA device as default network device and use ssh to log in the server. 1. Re-configure the network interface, for example: `ifconfig eth2 192.168.16.30/20 up`, then add routes if needed: `ip route add default via 192.168.16.1 dev eth2`. 1. Do the same thing on the other node. 1. Use `ping` to test if the two nodes have typical ICMP connection. 1. Use either `udaddy` or `ib_write_bw` to test the network connection is - ready and have the desired bandwith. + ready and have the desired bandwidth. ### Prepare Docker Image to Run RDMA Programs @@ -60,7 +60,7 @@ network device.*** package in it. 1. Start a docker container and mount GPU driver libs into it (you can skip this step if you are using nvidia-docker). -1. Mount RDMA dirvers and libs into the docker image (see below section), +1. Mount RDMA drivers and libs into the docker image (see below section), also `udaddy` and `ib_write_bw` if needed. 1. Mount GPU devices and RDMA devices into the container using `--device` or just use privileged mode `--privileged`. From 1f270275a6d3b9c2a279609aa781e1cd30018523 Mon Sep 17 00:00:00 2001 From: guochaorong Date: Thu, 23 Aug 2018 22:59:59 +0800 Subject: [PATCH 070/140] Revert "Add Python Callstacks when Op::Run error (#12759)" This reverts commit b2df17003f22712078df75b299fb27934650319d. --- paddle/fluid/framework/op_proto_maker.cc | 4 -- paddle/fluid/framework/op_proto_maker.h | 1 - paddle/fluid/framework/operator.cc | 61 +++++-------------- paddle/fluid/operators/top_k_op.cc | 2 - paddle/fluid/pybind/const_value.cc | 3 - python/paddle/fluid/framework.py | 5 -- .../tests/unittests/test_operator_desc.py | 5 +- 7 files changed, 16 insertions(+), 65 deletions(-) diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 9c289243c5..2288c7fe66 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -129,10 +129,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, "Optimized for variable") .SetDefault({}); - AddAttr>(OpCreationCallstackAttrName(), - "Callstack for Op Creatation.") - .SetDefault({}); - Validate(); } diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index cb9c8ab170..80970291c9 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -39,7 +39,6 @@ class OpProtoAndCheckerMaker { public: static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleVarAttrName() { return "op_role_var"; } - static const char *OpCreationCallstackAttrName() { return "op_callstack"; } void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9f8cdf1aeb..d04f774496 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -11,17 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/operator.h" +#include +#include + #include -#include -#include -#include -#include "gflags/gflags.h" -#include "glog/logging.h" + #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/profiler.h" @@ -129,48 +127,19 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } void OperatorBase::Run(const Scope& scope, const platform::Place& place) { - try { - if (VLOG_IS_ON(4)) { - VLOG(4) << place << " " << DebugStringEx(&scope); - } - if (platform::is_gpu_place(place)) { + VLOG(4) << place << " " << DebugStringEx(&scope); + if (platform::is_gpu_place(place)) { #ifndef PADDLE_WITH_CUDA - PADDLE_THROW("Cannot run operator on place %s", place); + PADDLE_THROW("Cannot run operator on place %s", place); #else - auto dev_id = boost::get(place).device; - platform::SetDeviceId(dev_id); + auto dev_id = boost::get(place).device; + platform::SetDeviceId(dev_id); #endif - } - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::RecordEvent record_event(Type(), pool.Get(place)); - RunImpl(scope, place); - if (VLOG_IS_ON(3)) { - VLOG(3) << place << " " << DebugStringEx(&scope); - } - } catch (platform::EnforceNotMet exception) { - if (Attrs().count("sub_block") != 0) { - throw exception; - } - - auto& callstack = Attr>( - OpProtoAndCheckerMaker::OpCreationCallstackAttrName()); - - if (callstack.empty()) { - throw exception; - } - std::ostringstream sout; - sout << "Invoke operator " << Type() << " error.\n"; - sout << "Python Callstacks: \n"; - for (auto& line : callstack) { - sout << line; - } - sout << "C++ Callstacks: \n"; - sout << exception.err_str_; - exception.err_str_ = sout.str(); - throw exception; - } catch (...) { - std::rethrow_exception(std::current_exception()); } + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::RecordEvent record_event(Type(), pool.Get(place)); + RunImpl(scope, place); + VLOG(3) << place << " " << DebugStringEx(&scope); } bool OperatorBase::HasInputs(const std::string& name) const { @@ -198,7 +167,7 @@ const std::vector& OperatorBase::Inputs( } bool OperatorBase::HasOutputs(const std::string& name) const { - if (outputs_.end() != outputs_.find(name)) { + if (outputs_.find(name) != outputs_.end()) { return true; } else { return false; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 92a0697e27..4a8ac441cf 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -30,8 +30,6 @@ class TopkOp : public framework::OperatorWithKernel { "Output(Indices) of TopkOp should not be null."); auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(input_dims.size(), 2, - "Rank of TopK op's input must be 2."); const int k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index a81715c3b3..e4415ed15c 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -43,9 +43,6 @@ void BindConstValue(pybind11::module* m) { op_proto_and_checker_maker.def( "kOpRoleVarAttrName", framework::OpProtoAndCheckerMaker::OpRoleVarAttrName); - op_proto_and_checker_maker.def( - "kOpCreationCallstackAttrName", - framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName); } } // namespace pybind diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e0ddd3b5ff..febb750ee1 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -18,7 +18,6 @@ import collections import contextlib import re import six -import traceback import numpy as np @@ -506,10 +505,6 @@ class Operator(object): if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: del op_attrs[role_var_name] - callstack_var_name = op_maker.kOpCreationCallstackAttrName() - op_attrs[callstack_var_name] = list( - reversed(traceback.format_stack()))[1:] - if len(self.desc.type()) != 0: return if type is None: diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 3ac8268073..6d01955993 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -67,10 +67,7 @@ class TestOperator(unittest.TestCase): self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual( set(mul_op.attr_names), - set([ - "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", - "op_callstack" - ])) + set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr("x_num_col_dims"), 1) From 0eccd59425c24fb3367c48d1545863c624d4c77b Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 23 Aug 2018 23:00:03 +0800 Subject: [PATCH 071/140] Keep APISpec the same with Python2 --- paddle/scripts/paddle_build.sh | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 49a66799bc..5dadef7e76 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -329,16 +329,11 @@ function assert_api_not_changed() { virtualenv .env source .env/bin/activate pip install ${PADDLE_ROOT}/build/python/dist/*whl - if [ "$1" != "" ]; then - echo "checking python abi: $1" - if [ "$1" == "cp35-cp35m" ]; then - # Always use python2 to generate api signature - LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:} PATH=/opt/python/cp27-cp27mu/bin/:${PATH} python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec - else - python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec - fi - else - python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec + python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec + if [ "$1" == "cp35-cp35m" ]; then + # Use sed to make python2 and python3 sepc keeps the same + sed -i 's/arg0: str/arg0: unicode/g' new.spec + sed -i "s/\(.*Transpiler.*\).__init__ ArgSpec(args=\['self'].*/\1.__init__ /g" new.spec fi python ${PADDLE_ROOT}/tools/diff_api.py ${PADDLE_ROOT}/paddle/fluid/API.spec new.spec deactivate From e3bb98eb38f8938ee3a0f8b07d8f486aca6ccfe3 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 00:01:09 +0800 Subject: [PATCH 072/140] optimize relu with avx and avx512 --- paddle/fluid/operators/math/cpu_vec.h | 83 ++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 6d8acbe539..e74e84055a 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -31,6 +31,13 @@ namespace math { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 +#define AVX_FLOAT_BLOCK 8 +#define AVX_DOUBLE_BLOCK 4 +#define AVX2_FLOAT_BLOCK 8 +#define AVX2_DOUBLE_BLOCK 4 +#define AVX512_FLOAT_BLOCK 16 +#define AVX512_DOUBLE_BLOCK 8 + template inline void vec_exp(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { @@ -88,24 +95,82 @@ inline void vec_relu(const int n, const T* x, T* y) { } } +template <> +inline void vec_relu(const int n, const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_relu(n, x, y); + return; + } + + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 zeros = _mm256_setzero_ps(); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, zeros); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } + if (rest == 0) { + return; + } + i = n - block; + MOVE_ONE_STEP; +#undef MOVE_ONE_STEP + +#else + vec_relu(n, x, y); +#endif +} + template <> inline void vec_relu(const int n, const float* x, float* y) { - // TODO(TJ): complete me - for (int i = 0; i < n; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; - } + vec_relu(n, x, y); } template <> -inline void vec_relu(const int n, const float* x, - float* y) { - // TODO(TJ): complete me - for (int i = 0; i < n; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; +inline void vec_relu(const int n, + const float* x, + float* y) { +#ifdef __AVX512F__ + // test me + constexpr int block = AVX512_FLOAT_BLOCK; + if (n < block) { + vec_relu(n, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m512 zeros = _mm512_setzero_ps(); + __m512 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm512_loadu_ps(x + i); \ + tmp = _mm512_max_ps(tmp, zeros); \ + _mm512_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; } + if (rest == 0) { + return; + } + i = n - block; + MOVE_ONE_STEP; +#undef MOVE_ONE_STEP +#else + vec_relu(n, x, y); +#endif } +// TODO(TJ): optimize double of sigmoid, tanh and relu if necessary + template class VecActivations { public: From 6bd89ba5b6966f9c328cbf3fe187a5768c5e0664 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 00:47:17 +0800 Subject: [PATCH 073/140] fix typo --- paddle/fluid/operators/math/cpu_vec.h | 2 +- paddle/fluid/operators/math/cpu_vec_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index e74e84055a..a2e2b5a7fe 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -82,7 +82,7 @@ inline void vec_tanh(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = static_cast(2) * x[i]; } - vec_exp(n, y, y); + vec_sigmoid(n, y, y); for (int i = 0; i < n; ++i) { y[i] = static_cast(2) * y[i] - static_cast(1); } diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index ab4858984d..0888e44fa6 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -68,7 +68,7 @@ void RandomVec(const int n, T* a) { std::mt19937 rng(seed++); std::uniform_real_distribution uniform_dist(0, 1); const T lower = static_cast(-20.f); - const T upper = static_cast(-20.f); + const T upper = static_cast(20.f); for (int i = 0; i < n; ++i) { a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } From a7849db561df9bf7a2c5961df63a861106f90b43 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 24 Aug 2018 00:47:41 +0800 Subject: [PATCH 074/140] Port new added code --- python/paddle/fluid/tests/unittests/test_attention_lstm_op.py | 2 +- python/paddle/fluid/transpiler/details/program_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py index a7382c2244..1b9c3efe0f 100644 --- a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -37,7 +37,7 @@ def attention_lstm( T = sum(lod[0]) N = len(lod[0]) M = x.shape[1] - D = b.shape[1] / 4 + D = b.shape[1] // 4 assert T == x.shape[0] assert len(fcws) == len(fcbs) hidden = [] diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 420ae6dfd4..64863aceee 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -159,7 +159,7 @@ def program_to_code(prog): get_indent_space(indent), '{', block_idx)) indent += 1 # sort all vars - all_vars = sorted(block.vars.iteritems(), key=lambda x: x[0]) + all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0]) for var in all_vars: print("{}{}".format( get_indent_space(indent), variable_to_code(var[1]))) From ca22586818c2ce9d9b4ac83f49a3c7a54570cc6b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 24 Aug 2018 10:51:37 +0800 Subject: [PATCH 075/140] code optimize (cherry picked from commit 587cca7) --- paddle/fluid/operators/fill_constant_op.cc | 27 +++++++++++++++------ paddle/fluid/operators/uniform_random_op.cc | 2 +- paddle/fluid/operators/uniform_random_op.cu | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 130f18dde4..2826b82117 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -41,19 +40,33 @@ class FillConstantOp : public framework::OperatorBase { static_cast(Attr("dtype")); auto value = Attr("value"); auto force_cpu = Attr("force_cpu"); - auto &out = - *scope.FindVar(Output("Out"))->GetMutable(); - out.Resize(framework::make_ddim(Attr>("shape"))); + + framework::Tensor *tensor = nullptr; + + auto &out_var = *scope.FindVar(Output("Out")); + + if (out_var.IsType()) { + tensor = out_var.GetMutable(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else if (out_var.IsType()) { + tensor = out_var.GetMutable()->mutable_value(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else { + PADDLE_THROW( + "fill constant op's output only" + "supports SelectedRows and LoDTensor"); + } + if (force_cpu) { auto cpu = platform::CPUPlace(); - out.mutable_data(cpu, framework::ToTypeIndex(data_type)); + tensor->mutable_data(cpu, framework::ToTypeIndex(data_type)); } else { - out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); + tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type)); } platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); - math::set_constant(dev_ctx, &out, value); + math::set_constant(dev_ctx, tensor, value); } }; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 5248767c2e..763bb40358 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -37,7 +37,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { } else { PADDLE_THROW( "uniform_random_op's output only" - "supports SelectedRows and Tensor"); + "supports SelectedRows and LoDTensor"); } T* data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index e1c7323a30..bbb692b0dd 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -54,7 +54,7 @@ class GPUUniformRandomKernel : public framework::OpKernel { } else { PADDLE_THROW( "uniform_random_op's output only" - "supports SelectedRows and Tensor"); + "supports SelectedRows and LoDTensor"); } T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); From fcf20eed0fba2c6576fd66139a9d3f134a0793c4 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 24 Aug 2018 10:57:40 +0800 Subject: [PATCH 076/140] fix sparse update bug --- paddle/fluid/operators/distributed/variable_response.cc | 1 + paddle/fluid/operators/listen_and_serv_op.cc | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 8e38b3713f..1617cc1b95 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -151,6 +151,7 @@ bool VariableResponse::CopySelectRowsData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, int length) { auto* slr = GetVar()->GetMutable(); + slr->mutable_rows()->clear(); slr->mutable_rows()->resize(length / framework::SizeOfType(typeid(int64_t))); // int64 int64_t* rows_data = slr->mutable_rows()->data(); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index f196e18fe1..4cc2159d9f 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -165,12 +165,13 @@ void ListenAndServOp::RunSyncLoop( recv_scope); VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; - rpc_service_->SetCond(distributed::kRequestGet); - rpc_service_->WaitBarrier(distributed::kRequestGet); - rpc_service_->ResetBarrierCounter(); // reset received sparse vars to avoid reuse it in the next mini-batch dynamic_cast(request_send_handler_.get()) ->ResetSparseVarRecorder(); + + rpc_service_->SetCond(distributed::kRequestGet); + rpc_service_->WaitBarrier(distributed::kRequestGet); + rpc_service_->ResetBarrierCounter(); } // while(true) } From fca139b5e302c46a26d99d0b57546010d3c97590 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 24 Aug 2018 11:07:47 +0800 Subject: [PATCH 077/140] Fix flowers dataset download problem --- python/paddle/dataset/common.py | 3 +++ python/paddle/dataset/flowers.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 1d7ff582c8..ece4046f5b 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -19,6 +19,7 @@ import hashlib import os import errno import shutil +import six import sys import importlib import paddle.dataset @@ -94,6 +95,8 @@ def download(url, module_name, md5sum, save_name=None): dl = 0 total_length = int(total_length) for data in r.iter_content(chunk_size=4096): + if six.PY2: + data = six.b(data) dl += len(data) f.write(data) done = int(50 * dl / total_length) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index aa73bbaf70..0a1cdaceaf 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -35,6 +35,7 @@ import itertools import functools from .common import download import tarfile +import six import scipy.io as scio from paddle.dataset.image import * from paddle.reader import * @@ -45,10 +46,10 @@ from six.moves import cPickle as pickle from six.moves import zip __all__ = ['train', 'test', 'valid'] -DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' -LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' -SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' -DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' +DATA_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/102flowers.tgz' +LABEL_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/imagelabels.mat' +SETID_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/setid.mat' +DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' # In official 'readme', tstid is the flag of test data @@ -120,7 +121,10 @@ def reader_creator(data_file, file = file.strip() batch = None with open(file, 'rb') as f: - batch = pickle.load(f) + if six.PY2: + batch = pickle.load(f) + else: + batch = pickle.load(f, encoding='bytes') data = batch['data'] labels = batch['label'] for sample, label in zip(data, batch['label']): From 8f9bbc2834c35d368b680d87fe50342717d28d31 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 24 Aug 2018 11:27:54 +0800 Subject: [PATCH 078/140] add unit test --- .../tests/unittests/test_fill_constant_op.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 44fb1d047d..b73711b19d 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -18,6 +18,9 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator + class TestFillConstantOp1(OpTest): def setUp(self): @@ -47,5 +50,27 @@ class TestFillConstantOp2(OpTest): self.check_output() +class TestFillConstantOpWithSelectedRows(OpTest): + def check_with_place(self, place): + scope = core.Scope() + # create Out Variable + out = scope.var('Out').get_selected_rows() + + # create and run fill_constant_op operator + fill_constant_op = Operator( + "fill_constant", shape=[123, 92], value=3.8, Out='Out') + fill_constant_op.run(scope, place) + + # get result from Out + result_array = np.array(out) + self.assertEqual(result_array, np.full((123, 92), 3.8)) + + def test_fill_constant_with_selected_rows(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main() From 7a4924cd44a47f3562d62c01d0c40e84ca78540e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 11:46:59 +0800 Subject: [PATCH 079/140] further optimize sigmoid with avx and avx512 --- paddle/fluid/operators/math/cpu_vec.h | 116 ++++++++++++++++++++ paddle/fluid/operators/math/cpu_vec_test.cc | 6 +- 2 files changed, 119 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index a2e2b5a7fe..52f072eb0e 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -77,6 +77,122 @@ inline void vec_sigmoid(const int n, const T* x, T* y) { } } +template <> +inline void vec_sigmoid(const int n, const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { // can use larger threshold if necessary + vec_sigmoid(n, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 zeros = _mm256_setzero_ps(); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, min); \ + tmp = _mm256_min_ps(tmp, max); \ + tmp = _mm256_sub_ps(zeros, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } + if (rest != 0) { + i = n - block; + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + + vec_exp(n, y, y); + + __m256 ones = _mm256_set1_ps(1.0f); +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(y + i); \ + tmp = _mm256_add_ps(ones, tmp); \ + tmp = _mm256_div_ps(ones, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step + for (i = n - rest; i < n; ++i) { + y[i] = 1.f / (1.f + y[i]); + } +#else + vec_sigmoid(n, x, y); +#endif +} + +template <> +inline void vec_sigmoid(const int n, const float* x, + float* y) { + vec_sigmoid(n, x, y); +} + +template <> +inline void vec_sigmoid(const int n, + const float* x, + float* y) { +#ifdef __AVX512F__ + constexpr int block = AVX512_FLOAT_BLOCK; + if (n < block) { + vec_sigmoid(n, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m512 max = _mm512_set1_ps(SIGMOID_THRESHOLD_MAX); + __m512 min = _mm512_set1_ps(SIGMOID_THRESHOLD_MIN); + __m512 zeros = _mm512_setzero_ps(); + __m512 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm512_loadu_ps(x + i); \ + tmp = _mm512_max_ps(tmp, min); \ + tmp = _mm512_min_ps(tmp, max); \ + tmp = _mm512_sub_ps(zeros, tmp); \ + _mm512_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } + if (rest != 0) { + i = n - block; + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + + vec_exp(n, y, y); + + __m512 ones = _mm512_set1_ps(1.0f); +#define MOVE_ONE_STEP \ + tmp = _mm512_loadu_ps(y + i); \ + tmp = _mm512_add_ps(ones, tmp); \ + tmp = _mm512_div_ps(ones, tmp); \ + _mm512_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + for (i = n - rest; i < n; ++i) { + y[i] = 1.f / (1.f + y[i]); + } +#else + vec_sigmoid(n, x, y); +#endif +} + template inline void vec_tanh(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index 0888e44fa6..8b0e9c086a 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -104,7 +104,7 @@ void TestAndBench(const int n, std::function tgt, TEST(CpuVecTest, sigmoid) { namespace jit = paddle::platform::jit; using namespace paddle::operators::math; // NOLINT - for (auto sz : {1, 2, 15, 16, 32, 128, 200, 512}) { + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_sigmoid, ref_sigmoid); TestAndBench(sz, vec_sigmoid, ref_sigmoid); TestAndBench(sz, vec_sigmoid, ref_sigmoid); @@ -117,7 +117,7 @@ TEST(CpuVecTest, sigmoid) { TEST(CpuVecTest, tanh) { namespace jit = paddle::platform::jit; using namespace paddle::operators::math; // NOLINT - for (auto sz : {1, 2, 15, 16, 32, 128, 200, 512}) { + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); @@ -130,7 +130,7 @@ TEST(CpuVecTest, tanh) { TEST(CpuVecTest, relu) { namespace jit = paddle::platform::jit; using namespace paddle::operators::math; // NOLINT - for (auto sz : {1, 2, 15, 16, 32, 128, 200, 512}) { + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); From c70a3fec3e3b469e381279917deb79b786e6b821 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 11:51:50 +0800 Subject: [PATCH 080/140] fix redefinition of argument machine --- cmake/configure.cmake | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 7e5d8a7621..e03e15bfc0 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -50,20 +50,14 @@ if(NOT WITH_PROFILER) endif(NOT WITH_PROFILER) if(NOT CMAKE_CROSSCOMPILING) - set(SIMD_FLAG) - if(WITH_AVX) - if (AVX512F_FOUND) - set(SIMD_FLAG "${SIMD_FLAG} ${AVX512F_FLAG}") - endif() - if (AVX2_FOUND) - set(SIMD_FLAG "${SIMD_FLAG} ${AVX2_FLAG}") - endif() - if (AVX_FOUND) - set(SIMD_FLAG "${SIMD_FLAG} ${AVX_FLAG}") - endif() - if (SSE3_FOUND) - set(SIMD_FLAG "${SIMD_FLAG} ${SSE3_FLAG}") - endif() + if(WITH_AVX AND AVX512F_FOUND) + set(SIMD_FLAG ${AVX512F_FLAG}) + elseif(WITH_AVX AND AVX2_FOUND) + set(SIMD_FLAG ${AVX2_FLAG}) + elseif(WITH_AVX AND AVX_FOUND) + set(SIMD_FLAG ${AVX_FLAG}) + elseif(SSE3_FOUND) + set(SIMD_FLAG ${SSE3_FLAG}) endif() endif() From fff6f595ff502d71c29dd1b5824f1d2940cd6069 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 24 Aug 2018 12:31:29 +0800 Subject: [PATCH 081/140] add unit test --- .../paddle/fluid/tests/unittests/test_fill_constant_op.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index b73711b19d..537cabd5d0 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -62,8 +62,10 @@ class TestFillConstantOpWithSelectedRows(OpTest): fill_constant_op.run(scope, place) # get result from Out - result_array = np.array(out) - self.assertEqual(result_array, np.full((123, 92), 3.8)) + result_array = np.array(out.get_tensor()) + full_array = np.full((123, 92), 3.8, 'float32') + + self.assertTrue(np.array_equal(result_array, full_array)) def test_fill_constant_with_selected_rows(self): places = [core.CPUPlace()] From 66cc1850a8e29858776fe31e4dc00e5dab49f2be Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 24 Aug 2018 12:50:10 +0800 Subject: [PATCH 082/140] add gpu place --- python/paddle/fluid/tests/unittests/test_fill_constant_op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 537cabd5d0..fd59c5bb7c 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -69,7 +69,9 @@ class TestFillConstantOpWithSelectedRows(OpTest): def test_fill_constant_with_selected_rows(self): places = [core.CPUPlace()] - # currently only support CPU + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: self.check_with_place(place) From 2b4edacca0d8756665dce87402043bb5f7ca26c6 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Fri, 24 Aug 2018 13:14:35 +0800 Subject: [PATCH 083/140] enhance the forward of concat op --- paddle/fluid/operators/math/concat.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index fbe7c29783..c3c5c160db 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -48,16 +48,16 @@ class ConcatFunctor { auto cpu_place = boost::get(context.GetPlace()); // computation - for (int k = 0; k < out_rows; ++k) { - T* dst_ptr = output->data() + k * out_cols; - int col_idx = 0; - for (int j = 0; j < num; ++j) { - int col_len = input_cols[j]; - const T* src_prt = input[j].data() + k * col_len; - memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, - sizeof(T) * col_len); - col_idx += col_len; + auto output_data = output->data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = input[j].data(); + for (int k = 0; k < out_rows; ++k) { + memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place, + input_data + k * col_len, sizeof(T) * col_len); } + col_idx += col_len; } } }; From f269614bcde3a7526dc164cb5ca9691a605709de Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 12:59:16 +0800 Subject: [PATCH 084/140] further optimize tanh with avx and mkl --- paddle/fluid/operators/math/cpu_vec.h | 175 +++++++++++++------------- 1 file changed, 90 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 52f072eb0e..d5f247e7ef 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -45,6 +45,13 @@ inline void vec_exp(const int n, const T* x, T* y) { } } +template +inline void vec_scal(const int n, const T a, T* x) { + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +} + #ifdef PADDLE_WITH_MKLML template <> inline void vec_exp(const int n, const float* x, float* y) { @@ -55,7 +62,74 @@ template <> inline void vec_exp(const int n, const double* x, double* y) { platform::dynload::vdExp(n, x, y); } + +template <> +inline void vec_scal(const int n, const float a, float* x) { + platform::dynload::cblas_sscal(n, a, x, 1); +} + +template <> +inline void vec_scal(const int n, const double a, double* x) { + platform::dynload::cblas_dscal(n, a, x, 1); +} +#endif + +// MKL scal only support inplace, choose this if src and dst are not equal +template +inline void vec_scal(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = a * x[i]; + } +} + +template <> +inline void vec_scal(const int n, const float a, + const float* x, float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block * 4) { // use larger threshold, since small ones has no boost + vec_scal(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 scalar = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = a * x[i]; + } +#else + vec_scal(n, a, x, y); #endif +} + +template <> +inline void vec_scal(const int n, const float a, + const float* x, float* y) { + vec_scal(n, a, x, y); +} + +template <> +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_scal(n, a, x, y); +} template inline void vec_identity(const int n, const T* x, T* y) { @@ -82,7 +156,7 @@ inline void vec_sigmoid(const int n, const float* x, float* y) { #ifdef __AVX__ constexpr int block = AVX_FLOAT_BLOCK; - if (n < block) { // can use larger threshold if necessary + if (n < block) { vec_sigmoid(n, x, y); return; } @@ -102,11 +176,15 @@ inline void vec_sigmoid(const int n, const float* x, for (i = 0; i < end; i += block) { MOVE_ONE_STEP; } +#undef MOVE_ONE_STEP if (rest != 0) { - i = n - block; - MOVE_ONE_STEP; + // can not continue move step since the src and dst address could be equal + const float xmin = SIGMOID_THRESHOLD_MIN; + const float xmax = SIGMOID_THRESHOLD_MAX; + for (i = n - rest; i < n; ++i) { + y[i] = 0.f - ((x[i] < xmin) ? xmin : ((x[i] > xmax) ? xmax : x[i])); + } } -#undef MOVE_ONE_STEP vec_exp(n, y, y); @@ -142,65 +220,17 @@ template <> inline void vec_sigmoid(const int n, const float* x, float* y) { -#ifdef __AVX512F__ - constexpr int block = AVX512_FLOAT_BLOCK; - if (n < block) { - vec_sigmoid(n, x, y); - return; - } - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m512 max = _mm512_set1_ps(SIGMOID_THRESHOLD_MAX); - __m512 min = _mm512_set1_ps(SIGMOID_THRESHOLD_MIN); - __m512 zeros = _mm512_setzero_ps(); - __m512 tmp; -#define MOVE_ONE_STEP \ - tmp = _mm512_loadu_ps(x + i); \ - tmp = _mm512_max_ps(tmp, min); \ - tmp = _mm512_min_ps(tmp, max); \ - tmp = _mm512_sub_ps(zeros, tmp); \ - _mm512_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } - if (rest != 0) { - i = n - block; - MOVE_ONE_STEP; - } -#undef MOVE_ONE_STEP - - vec_exp(n, y, y); - - __m512 ones = _mm512_set1_ps(1.0f); -#define MOVE_ONE_STEP \ - tmp = _mm512_loadu_ps(y + i); \ - tmp = _mm512_add_ps(ones, tmp); \ - tmp = _mm512_div_ps(ones, tmp); \ - _mm512_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } -#undef MOVE_ONE_STEP - if (rest == 0) { - return; - } - for (i = n - rest; i < n; ++i) { - y[i] = 1.f / (1.f + y[i]); - } -#else - vec_sigmoid(n, x, y); -#endif + // TODO(TJ): enable me + vec_sigmoid(n, x, y); } template inline void vec_tanh(const int n, const T* x, T* y) { + vec_scal(n, static_cast(2), x, y); + vec_sigmoid(n, y, y); + vec_scal(n, static_cast(2), y); for (int i = 0; i < n; ++i) { - y[i] = static_cast(2) * x[i]; - } - vec_sigmoid(n, y, y); - for (int i = 0; i < n; ++i) { - y[i] = static_cast(2) * y[i] - static_cast(1); + y[i] = y[i] - static_cast(1); } } @@ -255,35 +285,10 @@ template <> inline void vec_relu(const int n, const float* x, float* y) { -#ifdef __AVX512F__ - // test me - constexpr int block = AVX512_FLOAT_BLOCK; - if (n < block) { - vec_relu(n, x, y); - return; - } - const int rest = n % block; - const int end = n - rest; - int i = 0; - __m512 zeros = _mm512_setzero_ps(); - __m512 tmp; -#define MOVE_ONE_STEP \ - tmp = _mm512_loadu_ps(x + i); \ - tmp = _mm512_max_ps(tmp, zeros); \ - _mm512_storeu_ps(y + i, tmp) - for (i = 0; i < end; i += block) { - MOVE_ONE_STEP; - } - if (rest == 0) { - return; - } - i = n - block; - MOVE_ONE_STEP; -#undef MOVE_ONE_STEP -#else + // TODO(TJ): enable me vec_relu(n, x, y); -#endif } +// TODO(TJ): add vec add bias, make relu clip // TODO(TJ): optimize double of sigmoid, tanh and relu if necessary From bb9f98e10d0d138119070af17ab74cec7e94244d Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 14:04:49 +0800 Subject: [PATCH 085/140] add inplace test --- paddle/fluid/operators/math/cpu_vec_test.cc | 61 +++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index 8b0e9c086a..bf6481c5cc 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include #include "gflags/gflags.h" #include "glog/logging.h" @@ -139,3 +140,63 @@ TEST(CpuVecTest, relu) { } TestAndBench(30, vec_relu, ref_relu); } + +template +void TestInplace(const int n, std::function tgt, + std::function ref) { + std::vector x(n); + std::vector ytgt(n), yref(n); + RandomVec(n, x.data()); + + const T* x_data = x.data(); + T* yref_data = yref.data(); + T* ytgt_data = ytgt.data(); + std::memcpy(yref_data, x_data, sizeof(T) * n); + std::memcpy(ytgt_data, x_data, sizeof(T) * n); + + ref(n, yref_data, yref_data); + tgt(n, ytgt_data, ytgt_data); + + for (int i = 0; i < n; ++i) { + EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3); + } +} + +TEST(CpuVecTest, inplace_sigmoid) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestInplace(sz, vec_sigmoid, ref_sigmoid); + TestInplace(sz, vec_sigmoid, ref_sigmoid); + TestInplace(sz, vec_sigmoid, ref_sigmoid); + TestInplace(sz, vec_sigmoid, + ref_sigmoid); + } + TestInplace(30, vec_sigmoid, ref_sigmoid); +} + +TEST(CpuVecTest, inplace_tanh) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, + ref_tanh); + } + TestInplace(30, vec_tanh, ref_tanh); +} + +TEST(CpuVecTest, inplace_relu) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, + ref_relu); + } + TestInplace(30, vec_relu, ref_relu); +} From 786558fc680622844d45ac7ea75d899898f95b3b Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 24 Aug 2018 06:40:13 +0000 Subject: [PATCH 086/140] fix bug to avoid warning once import paddle.fluid --- python/paddle/dataset/image.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/python/paddle/dataset/image.py b/python/paddle/dataset/image.py index 1cd50bd180..b32736ee7c 100644 --- a/python/paddle/dataset/image.py +++ b/python/paddle/dataset/image.py @@ -36,11 +36,6 @@ import numpy as np try: import cv2 except ImportError: - import sys - sys.stderr.write( - '''Warning with paddle image module: opencv-python should be imported, - or paddle image module could NOT work; please install opencv-python first.''' - ) cv2 = None import os import tarfile @@ -53,6 +48,18 @@ __all__ = [ ] +def _check_cv2(): + if cv2 is None: + import sys + sys.stderr.write( + '''Warning with paddle image module: opencv-python should be imported, + or paddle image module could NOT work; please install opencv-python first.''' + ) + return False + else: + return True + + def batch_images_from_tar(data_file, dataset_name, img2label, @@ -134,7 +141,7 @@ def load_image_bytes(bytes, is_color=True): load and return a gray image. :type is_color: bool """ - assert cv2 is not None + assert _check_cv2() is True flag = 1 if is_color else 0 file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8) @@ -159,7 +166,7 @@ def load_image(file, is_color=True): load and return a gray image. :type is_color: bool """ - assert cv2 is not None + assert _check_cv2() is True # cv2.IMAGE_COLOR for OpenCV3 # cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version @@ -188,7 +195,7 @@ def resize_short(im, size): :param size: the shorter edge size of image after resizing. :type size: int """ - assert cv2 is not None + assert _check_cv2() is True h, w = im.shape[:2] h_new, w_new = size, size From 3462c29940ccf4e60f56f430757655d9c9676200 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 14:53:45 +0800 Subject: [PATCH 087/140] refine add bias with avx --- paddle/fluid/operators/attention_lstm_op.cc | 30 +++------- paddle/fluid/operators/math/cpu_vec.h | 66 +++++++++++++++++++-- 2 files changed, 69 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 1cb65346ee..a73ea09f1e 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -232,40 +232,28 @@ use lstm_x_t as input and compute as standard LSTM. template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { if (bias) { - for (int i = 0; i < n; ++i) { - y[i] = x[i] + bias[0]; - } - math::vec_relu(n, y, y); + math::vec_add_bias(n, *bias, x, y); + math::vec_relu(n, y, y); } else { - math::vec_relu(n, x, y); + math::vec_relu(n, x, y); } } -template -inline void vec_softmax(const math::BlasT& blas, const int n, - const T* x, T* y) { +template +inline void vec_softmax(const int n, const T* x, T* y) { T scalar = x[0]; // max for (int i = 1; i < n; ++i) { scalar = scalar < x[i] ? x[i] : scalar; } - - // sub - for (int i = 0; i < n; ++i) { - y[i] = x[i] - scalar; - } - - // exp - blas.VEXP(n, y, y); - + math::vec_add_bias(n, -scalar, x, y); // sub + math::vec_exp(n, y, y); // exp // sum scalar = T(0); for (int i = 0; i < n; ++i) { scalar += y[i]; } - - // scale - blas.SCAL(n, static_cast(1) / scalar, y); + math::vec_scal(n, static_cast(1) / scalar, y); // scale } template @@ -363,7 +351,7 @@ class AttentionLSTMKernel : public framework::OpKernel { fc_out_data); } // 1d. softmax - vec_softmax(blas, seq_len, fc_out_data, fc_out_data); + vec_softmax(seq_len, fc_out_data, fc_out_data); // mul x(seq_len*M) and sum pool math::FCCompute(blas, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data); diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index d5f247e7ef..0bae926e98 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -87,7 +87,7 @@ inline void vec_scal(const int n, const float a, const float* x, float* y) { #ifdef __AVX__ constexpr int block = AVX_FLOAT_BLOCK; - if (n < block * 4) { // use larger threshold, since small ones has no boost + if (n < block) { vec_scal(n, a, x, y); return; } @@ -131,6 +131,62 @@ inline void vec_scal(const int n, vec_scal(n, a, x, y); } +template +inline void vec_add_bias(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + a; + } +} + +template <> +inline void vec_add_bias(const int n, const float a, + const float* x, float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_add_bias(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_add_ps(tmp, bias); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = x[i] + a; + } +#else + vec_add_bias(n, a, x, y); +#endif +} + +template <> +inline void vec_add_bias(const int n, const float a, + const float* x, float* y) { + vec_add_bias(n, a, x, y); +} + +template <> +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_add_bias(n, a, x, y); +} + template inline void vec_identity(const int n, const T* x, T* y) { // do nothing @@ -229,11 +285,10 @@ inline void vec_tanh(const int n, const T* x, T* y) { vec_scal(n, static_cast(2), x, y); vec_sigmoid(n, y, y); vec_scal(n, static_cast(2), y); - for (int i = 0; i < n; ++i) { - y[i] = y[i] - static_cast(1); - } + vec_add_bias(n, static_cast(-1), y, y); } +// TODO(TJ): make relu clip template inline void vec_relu(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { @@ -246,7 +301,7 @@ inline void vec_relu(const int n, const float* x, float* y) { #ifdef __AVX__ constexpr int block = AVX_FLOAT_BLOCK; - if (n < block) { + if (n < block * 4) { vec_relu(n, x, y); return; } @@ -288,7 +343,6 @@ inline void vec_relu(const int n, // TODO(TJ): enable me vec_relu(n, x, y); } -// TODO(TJ): add vec add bias, make relu clip // TODO(TJ): optimize double of sigmoid, tanh and relu if necessary From ba943d38e38b96b527114b70a37321af665a5062 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 15:07:05 +0800 Subject: [PATCH 088/140] make runtime avx act --- paddle/fluid/operators/attention_lstm_op.cc | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index a73ea09f1e..8bab37c583 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -299,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); fc_out->Resize({max_seq_len, 1}); - math::VecActivations act_functor; std::function act_gate, act_cell, act_cand; - act_gate = act_functor(ctx.Attr("gate_activation")); - act_cell = act_functor(ctx.Attr("cell_activation")); - act_cand = act_functor(ctx.Attr("candidate_activation")); + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : NULL; From 3d06ccfb23f45994253cd229ddeab0e7b36e0a15 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Fri, 24 Aug 2018 17:07:25 +0800 Subject: [PATCH 089/140] update native_infer.rst --- doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst b/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst index 3571f81326..aa9377c112 100644 --- a/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst +++ b/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst @@ -9,8 +9,6 @@ Paddle 预测 API - 头文件 ``paddle_inference_api.h`` 定义了所有的接口 - 库文件\ ``libpaddle_fluid.so`` 或 ``libpaddle_fluid.a`` -- 库文件 ``libpaddle_inference_api.so`` 或 - ``libpaddle_inference_api.a`` 编译和依赖可以参考 :ref:`install_or_build_cpp_inference_lib` 。 @@ -97,8 +95,7 @@ engine CHECK(predictor->Run(slots, &outputs)); // 获取 outputs ... -编译时,联编 ``libpaddle_fluid.a/.so`` 和 -``libpaddle_inference_api.a/.so`` 便可。 +编译时,联编 ``libpaddle_fluid.a/.so`` 即可。 详细代码参考 ------------ From 3b38e5a4fc5be2740762d9ff7a8ff8b5b7d5e930 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Fri, 24 Aug 2018 10:04:18 +0000 Subject: [PATCH 090/140] speed up stack_op --- paddle/fluid/operators/stack_op.h | 56 ++++++++----------------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index c777d5feae..d236c5b943 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -150,30 +150,17 @@ class StackKernel : public framework::OpKernel { int total_num = pre * n * post; auto &dev_ctx = ctx.template device_context(); - constexpr auto kMaxThreshold = 16; - if (std::is_same::value || - n > kMaxThreshold) { #ifdef __NVCC__ - VLOG(10) << "Stack more than " << kMaxThreshold - << " tensors on GPU may be slow."; - thrust::device_vector device_x_vec(x_datas); - auto x_data_arr = device_x_vec.data().get(); + thrust::device_vector device_x_vec(x_datas); + auto x_data_arr = device_x_vec.data().get(); #else - auto x_data_arr = x_datas.data(); + auto x_data_arr = x_datas.data(); #endif - StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); #ifdef __NVCC__ - // Wait() must be called because device_x_vec may be destructed before - // kernel ends - dev_ctx.Wait(); -#endif - } -#ifdef __NVCC__ - else { // NOLINT - framework::Array x_data_arr; - for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i]; - StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); - } + // Wait() must be called because device_x_vec may be destructed before + // kernel ends + dev_ctx.Wait(); #endif } }; @@ -244,32 +231,17 @@ class StackGradKernel : public framework::OpKernel { int post = total_num / (n * pre); auto &dev_ctx = ctx.template device_context(); - constexpr auto kMaxThreshold = 16; - if (std::is_same::value || - n > kMaxThreshold) { #ifdef __NVCC__ - VLOG(10) << "Stack more than " << kMaxThreshold - << " tensors on GPU may be slow."; - thrust::device_vector device_dx_vec(dx_datas); - auto dx_data_arr = device_dx_vec.data().get(); + thrust::device_vector device_dx_vec(dx_datas); + auto dx_data_arr = device_dx_vec.data().get(); #else - auto dx_data_arr = dx_datas.data(); + auto dx_data_arr = dx_datas.data(); #endif - StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, - post); + StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); #ifdef __NVCC__ - // Wait() must be called because device_dx_vec may be destructed before - // kernel ends - dev_ctx.Wait(); -#endif - } -#ifdef __NVCC__ - else { // NOLINT - framework::Array dx_data_arr; - for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i]; - StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, - post); - } + // Wait() must be called because device_dx_vec may be destructed before + // kernel ends + dev_ctx.Wait(); #endif } }; From 3de455665983570d592ed23420904c53ec9bcc20 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Fri, 24 Aug 2018 10:52:36 +0000 Subject: [PATCH 091/140] concat op && map cnn model support --- paddle/fluid/inference/analysis/analyzer.cc | 2 +- .../api/api_tensorrt_subgraph_engine.cc | 2 + .../inference/tensorrt/convert/CMakeLists.txt | 8 +-- .../inference/tensorrt/convert/concat_op.cc | 57 +++++++++++++++++++ .../inference/tensorrt/convert/op_converter.h | 8 +++ .../tensorrt/convert/test_concat_op.cc | 49 ++++++++++++++++ 6 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/concat_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/test_concat_op.cc diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 7d16364609..0d94ccb64e 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -72,7 +72,7 @@ class DfgPassManagerImpl final : public DfgPassManager { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", - "depthwise_conv2d", "batch_norm"}); + "depthwise_conv2d", "batch_norm", "concat"}); if (!node->IsFunction()) return false; const auto* func = static_cast(node); diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 9ac0372971..93de7a5209 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -32,6 +32,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { : NativePaddlePredictor(config), config_(config) {} bool Init(const std::shared_ptr& parent_scope) { + FLAGS_IA_enable_tensorrt_subgraph_engine = true; VLOG(3) << "Predictor::init()"; FLAGS_tensorrt_max_batch_size = config_.max_batch_size; FLAGS_tensorrt_workspace_size = config_.workspace_size; @@ -161,3 +162,4 @@ USE_TRT_CONVERTER(fc); USE_TRT_CONVERTER(pool2d); USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(batch_norm); +USE_TRT_CONVERTER(concat); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 2a449eb95e..9d7be2d03c 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -18,12 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) - nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL) - nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) - nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) + +nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/concat_op.cc b/paddle/fluid/inference/tensorrt/convert/concat_op.cc new file mode 100644 index 0000000000..bb9627bf95 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/concat_op.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. + */ +class ConcatOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + std::vector itensors; + for (auto& input_name : op_desc.Input("X")) { + itensors.push_back(engine_->GetITensor(input_name)); + } + int axis = boost::get(op_desc.GetAttr("axis")); + PADDLE_ENFORCE(axis > 0, + "The axis attr of Concat op should be large than 0 for trt"); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(), + itensors.size()); + axis = axis - 1; // Remove batch dim + layer->setAxis(axis); + auto output_name = op_desc.Output("Out")[0]; + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(concat, ConcatOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 41faaf7212..d309d94c56 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -79,6 +79,14 @@ class OpConverter { it = Registry::Lookup("elementwise_" + op_type + "_tensor"); } + PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", + op_desc.Type()); + } + + if (op_desc.Type() == "depthwise_conv2d") { + it = Registry::Lookup("conv2d"); + PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", + op_desc.Type()); } if (!it) { diff --git a/paddle/fluid/inference/tensorrt/convert/test_concat_op.cc b/paddle/fluid/inference/tensorrt/convert/test_concat_op.cc new file mode 100644 index 0000000000..4f284a4db5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_concat_op.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(concat_op, test) { + std::unordered_set parameters({""}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("concat_x1", nvinfer1::DimsCHW(10, 3, 1)); + validator.DeclInputVar("concat_x2", nvinfer1::DimsCHW(3, 3, 1)); + validator.DeclInputVar("concat_x3", nvinfer1::DimsCHW(7, 3, 1)); + validator.DeclOutputVar("concat_out", nvinfer1::DimsCHW(20, 3, 1)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("concat"); + desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"}); + desc.SetOutput("Out", {"concat_out"}); + + int axis = 1; + desc.SetAttr("axis", axis); + + validator.SetOp(*desc.Proto()); + + validator.Execute(5); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle +USE_OP(concat); From 6be273cbdbf3fe46fe5a4b4af787977b9bd59929 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 Aug 2018 22:40:38 +0800 Subject: [PATCH 092/140] add seq mode lstm --- paddle/fluid/operators/fusion_lstm_op.cc | 52 ++++++++++++++++++++---- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 3888333ec5..870292827d 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,10 +15,14 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/cpu_info.h" + +DEFINE_bool(seq_mode, true, "Use sequence mode"); namespace paddle { namespace operators { @@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + int xx_width; + if (FLAGS_seq_mode) { + xx_width = wx_dims[1]; + } else { + xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); } @@ -205,10 +214,34 @@ inline void ReorderInitState(const DeviceContext& ctx, row_shuffle(ctx, src, index_lod, dst, indexed_src); } -template +template class FuisonLSTMKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { + void SeqCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* x = ctx.Input("X"); + auto* wx = ctx.Input("WeightX"); + auto* wh = ctx.Input("WeightH"); + auto* bias = ctx.Input("Bias"); + + auto* xx = ctx.Output("XX"); + + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int M = x_dims[1]; // x frame size + const int D4 = wh_dims[1]; + + const T* x_data = x->data(); + const T* wx_data = wx->data(); + T* xx_data = xx->mutable_data(ctx.GetPlace()); + + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, x_dims[0], D4, M, x_data, wx_data, + xx_data, bias->data()); + } + + void BatchCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = platform::CPUDeviceContext; auto* x = ctx.Input("X"); auto* wx = ctx.Input("WeightX"); auto* wh = ctx.Input("WeightH"); @@ -339,6 +372,13 @@ class FuisonLSTMKernel : public framework::OpKernel { // restore the output cell state in LoDTensor from the batch cell to_seq(dev_ctx, batch_cell, cell_out); } + void Compute(const framework::ExecutionContext& ctx) const override { + if (FLAGS_seq_mode) { + SeqCompute(ctx); + } else { + BatchCompute(ctx); + } + } }; } // namespace operators @@ -348,7 +388,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL( - fusion_lstm, - ops::FuisonLSTMKernel, - ops::FuisonLSTMKernel); +REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel, + ops::FuisonLSTMKernel); From 593ac0f23ec6c9beac85791a37d29228414903db Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 11:05:39 +0800 Subject: [PATCH 093/140] openblas (#12937) --- cmake/external/glog.cmake | 7 +++++++ cmake/external/openblas.cmake | 19 +++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/cmake/external/glog.cmake b/cmake/external/glog.cmake index ac0181e69c..25ef2970ac 100644 --- a/cmake/external/glog.cmake +++ b/cmake/external/glog.cmake @@ -60,6 +60,13 @@ ExternalProject_Add( -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} ) +IF(WIN32) + IF(NOT EXISTS "${GLOG_INSTALL_DIR}/lib/libglog.lib") + add_custom_command(TARGET extern_glog POST_BUILD + COMMAND cmake -E rename ${GLOG_INSTALL_DIR}/lib/glog.lib ${GLOG_INSTALL_DIR}/lib/libglog.lib + ) + ENDIF() +ENDIF(WIN32) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 56024edf5b..c3fbe4dbdb 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -17,20 +17,29 @@ IF(USE_EIGEN_FOR_BLAS) ENDIF(USE_EIGEN_FOR_BLAS) INCLUDE(cblas) +# IF(WIN32 AND NOT ${CBLAS_FOUND}) + + IF(NOT ${CBLAS_FOUND}) + INCLUDE(ExternalProject) SET(CBLAS_SOURCES_DIR ${THIRD_PARTY_PATH}/openblas) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) - SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) + SET(CBLAS_INCLUDE_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "openblas library." FORCE) ADD_DEFINITIONS(-DPADDLE_USE_OPENBLAS) + IF (WIN32) + SET(CBLAS_FOUND true) + MESSAGE(WARNING, "In windows, openblas only support msvc build, please build it manually and put it at " ${CBLAS_INSTALL_DIR}) + ENDIF(WIN32) + IF (NOT WIN32) SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable") SET(OPENBLAS_COMMIT "v0.2.20") @@ -69,7 +78,6 @@ IF(NOT ${CBLAS_FOUND}) ENDIF() SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs) - ExternalProject_Add( extern_openblas ${EXTERNAL_PROJECT_LOG_ARGS} @@ -84,9 +92,11 @@ IF(NOT ${CBLAS_FOUND}) UPDATE_COMMAND "" CONFIGURE_COMMAND "" ) + ELSE() + ENDIF(NOT WIN32) SET(CBLAS_PROVIDER openblas) IF(WITH_C_API) - INSTALL(DIRECTORY ${CBLAS_INC_DIR} DESTINATION third_party/openblas) + INSTALL(DIRECTORY ${CBLAS_INCLUDE_DIR} DESTINATION third_party/openblas) # Because libopenblas.a is a symbolic link of another library, thus need to # install the whole directory. IF(ANDROID) @@ -107,7 +117,8 @@ IF(NOT ${CBLAS_FOUND}) ENDIF(NOT ${CBLAS_FOUND}) MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}") -INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) +MESSAGE(STATUS "BLAS Include: ${CBLAS_INCLUDE_DIR}") +INCLUDE_DIRECTORIES(${CBLAS_INCLUDE_DIR}) # FIXME(gangliao): generate cblas target to track all high performance # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) From 5df65811010162743959090d7a80e557d9594178 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 11:05:49 +0800 Subject: [PATCH 094/140] merge_static_libs (#12936) --- cmake/generic.cmake | 38 ++++++++++++++++++++++++++++++++++++-- cmake/inference_lib.cmake | 9 +++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 82c958073c..6d23094232 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -148,7 +148,8 @@ function(merge_static_libs TARGET_NAME) COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles} ) - else() # general UNIX: use "ar" to extract objects and re-add to a common lib + endif(APPLE) + if(LINUX) # general UNIX: use "ar" to extract objects and re-add to a common lib set(target_DIR ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}.dir) foreach(lib ${libs}) @@ -187,7 +188,36 @@ function(merge_static_libs TARGET_NAME) COMMAND ${CMAKE_AR} crs ${target_LIBNAME} `find ${target_DIR} -name '*.o'` COMMAND ${CMAKE_RANLIB} ${target_LIBNAME} WORKING_DIRECTORY ${target_DIR}) - endif() + endif(LINUX) + if(WIN32) # windows do not support gcc/nvcc combined compiling. Use msvc lib.exe to merge libs. + # Make the generated dummy source file depended on all static input + # libs. If input lib changes,the source file is touched + # which causes the desired effect (relink). + add_custom_command(OUTPUT ${target_SRCS} + COMMAND ${CMAKE_COMMAND} -E touch ${target_SRCS} + DEPENDS ${libs}) + + # Generate dummy staic lib + file(WRITE ${target_SRCS} "const char *dummy_${TARGET_NAME} = \"${target_SRCS}\";") + add_library(${TARGET_NAME} STATIC ${target_SRCS}) + target_link_libraries(${TARGET_NAME} ${libs_deps}) + + foreach(lib ${libs}) + # Get the file names of the libraries to be merged + #if(NOT $ MATCHES "lib.*\\.lib") + # message("library" ${lib}) + # set(libfiles ${libfiles} lib$) + #else() + set(libfiles ${libfiles} $) + #endif() + endforeach() + + # windows cmd return error in clean env. + # COMMAND del "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/${TARGET_NAME}.lib" + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND lib /OUT:${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.lib ${libfiles} + ) + endif(WIN32) endfunction(merge_static_libs) function(cc_library TARGET_NAME) @@ -195,6 +225,10 @@ function(cc_library TARGET_NAME) set(oneValueArgs "") set(multiValueArgs SRCS DEPS) cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if(WIN32) + # add libxxx.lib prefix in windows + set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") + endif(WIN32) if(cc_library_SRCS) if(cc_library_SHARED OR cc_library_shared) # build *.so add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 834ab5a9e5..bc36683a9f 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -101,6 +101,7 @@ if(WITH_MKLDNN) ) endif() +if (NOT WIN32) if(NOT MOBILE_INFERENCE AND NOT RPI) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy") copy(snappy_lib @@ -120,15 +121,23 @@ if(NOT MOBILE_INFERENCE AND NOT RPI) DSTS ${dst_dir} ${dst_dir}/lib DEPS zlib) endif() +endif(NOT WIN32) # paddle fluid module set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") set(dst_dir "${FLUID_INSTALL_DIR}/paddle/fluid") set(module "framework") +if (NOT WIN32) copy(framework_lib DEPS framework_py_proto SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ) +else() +copy(framework_lib + SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h + DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} +) +endif(NOT WIN32) set(module "memory") copy(memory_lib From 669304f4e5005c9dd4763a86a2f91773d68941be Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 11:06:00 +0800 Subject: [PATCH 095/140] protobuf (#12935) --- cmake/external/protobuf.cmake | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 2665996432..550b0dada8 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -14,11 +14,14 @@ INCLUDE(ExternalProject) # Always invoke `FIND_PACKAGE(Protobuf)` for importing function protobuf_generate_cpp +IF(NOT WIN32) FIND_PACKAGE(Protobuf QUIET) +ENDIF(NOT WIN32) macro(UNSET_VAR VAR_NAME) UNSET(${VAR_NAME} CACHE) UNSET(${VAR_NAME}) endmacro() + UNSET_VAR(PROTOBUF_INCLUDE_DIR) UNSET_VAR(PROTOBUF_FOUND) UNSET_VAR(PROTOBUF_PROTOC_EXECUTABLE) @@ -94,12 +97,14 @@ macro(PROMPT_PROTOBUF_LIB) SET(protobuf_DEPS ${ARGN}) MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}") + MESSAGE(STATUS "Protobuf-lite library: ${PROTOBUF_LITE_LIBRARY}") MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}") + MESSAGE(STATUS "Protoc library: ${PROTOBUF_PROTOC_LIBRARY}") MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}") INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) # Assuming that all the protobuf libraries are of the same type. - IF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$") + IF(${PROTOBUF_LIBRARY} MATCHES ${CMAKE_STATIC_LIBRARY_SUFFIX}) SET(protobuf_LIBTYPE STATIC) ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}$") SET(protobuf_LIBTYPE SHARED) @@ -137,18 +142,25 @@ macro(SET_PROTOBUF_VERSION) endmacro() set(PROTOBUF_ROOT "" CACHE PATH "Folder contains protobuf") +IF (WIN32) + SET(PROTOBUF_ROOT ${THIRD_PARTY_PATH}/install/protobuf) + MESSAGE(WARNING, "In windows, protobuf only support msvc build, please build it manually and put it at " ${PROTOBUF_ROOT}) +ENDIF(WIN32) + if (NOT "${PROTOBUF_ROOT}" STREQUAL "") + find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH) - find_library(PROTOBUF_LIBRARY protobuf PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) - find_library(PROTOBUF_LITE_LIBRARY protobuf-lite PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) - find_library(PROTOBUF_PROTOC_LIBRARY protoc PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) + find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) + find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) + find_library(PROTOBUF_PROTOC_LIBRARY protoc libprotoc.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) find_program(PROTOBUF_PROTOC_EXECUTABLE protoc PATHS ${PROTOBUF_ROOT}/bin NO_DEFAULT_PATH) if (PROTOBUF_INCLUDE_DIR AND PROTOBUF_LIBRARY AND PROTOBUF_LITE_LIBRARY AND PROTOBUF_PROTOC_LIBRARY AND PROTOBUF_PROTOC_EXECUTABLE) message(STATUS "Using custom protobuf library in ${PROTOBUF_ROOT}.") + SET(PROTOBUF_FOUND true) SET_PROTOBUF_VERSION() PROMPT_PROTOBUF_LIB() else() - message(WARNING "Cannot find protobuf library in ${PROTOBUF_ROOT}.") + message(WARNING "Cannot find protobuf library in ${PROTOBUF_ROOT}") endif() endif() @@ -239,6 +251,7 @@ IF(CMAKE_CROSSCOMPILING) CACHE FILEPATH "protobuf executable." FORCE) ENDIF() + IF(NOT PROTOBUF_FOUND) build_protobuf(extern_protobuf FALSE) From a4ffdf3088daaef939eab72db0c96473db8e2621 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 11:25:45 +0800 Subject: [PATCH 096/140] gflags (#12928) --- cmake/external/gflags.cmake | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index a1d2d0f446..cf58cc3976 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -18,7 +18,7 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags) SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) IF(WIN32) - set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) + set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) ELSE(WIN32) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) ENDIF(WIN32) @@ -45,7 +45,13 @@ ExternalProject_Add( -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} ) - +IF(WIN32) + IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib") + add_custom_command(TARGET extern_gflags POST_BUILD + COMMAND cmake -E rename ${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib ${GFLAGS_INSTALL_DIR}/lib/libgflags.lib + ) + ENDIF() +ENDIF(WIN32) ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES}) ADD_DEPENDENCIES(gflags extern_gflags) @@ -60,3 +66,4 @@ IF(WITH_C_API) INSTALL(FILES ${GFLAGS_LIBRARIES} DESTINATION third_party/gflags/lib) ENDIF() ENDIF() + From eca4563e5dfd949d4ee8c945494a5f25412dae17 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 11:37:46 +0800 Subject: [PATCH 097/140] operators module (#12938) --- paddle/fluid/operators/CMakeLists.txt | 5 +++-- paddle/fluid/operators/math/math_function.h | 4 ++++ paddle/fluid/platform/float16.h | 4 ++++ paddle/fluid/pybind/CMakeLists.txt | 12 +++++++----- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 68fbde2c09..8da0aaaafe 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -85,7 +85,7 @@ function(op_library TARGET) #remove windows unsupported op if (WIN32) - foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op") + foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op") if ("${TARGET}" STREQUAL "${windows_unsupport_op}") return() endif() @@ -319,8 +319,9 @@ foreach(src ${GENERAL_OPS}) endforeach() file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") - +if (NOT WIN32) add_subdirectory(reader) +endif(NOT WIN32) foreach(src ${READER_LIBRARY}) set(OP_LIBRARY ${src} ${OP_LIBRARY}) endforeach() diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index 7ec78d9ef8..c63ad89e46 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -19,6 +19,10 @@ limitations under the License. */ #ifdef PADDLE_USE_OPENBLAS #include +// remove typedef in openblas +#undef FLOAT +#undef INT +#undef SIZE #endif #include diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index efb021c838..ee16fc66e4 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -56,7 +56,11 @@ limitations under the License. */ #include #endif // PADDLE_ARM +#if !defined(_WIN32) #define PADDLE_ALIGN(x) __attribute__((aligned(x))) +#else +#define PADDLE_ALIGN(x) /*do nothing*/ +#endif namespace paddle { namespace platform { diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d6a14b3305..b5bd07d401 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,17 +1,19 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor prune profiler feed_fetch_method - ) + +set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method) +set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc) if(NOT WIN32) -list(APPEND PYBIND_DEPS parallel_executor) +list(APPEND PYBIND_DEPS parallel_executor profiler) +list(APPEND PYBIND_SRCS recordio.cc) endif() if(WITH_PYTHON) if(WITH_AMD_GPU) hip_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc + SRCS ${PYBIND_SRCS} DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB}) else() cc_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc + SRCS ${PYBIND_SRCS} DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB}) if(NOT APPLE AND NOT ANDROID AND NOT WIN32) From d0b713493eebd79c8bc6c40a8d55f6f31bad4021 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sat, 25 Aug 2018 13:30:22 +0800 Subject: [PATCH 098/140] enhance DebugStringEx (#12949) --- paddle/fluid/framework/operator.cc | 52 +++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d04f774496..d58d6e4f3e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -74,6 +74,12 @@ static DDim GetDims(const Scope& scope, const std::string& name, } } +static bool VarInited(const Scope& scope, const std::string& name) { + Variable* var = scope.FindVar(name); + if (var == nullptr) return false; + return var->IsInitialized(); +} + static std::string GetDtype(const Scope& scope, const std::string& name) { Variable* var = scope.FindVar(name); if (var == nullptr) { @@ -87,8 +93,12 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { } return DataTypeToString(ToDataType(tensor.type())); } else if (var->IsType()) { - return DataTypeToString( - ToDataType(var->Get().value().type())); + auto tensor = var->Get().value(); + if (UNLIKELY(!tensor.IsInitialized())) { + return "uninited"; + } else { + return DataTypeToString(ToDataType(tensor.type())); + } } else { return ""; } @@ -197,16 +207,21 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { auto& input = *it; ss << input.first << "["; for (size_t i = 0; i < input.second.size(); ++i) { - ss << input.second[i]; + auto var_name = input.second[i]; + ss << var_name; if (scope) { - int row_size = GetRowSize(*scope, input.second[i]); - if (row_size >= 0) { - ss << "[row_size=" << row_size << "]"; + if (!VarInited(*scope, var_name)) { + ss << "[uninited]"; + } else { + int row_size = GetRowSize(*scope, var_name); + if (row_size >= 0) { + ss << "[row_size=" << row_size << "]"; + } + std::string dtype = GetDtype(*scope, var_name); + ss << ":" << dtype; + ss << "[" << GetDims(*scope, var_name, true) << "]"; + ss << "(" << GetLoD(*scope, var_name) << ")"; } - std::string dtype = GetDtype(*scope, input.second[i]); - ss << ":" << dtype; - ss << "[" << GetDims(*scope, input.second[i], true) << "]"; - ss << "(" << GetLoD(*scope, input.second[i]) << ")"; } if (i != input.second.size() - 1) { ss << ", "; @@ -223,14 +238,19 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { auto& output = *it; ss << output.first << "["; for (size_t i = 0; i < output.second.size(); ++i) { - ss << output.second[i]; + auto var_name = output.second[i]; + ss << var_name; if (scope) { - int row_size = GetRowSize(*scope, output.second[i]); - if (row_size >= 0) { - ss << "[row_size=" << row_size << "]"; + if (!VarInited(*scope, var_name)) { + ss << "[uninited]"; + } else { + int row_size = GetRowSize(*scope, output.second[i]); + if (row_size >= 0) { + ss << "[row_size=" << row_size << "]"; + } + ss << "[" << GetDims(*scope, var_name, true) << "]"; + ss << "(" << GetLoD(*scope, var_name) << ")"; } - ss << "[" << GetDims(*scope, output.second[i], true) << "]"; - ss << "(" << GetLoD(*scope, output.second[i]) << ")"; } if (i != output.second.size() - 1) { ss << ", "; From c790d57cd4ac80610f7e0f3c4ab164e57f74e463 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 14:53:27 +0800 Subject: [PATCH 099/140] data_type (#12933) * data_type * "remove tabs" --- paddle/fluid/CMakeLists.txt | 6 ++++- paddle/fluid/framework/CMakeLists.txt | 21 ++++++++++++++-- paddle/fluid/framework/data_type.h | 35 +++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 2577e59d9c..ee1f655e25 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -2,9 +2,13 @@ add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) add_subdirectory(operators) -add_subdirectory(pybind) add_subdirectory(string) + +if (NOT WIN32) +add_subdirectory(pybind) add_subdirectory(recordio) +endif(NOT WIN32) + if(WITH_INFERENCE) # NOTE: please add subdirectory inference at last. add_subdirectory(inference) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2ec422cc17..2c62d4ed6b 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1,5 +1,7 @@ -add_subdirectory(details) add_subdirectory(ir) +if (NOT WIN32) +add_subdirectory(details) +endif (NOT WIN32) # ddim lib proto_library(framework_proto SRCS framework.proto) @@ -28,8 +30,12 @@ if(WITH_GPU) else() cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) endif() - +if (NOT WIN32) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio) +else() +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) +endif (NOT WIN32) + cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) @@ -69,14 +75,22 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) + +if (NOT WIN32) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference data_transform lod_tensor profiler) +else() +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog + shape_inference data_transform lod_tensor) +endif(NOT WIN32) + cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) +if (NOT WIN32) py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) @@ -86,6 +100,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/ COMMENT "Copy generated python proto into directory paddle/fluid/proto." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) +endif(NOT WIN32) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) @@ -120,7 +135,9 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) # cc_test(channel_test SRCS channel_test.cc) cc_test(tuple_test SRCS tuple_test.cc ) +if (NOT WIN32) cc_test(rw_lock_test SRCS rw_lock_test.cc) +endif (NOT WIN32) # disable test temporarily. # TODO https://github.com/PaddlePaddle/Paddle/issues/11971 diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 491413db8c..f8c72ffc89 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -26,6 +26,7 @@ namespace framework { extern proto::VarType::Type ToDataType(std::type_index type); extern std::type_index ToTypeIndex(proto::VarType::Type type); +#if !defined(_WIN32) template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { switch (type) { @@ -57,6 +58,40 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { PADDLE_THROW("Not supported %d", type); } } +#else +// the msvc compiler do not implement two-stage name lookup correctly. +template +inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { + switch (type) { + case proto::VarType::FP16: + visitor.operator()(); + break; + case proto::VarType::FP32: + visitor.operator()(); + break; + case proto::VarType::FP64: + visitor.operator()(); + break; + case proto::VarType::INT32: + visitor.operator()(); + break; + case proto::VarType::INT64: + visitor.operator()(); + break; + case proto::VarType::BOOL: + visitor.operator()(); + break; + case proto::VarType::UINT8: + visitor.operator()(); + break; + case proto::VarType::INT16: + visitor.operator()(); + break; + default: + PADDLE_THROW("Not supported %d", type); + } +} +#endif // _WIN32 extern std::string DataTypeToString(const proto::VarType::Type type); extern size_t SizeOfType(std::type_index type); From 77c0aeb91e8906f6f1cecc1cdb28f1731e4a46c0 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 14:55:56 +0800 Subject: [PATCH 100/140] boost (#12929) * "fix ci" * "windows tab" * "fix ci" --- cmake/external/boost.cmake | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cmake/external/boost.cmake b/cmake/external/boost.cmake index 73713d93d5..ada61de8eb 100644 --- a/cmake/external/boost.cmake +++ b/cmake/external/boost.cmake @@ -28,7 +28,12 @@ if((NOT DEFINED BOOST_TAR) OR (NOT DEFINED BOOST_URL)) set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE) set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE) endif() -MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") +IF (WIN32) + MESSAGE(WARNING, "In windows, boost can not be downloaded automaticlly, please build it manually and put it at " ${THIRD_PARTY_PATH}install/boost) +else() + MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") +ENDIF(WIN32) + set(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost) set(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}") set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}/${BOOST_TAR}" CACHE PATH "boost include directory." FORCE) @@ -36,12 +41,13 @@ set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1) include_directories(${BOOST_INCLUDE_DIR}) +if (NOT WIN32) ExternalProject_Add( ${BOOST_PROJECT} ${EXTERNAL_PROJECT_LOG_ARGS} DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR} DOWNLOAD_COMMAND wget --no-check-certificate ${BOOST_URL} -c -q -O ${BOOST_TAR}.tar.gz - && tar zxf ${BOOST_TAR}.tar.gz + && tar zxf ${BOOST_TAR}.tar.gz DOWNLOAD_NO_PROGRESS 1 PREFIX ${BOOST_SOURCES_DIR} CONFIGURE_COMMAND "" @@ -49,8 +55,9 @@ ExternalProject_Add( INSTALL_COMMAND "" UPDATE_COMMAND "" ) +endif(NOT WIN32) -if (${CMAKE_VERSION} VERSION_LESS "3.3.0") +if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c) file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";") add_library(boost STATIC ${dummyfile}) From 04b1e4dcea1cb2a590643c464c70167b87ed94d4 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 15:38:24 +0800 Subject: [PATCH 101/140] tensor module windows support (#12934) * tensor windows support * "fix ci" * "remove utils" --- paddle/fluid/framework/lod_tensor.cc | 17 ++++++++++++++++- paddle/fluid/framework/lod_tensor_test.cc | 2 ++ paddle/fluid/framework/rw_lock.h | 12 ++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 919029c38f..adeb26e4e7 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -25,8 +25,10 @@ limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" +#if !defined(_WIN32) #include "paddle/fluid/recordio/scanner.h" #include "paddle/fluid/recordio/writer.h" +#endif // _WIN32 namespace paddle { namespace framework { @@ -300,6 +302,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, TensorFromStream(is, static_cast(tensor), dev_ctx); } +#if !defined(_WIN32) void WriteToRecordIO(recordio::Writer *writer, const std::vector &tensor, const platform::DeviceContext &dev_ctx) { @@ -329,7 +332,19 @@ bool ReadFromRecordIO(recordio::Scanner *scanner, return true; } - +#else +class Writer {}; +class Scanner {}; +void WriteToRecordIO(recordio::Writer *writer, + const std::vector &tensor, + const platform::DeviceContext &dev_ctx) {} +bool ReadFromRecordIO(recordio::Scanner *scanner, + const platform::DeviceContext &dev_ctx, + std::vector *result_ptr) { + PADDLE_ENFORCE("windows didn't supported recordio!."); + return true; +} +#endif // _WIN32 std::vector LoDTensor::SplitLoDTensor( const std::vector places) const { check_memory_size(); diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index cd50aaa260..cbf5fd04d7 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -274,6 +274,7 @@ TEST(LoD, ConvertToOffsetBasedLoD) { EXPECT_EQ(offset_lod, expected); } +#if !defined(_WIN32) template static void TestRecordIO() { LoDTensor tensor; @@ -320,6 +321,7 @@ TEST(LoDTensor, RecordIO) { TestRecordIO(); TestRecordIO(); } +#endif // !defined(_WIN32) } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h index 1418fb5134..a068d3543d 100644 --- a/paddle/fluid/framework/rw_lock.h +++ b/paddle/fluid/framework/rw_lock.h @@ -14,13 +14,16 @@ limitations under the License. */ #pragma once +#if !defined(_WIN32) #include +#endif // !_WIN32 #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { +#if !defined(_WIN32) struct RWLock { RWLock() { pthread_rwlock_init(&lock_, nullptr); } @@ -43,6 +46,15 @@ struct RWLock { private: pthread_rwlock_t lock_; }; +#else +// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive +// In windows, rw_lock seems like a hack. Use empty object and do nothing. +struct RWLock { + void RDLock() {} + void WRLock() {} + void UNLock() {} +}; +#endif } // namespace framework } // namespace paddle From dbd7896678ade5a57705477d0f963525e909733c Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 25 Aug 2018 15:38:42 +0800 Subject: [PATCH 102/140] cmakelist windows (#12927) * picked pr * "fix ci" --- CMakeLists.txt | 17 ++++++++++++----- cmake/configure.cmake | 5 +++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 317f7f9eb4..b1d0abdf2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,6 +24,9 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: " "${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: " "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") +if(WIN32) + set(CMAKE_STATIC_LIBRARY_PREFIX lib) +endif(WIN32) if(NOT CMAKE_CROSSCOMPILING) find_package(CUDA QUIET) @@ -165,7 +168,6 @@ include(external/python) # download, build, install python include(external/openblas) # download, build, install openblas include(external/mkldnn) # download, build, install mkldnn include(external/swig) # download, build, install swig -include(external/warpctc) # download, build, install warpctc include(external/boost) # download boost include(external/any) # download libn::any include(external/eigen) # download eigen3 @@ -173,6 +175,14 @@ include(external/pybind11) # download pybind11 include(external/cares) include(external/cub) +if (NOT WIN32) +# there is no official support of snappystream, warpctc, nccl, cupti in windows +include(external/snappy) # download snappy +include(external/snappystream) # download snappystream +include(external/warpctc) # download, build, install warpctc +include(cupti) +endif (NOT WIN32) + if(WITH_DISTRIBUTE) if(WITH_GRPC) include(external/grpc) @@ -194,13 +204,10 @@ if(WITH_BRPC_RDMA) endif() endif() -include(external/snappy) # download snappy -include(external/snappystream) -include(external/threadpool) +include(external/threadpool) include(flags) # set paddle compile flags include(cudnn) # set cudnn libraries, must before configure -include(cupti) include(configure) # add paddle env configuration if(WITH_GPU) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index e03e15bfc0..ce1857582b 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -61,6 +61,11 @@ if(NOT CMAKE_CROSSCOMPILING) endif() endif() +if(WIN32) + # windows stupid compile option for all targets. + add_definitions(-D_XKEYCHECK_H) +endif(WIN32) + if(NOT WITH_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG) endif(NOT WITH_GOLANG) From 3c58b87b45440cf13be778a53c6b2744c1d00e7e Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 26 Aug 2018 10:00:41 +0800 Subject: [PATCH 103/140] fix auc layer and add check for auc op (#12954) * fix auc layer and add check for auc op * use input to check if states are inited * optimize code --- paddle/fluid/operators/auc_op.h | 14 ++++++++++++++ paddle/fluid/operators/math/cpu_vec_test.cc | 1 + python/paddle/fluid/layers/metric_op.py | 12 ++++++++---- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/auc_op.h b/paddle/fluid/operators/auc_op.h index 0a18585edb..0651203286 100644 --- a/paddle/fluid/operators/auc_op.h +++ b/paddle/fluid/operators/auc_op.h @@ -60,6 +60,20 @@ class AucKernel : public framework::OpKernel { const T* inference_data = predict->data(); const auto* label_data = label->data(); + // check if states are inited. + auto* tp_in = ctx.Input("TP"); + auto* fp_in = ctx.Input("FP"); + auto* tn_in = ctx.Input("TN"); + auto* fn_in = ctx.Input("FN"); + PADDLE_ENFORCE(tp_in->IsInitialized(), "true_positive is not inited!"); + PADDLE_ENFORCE(fp_in->IsInitialized(), "false_negative is not inited!"); + PADDLE_ENFORCE(tn_in->IsInitialized(), "true_negative is not inited!"); + PADDLE_ENFORCE(fn_in->IsInitialized(), "false_positive is not inited!"); + PADDLE_ENFORCE_EQ(tp_in->numel(), num_thresholds, ""); + PADDLE_ENFORCE_EQ(fp_in->numel(), num_thresholds, ""); + PADDLE_ENFORCE_EQ(tn_in->numel(), num_thresholds, ""); + PADDLE_ENFORCE_EQ(fn_in->numel(), num_thresholds, ""); + auto* tp_data = true_positive->mutable_data(ctx.GetPlace()); auto* fn_data = false_negative->mutable_data(ctx.GetPlace()); auto* tn_data = true_negative->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index bf6481c5cc..3ce66f49ed 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include +#include #include #include "gflags/gflags.h" #include "glog/logging.h" diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index 2c3bdd77e1..0182bbeb63 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): helper = LayerHelper("auc", **locals()) auc_out = helper.create_tmp_variable(dtype="float64") # make tp, tn, fp, fn persistable, so that can accumulate all batches. - tp = helper.create_global_variable(persistable=True, dtype='int64') - tn = helper.create_global_variable(persistable=True, dtype='int64') - fp = helper.create_global_variable(persistable=True, dtype='int64') - fn = helper.create_global_variable(persistable=True, dtype='int64') + tp = helper.create_global_variable( + persistable=True, dtype='int64', shape=[num_thresholds]) + tn = helper.create_global_variable( + persistable=True, dtype='int64', shape=[num_thresholds]) + fp = helper.create_global_variable( + persistable=True, dtype='int64', shape=[num_thresholds]) + fn = helper.create_global_variable( + persistable=True, dtype='int64', shape=[num_thresholds]) for var in [tp, tn, fp, fn]: helper.set_variable_initializer( var, Constant( From 4fcc2936174315969275abb2a0172c72e3b01bbe Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sun, 26 Aug 2018 16:25:15 +0800 Subject: [PATCH 104/140] memory module (#12931) * memory module * "fix ci" --- .../inference/api/demo_ci/CMakeLists.txt | 2 + .../fluid/memory/detail/system_allocator.cc | 49 +++++++++++++------ 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index ba73a6eaa6..a697218377 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -23,9 +23,11 @@ include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +if (NOT WIN32) include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") include_directories("${PADDLE_LIB}/third_party/install/zlib/include") +endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 9b1ab1e228..1b96798d23 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -11,12 +11,18 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define GLOG_NO_ABBREVIATED_SEVERITIES #include "paddle/fluid/memory/detail/system_allocator.h" -#include // for malloc and free +#ifdef _WIN32 +#include +#include // VirtualLock/VirtualUnlock +#else #include // for mlock and munlock -#include // for std::max +#endif +#include // for malloc and free +#include // for std::max #include "gflags/gflags.h" #include "paddle/fluid/platform/assert.h" @@ -35,31 +41,42 @@ namespace paddle { namespace memory { namespace detail { -void* CPUAllocator::Alloc(size_t* index, size_t size) { - // According to http://www.cplusplus.com/reference/cstdlib/malloc/, - // malloc might not return nullptr if size is zero, but the returned - // pointer shall not be dereferenced -- so we make it nullptr. - if (size <= 0) return nullptr; - - *index = 0; // unlock memory - +void* AlignedMalloc(size_t size) { void* p = nullptr; - + size_t alignment = 32ul; #ifdef PADDLE_WITH_MKLDNN // refer to https://github.com/01org/mkl-dnn/blob/master/include/mkldnn.hpp // memory alignment - PADDLE_ENFORCE_EQ(posix_memalign(&p, 4096ul, size), 0, "Alloc %ld error!", - size); + alignment = 4096ul; +#endif +#ifdef _WIN32 + p = _aligned_malloc(size, alignment); #else - PADDLE_ENFORCE_EQ(posix_memalign(&p, 32ul, size), 0, "Alloc %ld error!", + PADDLE_ENFORCE_EQ(posix_memalign(&p, alignment, size), 0, "Alloc %ld error!", size); #endif PADDLE_ENFORCE(p, "Fail to allocate CPU memory: size = %d .", size); + return p; +} + +void* CPUAllocator::Alloc(size_t* index, size_t size) { + // According to http://www.cplusplus.com/reference/cstdlib/malloc/, + // malloc might not return nullptr if size is zero, but the returned + // pointer shall not be dereferenced -- so we make it nullptr. + if (size <= 0) return nullptr; + + *index = 0; // unlock memory + + void* p = AlignedMalloc(size); if (p != nullptr) { if (FLAGS_use_pinned_memory) { *index = 1; +#ifdef _WIN32 + VirtualLock(p, size); +#else mlock(p, size); // lock memory +#endif } } @@ -68,7 +85,11 @@ void* CPUAllocator::Alloc(size_t* index, size_t size) { void CPUAllocator::Free(void* p, size_t size, size_t index) { if (p != nullptr && index == 1) { +#ifdef _WIN32 + VirtualUnlock(p, size); +#else munlock(p, size); +#endif } free(p); } From 607c41952e78d8c5d489a75590204f802d392ee5 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 26 Aug 2018 16:10:45 +0800 Subject: [PATCH 105/140] compute gates --- paddle/fluid/operators/fusion_lstm_op.cc | 87 +++++++++++++++++++++++- 1 file changed, 84 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 870292827d..604c6f1839 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -220,24 +220,105 @@ class FuisonLSTMKernel : public framework::OpKernel { void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); auto* wx = ctx.Input("WeightX"); auto* wh = ctx.Input("WeightH"); auto* bias = ctx.Input("Bias"); auto* xx = ctx.Output("XX"); + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D - const int M = x_dims[1]; // x frame size + auto x_lod = x->lod(); + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int N = x_lod[0].size() - 1; // batch size + const int M = x_dims[1]; // x frame size + const int D = wh_dims[0]; + const int D2 = D * 2; + const int D3 = D * 3; const int D4 = wh_dims[1]; const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : NULL; + const T* c0_data = c0 ? c0->data() : NULL; const T* wx_data = wx->data(); + const T* wh_data = wh->data(); T* xx_data = xx->mutable_data(ctx.GetPlace()); + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); auto blas = math::GetBlas(ctx); math::FCCompute(blas, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data()); + + for (int i = 0; i < N; ++i) { + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + i * D; + prev_cell_data = c0_data + i * D; + } else { + // W_ch, W_ih, W_fh, W_oh + // actgate + math::vec_sigmoid(D3, xx_data + D, xx_data + D); + // ch gate + math::vec_tanh(D, xx_data, xx_data); + // cell out= input*tilde + blas.VMUL(D, xx_data, xx_data + D, cell_out_data); + // hidden out= act_state(cellout) * outgate + // act state + math::vec_tanh(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + tstart = 1; + + // move offset + xx_data = xx_data + D4; + hidden_out_data = hidden_out_data + D; + cell_out_data = cell_out_data + D; + } + for (int step = tstart; step < seq_len; ++step) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), + prev_hidden_data, D, wh_data, D4, static_cast(1), xx_data, + D4); + + // W_ch, W_ih, W_fh, W_oh + // actgate + math::vec_sigmoid(D3, xx_data + D, xx_data + D); + // ch gate + math::vec_tanh(D, xx_data, xx_data); + + // a = forget * prev_cell + blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); + + // b = input * tilde + blas.VMUL(D, xx_data, xx_data + D, xx_data + D); + + // cell out= a+b + blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + + // hidden out= act_state(cellout) * outgate + // act state + math::vec_tanh(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + + // move offset + xx_data = xx_data + D4; + hidden_out_data = hidden_out_data + D; + cell_out_data = cell_out_data + D; + } + } } void BatchCompute(const framework::ExecutionContext& ctx) const { From 4b28fab8c94863d5ff24ce4c59ff31bb5d06b4ee Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 26 Aug 2018 18:24:00 +0800 Subject: [PATCH 106/140] enable more acts --- paddle/fluid/operators/fusion_lstm_op.cc | 34 ++++++++++++------- .../tests/unittests/test_fusion_lstm_op.py | 2 +- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 604c6f1839..97852e2928 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel { auto* hidden_out = ctx.Output("Hidden"); auto* cell_out = ctx.Output("Cell"); + std::function act_gate, act_cell, act_cand; + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } + auto x_lod = x->lod(); auto x_dims = x->dims(); // T x M auto wh_dims = wh->dims(); // D x 4D @@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_cell_data = c0_data + i * D; } else { // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // cell out= input*tilde blas.VMUL(D, xx_data, xx_data + D, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev @@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel { D4); // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // a = forget * prev_cell blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); @@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel { blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 9d8bef677f..d807f0a8b6 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -45,7 +45,7 @@ def fusion_lstm( class TestLstmOp(OpTest): def set_argument(self): - self.lod = [[2, 3, 2]] + pass def setUp(self): self.op_type = 'fusion_lstm' From 1777cd09f652e18c85a5017058cd29c4794446fa Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 26 Aug 2018 18:42:20 +0800 Subject: [PATCH 107/140] refine fusion lstm op test --- .../tests/unittests/test_fusion_lstm_op.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index d807f0a8b6..19f22fc7bd 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -43,13 +43,13 @@ def fusion_lstm( act_cell, act_cand) -class TestLstmOp(OpTest): - def set_argument(self): +class TestFusionLSTMOp(OpTest): + def set_conf(self): pass def setUp(self): self.op_type = 'fusion_lstm' - self.lod = [[2, 3, 2]] + self.lod = [[2, 3, 5, 4]] self.M = 8 self.D = 16 self.has_initial_state = False @@ -58,33 +58,33 @@ class TestLstmOp(OpTest): self.act_cell = 'tanh' self.act_cand = 'tanh' self.use_peepholes = False - self.set_argument() + self.set_conf() T = sum(self.lod[0]) bs = len(self.lod[0]) - x = np.random.normal(size=(T, self.M)).astype('float64') + x = np.random.normal(size=(T, self.M)).astype('float32') if self.has_initial_state: - h0 = np.random.normal(size=(bs, self.D)).astype('float64') - c0 = np.random.normal(size=(bs, self.D)).astype('float64') + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') else: - h0 = np.zeros((bs, self.D)).astype('float64') - c0 = np.zeros((bs, self.D)).astype('float64') + h0 = np.zeros((bs, self.D)).astype('float32') + c0 = np.zeros((bs, self.D)).astype('float32') - wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') + wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32') if self.use_peepholes: - b = np.random.normal(size=(1, 7 * self.D)).astype('float64') + b = np.random.normal(size=(1, 7 * self.D)).astype('float32') else: - b = np.random.normal(size=(1, 4 * self.D)).astype('float64') + b = np.random.normal(size=(1, 4 * self.D)).astype('float32') w_b = np.copy(b[:, 0:4 * self.D]) w_c = b[:, 4 * self.D:] if self.use_peepholes else None # this is the weight of fc - wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64') + wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32') # this is the bias of fc # and it should be manually added into the bias of this fusion LSTM - bx = np.random.normal(size=(1, 4 * self.D)).astype('float64') + bx = np.random.normal(size=(1, 4 * self.D)).astype('float32') b[0, 0:4 * self.D] += bx[0, :] h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c, self.is_reverse, ACTIVATION[self.act_gate], @@ -114,35 +114,44 @@ class TestLstmOp(OpTest): } def test_check_output(self): - self.check_output(atol=1e-8) + self.check_output() -class TestLstmOpInitReverse(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpInit(TestFusionLSTMOp): + def set_conf(self): self.has_initial_state = True - self.is_reverse = True -class TestLstmOpMD1(TestLstmOp): - def set_argument(self): +# class TestFusionLSTMOpReverse(TestFusionLSTMOp): +# def set_conf(self): +# self.is_reverse = True + +# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): +# def set_conf(self): +# self.has_initial_state = True +# self.is_reverse = True + + +class TestFusionLSTMOpMD1(TestFusionLSTMOp): + def set_conf(self): self.M = 36 self.D = 8 -class TestLstmOpMD2(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpMD2(TestFusionLSTMOp): + def set_conf(self): self.M = 8 self.D = 8 -class TestLstmOpMD3(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpMD3(TestFusionLSTMOp): + def set_conf(self): self.M = 15 self.D = 3 -class TestLstmOpBS1(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpBS1(TestFusionLSTMOp): + def set_conf(self): self.lod = [[3]] self.D = 16 From e61cf3214da019ca1de1fb68ae143928877b4e62 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 26 Aug 2018 21:00:56 +0800 Subject: [PATCH 108/140] complete reverse seq --- paddle/fluid/operators/fusion_lstm_op.cc | 41 ++++++++++++------- .../tests/unittests/test_fusion_lstm_op.py | 17 ++++---- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 97852e2928..e4e4ac8e33 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -229,6 +229,7 @@ class FuisonLSTMKernel : public framework::OpKernel { auto* xx = ctx.Output("XX"); auto* hidden_out = ctx.Output("Hidden"); auto* cell_out = ctx.Output("Cell"); + bool is_reverse = ctx.Attr("is_reverse"); std::function act_gate, act_cell, act_cand; auto& act_gate_str = ctx.Attr("gate_activation"); @@ -247,8 +248,9 @@ class FuisonLSTMKernel : public framework::OpKernel { } auto x_lod = x->lod(); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; // batch size const int M = x_dims[1]; // x frame size const int D = wh_dims[0]; @@ -266,17 +268,34 @@ class FuisonLSTMKernel : public framework::OpKernel { T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); auto blas = math::GetBlas(ctx); - math::FCCompute(blas, x_dims[0], D4, M, x_data, wx_data, + math::FCCompute(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); + int xx_offset = D4; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 4; + hidden_out_data = hidden_out_data + offset; + cell_out_data = cell_out_data + offset; + xx_offset = -D4; + gate_offset = -D; + } + + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + cell_out_data = cell_out_data + gate_offset; + }; for (int i = 0; i < N; ++i) { - int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; const T* prev_cell_data = NULL; const T* prev_hidden_data = NULL; int tstart = 0; if (h0_data) { - prev_hidden_data = h0_data + i * D; - prev_cell_data = c0_data + i * D; + prev_hidden_data = h0_data + bid * D; + prev_cell_data = c0_data + bid * D; } else { // W_ch, W_ih, W_fh, W_oh act_gate(D3, xx_data + D, xx_data + D); @@ -292,10 +311,7 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_cell_data = cell_out_data; tstart = 1; - // move offset - xx_data = xx_data + D4; - hidden_out_data = hidden_out_data + D; - cell_out_data = cell_out_data + D; + move_step(); } for (int step = tstart; step < seq_len; ++step) { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), @@ -323,10 +339,7 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_hidden_data = hidden_out_data; prev_cell_data = cell_out_data; - // move offset - xx_data = xx_data + D4; - hidden_out_data = hidden_out_data + D; - cell_out_data = cell_out_data + D; + move_step(); } } } diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 19f22fc7bd..5805bdf461 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -122,14 +122,15 @@ class TestFusionLSTMOpInit(TestFusionLSTMOp): self.has_initial_state = True -# class TestFusionLSTMOpReverse(TestFusionLSTMOp): -# def set_conf(self): -# self.is_reverse = True - -# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): -# def set_conf(self): -# self.has_initial_state = True -# self.is_reverse = True +class TestFusionLSTMOpReverse(TestFusionLSTMOp): + def set_conf(self): + self.is_reverse = True + + +class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + self.is_reverse = True class TestFusionLSTMOpMD1(TestFusionLSTMOp): From d361624c1db2bac156d856e6d80fb7a10f288536 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 27 Aug 2018 09:42:24 +0800 Subject: [PATCH 109/140] platform module (#12932) * platform module * Update profiler.h --- paddle/fluid/platform/CMakeLists.txt | 5 ++++ paddle/fluid/platform/cpu_info.cc | 21 ++++++++++--- paddle/fluid/platform/device_tracer.h | 10 ++++++- paddle/fluid/platform/dynload/CMakeLists.txt | 2 ++ .../fluid/platform/dynload/dynamic_loader.cc | 3 +- paddle/fluid/platform/enforce.h | 30 +++++++++++++++++-- paddle/fluid/platform/profiler.h | 10 +++++++ 7 files changed, 71 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 75d3856d0d..e25efebe6c 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -1,3 +1,4 @@ +if (NOT WIN32) proto_library(profiler_proto SRCS profiler.proto DEPS framework_proto) py_proto_compile(profiler_py_proto SRCS profiler.proto) @@ -10,6 +11,7 @@ add_custom_command(TARGET profiler_py_proto POST_BUILD COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/profiler COMMENT "Copy generated python proto into directory paddle/fluid/proto/profiler." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) +endif(NOT WIN32) if(WITH_GPU) nv_library(enforce SRCS enforce.cc) @@ -58,9 +60,12 @@ cc_test(init_test SRCS init_test.cc DEPS device_context) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) + +if (NOT WIN32) cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer) cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) +endif(NOT WIN32) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index fcd658d67c..2880c09263 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -22,9 +22,13 @@ limitations under the License. */ #ifdef __APPLE__ #include #include + +#elif defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#include #else #include -#endif +#endif // _WIN32 #include #include "gflags/gflags.h" @@ -32,16 +36,20 @@ limitations under the License. */ DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," "reserve the rest for page tables, etc"); - +#if !defined(_WIN32) DEFINE_uint64(initial_cpu_memory_in_mb, #ifdef PADDLE_WITH_MKLDNN /* Aligned with mozga-intel, MKLDNN need at least 5000 MB * to obtain the best performance*/ - 5000, + 5000ul, #else - 500, + 500ul, #endif "Initial CPU memory for PaddlePaddle, in MD unit."); +#else +DEFINE_uint64(initial_cpu_memory_in_mb, 500ul, + "Initial CPU memory for PaddlePaddle, in MD unit."); +#endif // !defined(_WIN32) DEFINE_double( fraction_of_cuda_pinned_memory_to_use, 0.5, @@ -60,6 +68,11 @@ inline size_t CpuTotalPhysicalMemory() { size_t len = sizeof(size); if (sysctl(mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size; return 0L; +#elif defined(_WIN32) + MEMORYSTATUSEX sMeminfo; + sMeminfo.dwLength = sizeof(sMeminfo); + GlobalMemoryStatusEx(&sMeminfo); + return sMeminfo.ullTotalPhys; #else int64_t pages = sysconf(_SC_PHYS_PAGES); int64_t page_size = sysconf(_SC_PAGE_SIZE); diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index 322996fb4f..f59fc40b71 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -13,7 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#if !defined(_WIN32) #include +#else +#include +#endif // !_WIN32 + #include #include // NOLINT #include @@ -27,12 +32,15 @@ namespace platform { /////////////////////// // WARN: Under Development. Don't depend on it yet. ////////////////////// - +#if !defined(_WIN32) inline uint64_t PosixInNsec() { struct timeval tv; gettimeofday(&tv, nullptr); return 1000 * (static_cast(tv.tv_sec) * 1000000 + tv.tv_usec); } +#else +inline uint64_t PosixInNsec() { return static_cast(0); } +#endif // !_WIN32 // DeviceTracer performs the following tasks: // 1. Register cuda callbacks for various events: kernel, memcpy, etc. diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 07159d4a12..5939c500c9 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -16,7 +16,9 @@ if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) endif(CUPTI_FOUND) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) +if (NOT WIN32) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) +endif(NOT WIN32) if (WITH_MKLML) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 93bf7c1351..4fbfa6354a 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/dynload/dynamic_loader.h" -#include - #include #include // NOLINT #include @@ -23,6 +21,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/dynload/cupti_lib_path.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/port.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index a76ba75f9e..61a653d931 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -18,6 +18,11 @@ limitations under the License. */ #include // for __cxa_demangle #endif // __GNUC__ +#if defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h +#endif + #ifdef PADDLE_WITH_CUDA #include #include @@ -117,7 +122,12 @@ struct EOFException : public std::exception { // always forces branch prediction of true. // This generates faster binary code. __builtin_expect is since C++11. // For more details, please check https://stackoverflow.com/a/43870188/724872. +#if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#else +// there is no equivalent intrinsics in msvc. +#define UNLIKELY(condition) (condition == 0) +#endif template inline typename std::enable_if::type throw_on_error( @@ -230,6 +240,7 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } +#if !defined(_WIN32) #define PADDLE_THROW(...) \ do { \ throw ::paddle::platform::EnforceNotMet( \ @@ -248,15 +259,28 @@ inline void throw_on_error(T e) { __FILE__, __LINE__); \ } \ } while (false) -#else -#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__); -#endif #define PADDLE_THROW_EOF() \ do { \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ __LINE__); \ } while (false) + +#else +#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__) +#endif // REPLACE_ENFORCE_GLOG + +#else // !_WIN32 +// disable enforce, caused by the varardic macro exception error +#define PADDLE_THROW(x) \ + do { \ + throw std::make_exception_ptr( \ + std::runtime_error("Windows disable the enforce.")); \ + } while (false) + +#define PADDLE_ENFORCE(x, ...) x +#endif // !_WIN32 + /* * Some enforce helpers here, usage: * int a = 1; diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index c99d9c807d..38630686f7 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -69,6 +69,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx); void PopEvent(const std::string& name, const DeviceContext* dev_ctx); +#if !defined(_WIN32) struct RecordEvent { RecordEvent(const std::string& name, const DeviceContext* dev_ctx); @@ -94,6 +95,15 @@ struct RecordBlock { std::string name_; uint64_t start_ns_; }; +#else +// windows do not support profiler temporarily. +struct RecordEvent { + RecordEvent(const std::string& name, const DeviceContext* dev_ctx) {} +}; +struct RecordBlock { + explicit RecordBlock(int block_id) {} +}; +#endif // Return the event list of all threads. Assumed the returned value calls // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. From 954b0e113f255a094800c892b495099d65c7378c Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 10:21:48 +0800 Subject: [PATCH 110/140] init fusion seq expand concat fc op --- .../operators/fusion_seq_concat_fc_op.cc | 417 ++++++++++++++++++ .../fluid/operators/fusion_seq_concat_fc_op.h | 41 ++ 2 files changed, 458 insertions(+) create mode 100644 paddle/fluid/operators/fusion_seq_concat_fc_op.cc create mode 100644 paddle/fluid/operators/fusion_seq_concat_fc_op.h diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc new file mode 100644 index 0000000000..810df3c3ed --- /dev/null +++ b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc @@ -0,0 +1,417 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h" +#include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { + +void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(C0) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), + "Input(LSTMWeight) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), + "Input(LSTMBias) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("AttentionWeight"), + "Input(AttentionWeight) of FusionSeqConcatFC should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AttentionedX"), + "Output(AttentionedX) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AttentionFCOut"), + "Output(AttentionFCOut) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), + "Output(LSTMX) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), + "Output(LSTMOUT) of FusionSeqConcatFC should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + const int M = x_dims[1]; + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + auto w_dims = ctx->GetInputDim("LSTMWeight"); + const int D = w_dims[1] / 4; + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], D + M, + "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); + + auto b_dims = ctx->GetInputDim("LSTMBias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); + + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + if (ctx->HasInput("H0")) { + auto h_dims = ctx->GetInputDim("H0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); + PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, + "Input(AttentionWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->HasInput("AttentionBias")) { + auto atten_b_dims = ctx->GetInputDim("AttentionBias"); + PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, + "Input(AttentionBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, + "AttentionBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, + "AttentionBias shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalar")) { + auto dims = ctx->GetInputDim("AttentionScalar"); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalar)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalarBias")) { + auto dims = ctx->GetInputDim("AttentionScalarBias"); + PADDLE_ENFORCE( + ctx->HasInput("AttentionScalar"), + "AttentionScalar should not be null when have AttentionScalarBias."); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalarBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); + } + + framework::DDim out_dims({x_dims[0], D}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("AttentionedX", {x_dims[0], 1}); + ctx->SetOutputDim("LSTMX", {1, M}); + ctx->SetOutputDim("LSTMOUT", {1, 4 * D}); + // AttentionFCOut should be reshape as (maxseqlen,1) in runtime + ctx->ShareLoD("X", "Hidden"); + ctx->ShareLoD("X", "Cell"); +} + +framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); +} + +void FusionSeqConcatFCOpMaker::Make() { + AddInput("X", + "(LoDTensor) the input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X M), where T is the " + "total time steps in this mini-batch, M is the dim size of x."); + AddInput("C0", + "(Tensor) LSTM C0" + "This is a tensor with shape (N x D), where N is the batch size, D " + "is the gate size." + "C0 is necessary because of attention."); + AddInput("H0", + "(Tensor, optional) LSTM H0" + "This is a tensor with shape (N x D), where N is the " + "batch size and D is the gate size.") + .AsDispensable(); + AddInput("AttentionWeight", + "(Tensor) the weights of attention fc. Always relu the fc result." + "The shape is ((M+D) x 1), where M is the dim size of x, D is the " + "gate size of LSTM."); + AddInput("AttentionBias", + "(Tensor, optional) the bias of attention fc." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalar", + "(Tensor, optional) the scalar on the result of attentioned fc. " + "Always relu the Scalar." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalarBias", + "(Tensor, optional) the scalar bias of attention fc." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("LSTMWeight", + "(Tensor) the combined weight of LSTM" + " - The shape is ((D+M) x 4D), where D is the hidden gate size, M " + "is the dim size of x" + " - Weight = {W_forget, W_input, W_output, W_cell}"); + AddInput("LSTMBias", + "(Tensor) the combined bias of LSTM, shape (1x4D)." + "Note: we should add the bias of hidden and context accorindg to " + "the same gate: " + "{B_forget, B_input, B_output, B_cell}"); + AddOutput("Hidden", + "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("AttentionedX", + "(Tensor) shape is (T x 1), the result after X * AttentionWeight," + " where T is the total time steps in this mini-batch," + " D is the hidden size.") + .AsIntermediate(); + AddOutput("AttentionFCOut", + "(Tensor) (max_seq_len, 1), compute at each step.") + .AsIntermediate(); + AddOutput("LSTMX", + "(Tensor) the input X of LSTM for each step." + "Shape is (1 x M), where M is the x frame size") + .AsIntermediate(); + AddOutput( + "LSTMOUT", + "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." + "Shape is (1 x 4D), where M is the x frame size") + .AsIntermediate(); + AddAttr("gate_activation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("cell_activation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("candidate_activation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddComment(R"DOC( +Fusion Sequence expand + concat + fc Operator. + +Only support seq_expand ref_level=0, + +and the ref lod of seq_expand level is the first input of concat, + +and the other inputs should have same lod and same batch size of ref lod. + +)DOC"); +} + +// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; +template +inline void bias_relu(const int n, const T* x, const T* bias, T* y) { + if (bias) { + math::vec_add_bias(n, *bias, x, y); + math::vec_relu(n, y, y); + } else { + math::vec_relu(n, x, y); + } +} + +template +inline void vec_softmax(const int n, const T* x, T* y) { + T scalar = x[0]; + // max + for (int i = 1; i < n; ++i) { + scalar = scalar < x[i] ? x[i] : scalar; + } + math::vec_add_bias(n, -scalar, x, y); // sub + math::vec_exp(n, y, y); // exp + // sum + scalar = T(0); + for (int i = 0; i < n; ++i) { + scalar += y[i]; + } + math::vec_scal(n, static_cast(1) / scalar, y); // scale +} + +template +class FusionSeqConcatFCKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using DeviceContext = paddle::platform::CPUDeviceContext; + + auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); + auto* atten_w = ctx.Input("AttentionWeight"); + auto* atten_b = ctx.Input("AttentionBias"); + auto* atten_scalar = ctx.Input("AttentionScalar"); + auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); + auto* lstm_w = ctx.Input("LSTMWeight"); + auto* lstm_b = ctx.Input("LSTMBias"); + + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); + auto* atted_x = ctx.Output("AttentionedX"); + auto* fc_out = ctx.Output("AttentionFCOut"); + auto* lstm_x = ctx.Output("LSTMX"); + auto* lstm_out = ctx.Output("LSTMOUT"); + + // some shape should be reshape here since infershape can not get lod info + auto x_lod = x->lod(); + const int N = x_lod[0].size() - 1; // batch size + auto x_dims = x->dims(); // T x M + auto w_dims = lstm_w->dims(); // (D+M) x 4D + const int total_T = x_dims[0]; + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size + const int D2 = D * 2; + const int D3 = D * 3; + const int D4 = w_dims[1]; + int max_seq_len = x_lod[0][1]; + for (int i = 1; i < N; ++i) { + int len = x_lod[0][i + 1] - x_lod[0][i]; + max_seq_len = max_seq_len < len ? len : max_seq_len; + } + PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1."); + PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); + fc_out->Resize({max_seq_len, 1}); + + std::function act_gate, act_cell, act_cand; + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } + + const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : NULL; + const T* c0_data = c0->data(); + const T* lstm_w_data = lstm_w->data(); + const T* lstm_b_data = lstm_b->data(); + const T* atten_w_data = atten_w->data(); + const T* atten_b_data = atten_b ? atten_b->data() : NULL; + const T* atten_scalar_data = atten_scalar ? atten_scalar->data() : NULL; + const T* atten_scalar_bias_data = + atten_scalar_bias ? atten_scalar_bias->data() : NULL; + + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + T* atted_x_data = atted_x->mutable_data(ctx.GetPlace()); + T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); + T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); + T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); + + // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, + atted_x_data, atten_b_data); + + const T* cur_atten_x_data = atted_x_data; + const T* cur_x_data = x_data; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + T* cur_cell_out_data = cell_out_data; + T* cur_hidden_out_data = hidden_out_data; + for (int i = 0; i < N; ++i) { + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + prev_cell_data = c0_data + i * D; + prev_hidden_data = h0_data ? h0_data + i * D : NULL; + for (int step = 0; step < seq_len; ++step) { + /// 1. compute attention vector + // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt + T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); + // 1b. add cell bias and relu + bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data); + // 1c. fc scalar + if (atten_scalar_data) { + blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); + bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, + fc_out_data); + } + // 1d. softmax + vec_softmax(seq_len, fc_out_data, fc_out_data); + // mul x(seq_len*M) and sum pool + math::FCCompute(blas, 1, M, seq_len, fc_out_data, + cur_x_data, lstm_x_data); + + /// 2. compute LSTM step + // lstm weight : concat[forget , input , output , tilde] + // shape : (D + M) x (4 * D) + // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D + blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data); + if (prev_hidden_data) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), + prev_hidden_data, D, lstm_w_data, D4, static_cast(1), + lstm_out_data, D4); + } + // since input is 1xM, so can use add bias + blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); + + // gate act: sigmoid + act_gate(D3, lstm_out_data, lstm_out_data); + // candicate act: tanh + act_cand(D, lstm_out_data + D3, lstm_out_data + D3); + + // a = forget * prev_cell + blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); + + // b = input * tilde + blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D); + + // cell_out = a + b + blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); + + // state act tanh(cell_out) * output_gate + act_cell(D, cur_cell_out_data, lstm_out_data); + blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); + + prev_hidden_data = cur_hidden_out_data; + prev_cell_data = cur_cell_out_data; + cur_cell_out_data = cur_cell_out_data + D; + cur_hidden_out_data = cur_hidden_out_data + D; + } + cur_x_data = cur_x_data + seq_len * M; + cur_atten_x_data = cur_atten_x_data + seq_len; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp, + ops::FusionSeqConcatFCOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc, + ops::FusionSeqConcatFCKernel, + ops::FusionSeqConcatFCKernel); diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.h b/paddle/fluid/operators/fusion_seq_concat_fc_op.h new file mode 100644 index 0000000000..66ac48f4c1 --- /dev/null +++ b/paddle/fluid/operators/fusion_seq_concat_fc_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class FusionSeqConcatFCOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionSeqConcatFCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle From 6cc7870517bd1820e82bdb635391bf705820578c Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 27 Aug 2018 09:42:24 +0800 Subject: [PATCH 111/140] fix concat synchronization bug --- paddle/fluid/operators/math/concat.cu | 6 ++++ paddle/fluid/platform/CMakeLists.txt | 5 ++++ paddle/fluid/platform/cpu_info.cc | 21 ++++++++++--- paddle/fluid/platform/device_tracer.h | 10 ++++++- paddle/fluid/platform/dynload/CMakeLists.txt | 2 ++ .../fluid/platform/dynload/dynamic_loader.cc | 3 +- paddle/fluid/platform/enforce.h | 30 +++++++++++++++++-- paddle/fluid/platform/profiler.h | 10 +++++++ 8 files changed, 77 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 820e73e779..342379268b 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -177,6 +177,9 @@ class ConcatFunctor { dev_ins_data, dev_ins_col_data, static_cast(inputs_col.size()), out_row, out_col, output->data()); } + // Wait() must be called because `inputs_data` may be destructed before + // kernel ends + context.Wait(); } }; @@ -252,6 +255,9 @@ class ConcatGradFunctor { input.data(), in_row, in_col, dev_outs_col_data, static_cast(outputs_cols.size()), dev_out_gpu_data); } + // Wait() must be called because `outputs_data` may be destructed before + // kernel ends + context.Wait(); } }; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 75d3856d0d..e25efebe6c 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -1,3 +1,4 @@ +if (NOT WIN32) proto_library(profiler_proto SRCS profiler.proto DEPS framework_proto) py_proto_compile(profiler_py_proto SRCS profiler.proto) @@ -10,6 +11,7 @@ add_custom_command(TARGET profiler_py_proto POST_BUILD COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/profiler COMMENT "Copy generated python proto into directory paddle/fluid/proto/profiler." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) +endif(NOT WIN32) if(WITH_GPU) nv_library(enforce SRCS enforce.cc) @@ -58,9 +60,12 @@ cc_test(init_test SRCS init_test.cc DEPS device_context) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) + +if (NOT WIN32) cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer) cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) +endif(NOT WIN32) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index fcd658d67c..2880c09263 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -22,9 +22,13 @@ limitations under the License. */ #ifdef __APPLE__ #include #include + +#elif defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#include #else #include -#endif +#endif // _WIN32 #include #include "gflags/gflags.h" @@ -32,16 +36,20 @@ limitations under the License. */ DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," "reserve the rest for page tables, etc"); - +#if !defined(_WIN32) DEFINE_uint64(initial_cpu_memory_in_mb, #ifdef PADDLE_WITH_MKLDNN /* Aligned with mozga-intel, MKLDNN need at least 5000 MB * to obtain the best performance*/ - 5000, + 5000ul, #else - 500, + 500ul, #endif "Initial CPU memory for PaddlePaddle, in MD unit."); +#else +DEFINE_uint64(initial_cpu_memory_in_mb, 500ul, + "Initial CPU memory for PaddlePaddle, in MD unit."); +#endif // !defined(_WIN32) DEFINE_double( fraction_of_cuda_pinned_memory_to_use, 0.5, @@ -60,6 +68,11 @@ inline size_t CpuTotalPhysicalMemory() { size_t len = sizeof(size); if (sysctl(mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size; return 0L; +#elif defined(_WIN32) + MEMORYSTATUSEX sMeminfo; + sMeminfo.dwLength = sizeof(sMeminfo); + GlobalMemoryStatusEx(&sMeminfo); + return sMeminfo.ullTotalPhys; #else int64_t pages = sysconf(_SC_PHYS_PAGES); int64_t page_size = sysconf(_SC_PAGE_SIZE); diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index 322996fb4f..f59fc40b71 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -13,7 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#if !defined(_WIN32) #include +#else +#include +#endif // !_WIN32 + #include #include // NOLINT #include @@ -27,12 +32,15 @@ namespace platform { /////////////////////// // WARN: Under Development. Don't depend on it yet. ////////////////////// - +#if !defined(_WIN32) inline uint64_t PosixInNsec() { struct timeval tv; gettimeofday(&tv, nullptr); return 1000 * (static_cast(tv.tv_sec) * 1000000 + tv.tv_usec); } +#else +inline uint64_t PosixInNsec() { return static_cast(0); } +#endif // !_WIN32 // DeviceTracer performs the following tasks: // 1. Register cuda callbacks for various events: kernel, memcpy, etc. diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 07159d4a12..5939c500c9 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -16,7 +16,9 @@ if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) endif(CUPTI_FOUND) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) +if (NOT WIN32) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) +endif(NOT WIN32) if (WITH_MKLML) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 93bf7c1351..4fbfa6354a 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/dynload/dynamic_loader.h" -#include - #include #include // NOLINT #include @@ -23,6 +21,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/dynload/cupti_lib_path.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/port.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index a76ba75f9e..61a653d931 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -18,6 +18,11 @@ limitations under the License. */ #include // for __cxa_demangle #endif // __GNUC__ +#if defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h +#endif + #ifdef PADDLE_WITH_CUDA #include #include @@ -117,7 +122,12 @@ struct EOFException : public std::exception { // always forces branch prediction of true. // This generates faster binary code. __builtin_expect is since C++11. // For more details, please check https://stackoverflow.com/a/43870188/724872. +#if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#else +// there is no equivalent intrinsics in msvc. +#define UNLIKELY(condition) (condition == 0) +#endif template inline typename std::enable_if::type throw_on_error( @@ -230,6 +240,7 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } +#if !defined(_WIN32) #define PADDLE_THROW(...) \ do { \ throw ::paddle::platform::EnforceNotMet( \ @@ -248,15 +259,28 @@ inline void throw_on_error(T e) { __FILE__, __LINE__); \ } \ } while (false) -#else -#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__); -#endif #define PADDLE_THROW_EOF() \ do { \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ __LINE__); \ } while (false) + +#else +#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__) +#endif // REPLACE_ENFORCE_GLOG + +#else // !_WIN32 +// disable enforce, caused by the varardic macro exception error +#define PADDLE_THROW(x) \ + do { \ + throw std::make_exception_ptr( \ + std::runtime_error("Windows disable the enforce.")); \ + } while (false) + +#define PADDLE_ENFORCE(x, ...) x +#endif // !_WIN32 + /* * Some enforce helpers here, usage: * int a = 1; diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index c99d9c807d..38630686f7 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -69,6 +69,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx); void PopEvent(const std::string& name, const DeviceContext* dev_ctx); +#if !defined(_WIN32) struct RecordEvent { RecordEvent(const std::string& name, const DeviceContext* dev_ctx); @@ -94,6 +95,15 @@ struct RecordBlock { std::string name_; uint64_t start_ns_; }; +#else +// windows do not support profiler temporarily. +struct RecordEvent { + RecordEvent(const std::string& name, const DeviceContext* dev_ctx) {} +}; +struct RecordBlock { + explicit RecordBlock(int block_id) {} +}; +#endif // Return the event list of all threads. Assumed the returned value calls // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. From 3e1050a2e86dae6c6aa747aa9ce18722ae8c1938 Mon Sep 17 00:00:00 2001 From: chengduo Date: Mon, 27 Aug 2018 11:03:17 +0800 Subject: [PATCH 112/140] Add pad_constant_like_op (#12943) * Add pad_constant_batch_size_like * refine pad_op * optimize memory --- paddle/fluid/operators/math/padding.h | 124 +++++++++++ .../fluid/operators/pad_constant_like_op.cc | 196 ++++++++++++++++++ .../fluid/operators/pad_constant_like_op.cu | 27 +++ paddle/fluid/operators/pad_constant_like_op.h | 93 +++++++++ paddle/fluid/operators/pad_op.h | 113 ++-------- .../tests/unittests/test_pad_constant_like.py | 69 ++++++ 6 files changed, 529 insertions(+), 93 deletions(-) create mode 100644 paddle/fluid/operators/math/padding.h create mode 100644 paddle/fluid/operators/pad_constant_like_op.cc create mode 100644 paddle/fluid/operators/pad_constant_like_op.cu create mode 100644 paddle/fluid/operators/pad_constant_like_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_pad_constant_like.py diff --git a/paddle/fluid/operators/math/padding.h b/paddle/fluid/operators/math/padding.h new file mode 100644 index 0000000000..3ae25eae98 --- /dev/null +++ b/paddle/fluid/operators/math/padding.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenTensor = framework::EigenTensor; + +template +void PadFunction(const framework::ExecutionContext& context, + const std::vector& pads, const framework::Tensor& src, + T pad_value, framework::Tensor* out) { + Eigen::array, D> paddings; + + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = pads[i * 2]; + paddings[i].second = pads[i * 2 + 1]; + } + + auto src_tensor = EigenTensor::From(src); + auto out_tensor = EigenTensor::From(*out); + + auto& place = + *context.template device_context().eigen_device(); + out_tensor.device(place) = src_tensor.pad(paddings, pad_value); +} + +template +void PadGradFunction(const framework::ExecutionContext& context, + const std::vector& pads, const framework::Tensor& src, + framework::Tensor* d_out) { + Eigen::array, D> paddings; + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = -pads[i * 2]; + paddings[i].second = -pads[i * 2 + 1]; + } + + auto d_out_tensor = EigenTensor::From(*d_out); + auto src_tensor = EigenTensor::From(src); + auto& place = + *context.template device_context().eigen_device(); + d_out_tensor.device(place) = src_tensor.pad(paddings, 0); +} + +template +void PaddingFunctor(int rank, const framework::ExecutionContext& context, + const std::vector& pads, T pad_value, + const framework::Tensor& src, framework::Tensor* out) { + switch (rank) { + case 1: + PadFunction(context, pads, src, pad_value, out); + break; + case 2: + PadFunction(context, pads, src, pad_value, out); + break; + case 3: + PadFunction(context, pads, src, pad_value, out); + break; + case 4: + PadFunction(context, pads, src, pad_value, out); + break; + case 5: + PadFunction(context, pads, src, pad_value, out); + break; + case 6: + PadFunction(context, pads, src, pad_value, out); + break; + default: + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); + } +} + +template +void PaddingGradFunctor(int rank, const framework::ExecutionContext& context, + const std::vector& pads, + const framework::Tensor& src, framework::Tensor* out) { + switch (rank) { + case 1: + PadGradFunction(context, pads, src, out); + break; + case 2: + PadGradFunction(context, pads, src, out); + break; + case 3: + PadGradFunction(context, pads, src, out); + break; + case 4: + PadGradFunction(context, pads, src, out); + break; + case 5: + PadGradFunction(context, pads, src, out); + break; + case 6: + PadGradFunction(context, pads, src, out); + break; + default: + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); + } +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc new file mode 100644 index 0000000000..5958811d38 --- /dev/null +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -0,0 +1,196 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/pad_constant_like_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class PadConstantLikeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of PadConstantLikeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of PadConstantLikeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PadConstantLikeOp should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ(x_dim.size(), y_dim.size(), + "The dimention of X and Y should be the same."); + + for (int i = 0; i < x_dim.size(); ++i) { + PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]); + } + ctx->SetOutputDim("Out", x_dim); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class PadConstantLikeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input of pad_constant_like op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddInput("Y", + "The input of pad_constant_like op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddOutput("Out", + "The output of pad_constant_like op. " + "A tensor with the same shape as X."); + AddAttr("pad_value", + "(float, default 0.0) " + "The value to fill the padded areas.") + .SetDefault(0.0f); + AddComment(R"DOC( +PadConstantLikeOp Operator. + +Pad input(Y) with a pad_value, the number of values padded to the edges of each +axis is specified by the difference of the shape of X and Y. +((0, shape_x_0 - shape_y_0), … (0, shape_x_n - shape_y_n)) unique pad widths for +each axis. +The input should be a k-D tensor(k > 0 and k < 7). As an example: + +case1: + Given: + X = [[1, 2], + [3, 4], + [1, 2], + [3, 4]]], + X.shape = (4, 2) + + Y = [[5, 6], + [7, 8]], + Y.shape = (2, 2) + + And + pad_value = 0, + + Return: + Out = [[5, 6], + [7, 8], + [0, 0], + [0, 0]] + Out.shape = (4, 2) + +case2: + Given: + X = [[[[ 0, 1, 2], + [ 3, 4, 5]], + [[ 6, 7, 8], + [ 9, 10, 11]], + [[12, 13, 14], + [15, 16, 17]]], + [[[18, 19, 20], + [21, 22, 23]], + [[24, 25, 26], + [27, 28, 29]], + [[30, 31, 32], + [33, 34, 35]]]] + X.shape = (2, 3, 2, 3) + + Y = [[[[35, 36, 37]], + [[38, 39, 40]], + [[41, 42, 43]]]] + Y.shape = (1, 3, 1, 3) + + And + pad_value = -1, + + Return: + + Out = [[[[35, 36, 37], + [-1, -1, -1]], + [[38, 39, 40], + [-1, -1, -1]], + [[41, 42, 43], + [-1, -1, -1]]], + [[[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]]]] + Out.shape = (2, 3, 2, 3) +)DOC"); + } +}; + +class PadConstantLikeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto y_dim = ctx->GetInputDim("Y"); + auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(dout_dim.size(), y_dim.size(), + "The dimention of X and Y should be the same."); + + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dim); + ctx->ShareLoD("Y", /*->*/ y_grad_name); + + for (int i = 0; i < y_dim.size(); ++i) { + PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]); + } + } + } +}; + +class PadConstantLikeOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *bind = new framework::OpDesc(); + bind->SetType("pad_constant_like_grad"); + bind->SetInput("Y", Input("Y")); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + bind->SetAttrMap(Attrs()); + return std::unique_ptr(bind); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(pad_constant_like, ops::PadConstantLikeOp, + ops::PadConstantLikeOpMaker, ops::PadConstantLikeOpGradMaker); +REGISTER_OPERATOR(pad_constant_like_grad, ops::PadConstantLikeOpGrad); + +REGISTER_OP_CPU_KERNEL( + pad_constant_like, + ops::PadConstantLikeKernel, + ops::PadConstantLikeKernel); +REGISTER_OP_CPU_KERNEL( + pad_constant_like_grad, + ops::PadConstantLikeGradKernel, + ops::PadConstantLikeGradKernel); diff --git a/paddle/fluid/operators/pad_constant_like_op.cu b/paddle/fluid/operators/pad_constant_like_op.cu new file mode 100644 index 0000000000..ea69577904 --- /dev/null +++ b/paddle/fluid/operators/pad_constant_like_op.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/fluid/operators/pad_constant_like_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + pad_constant_like, + ops::PadConstantLikeKernel, + ops::PadConstantLikeKernel); +REGISTER_OP_CUDA_KERNEL( + pad_constant_like_grad, + ops::PadConstantLikeGradKernel, + ops::PadConstantLikeGradKernel); diff --git a/paddle/fluid/operators/pad_constant_like_op.h b/paddle/fluid/operators/pad_constant_like_op.h new file mode 100644 index 0000000000..01d66901af --- /dev/null +++ b/paddle/fluid/operators/pad_constant_like_op.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/padding.h" + +namespace paddle { +namespace operators { + +template +class PadConstantLikeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto in_x = context.Input("X"); + auto in_y = context.Input("Y"); + auto* out = context.Output("Out"); + + if (in_x->dims() == in_y->dims()) { + // TensorCopy(in_y, context.GetPlace(), context, out); + out->ShareDataWith(*in_y); + return; + } + + T pad_value = context.Attr("pad_value"); + out->mutable_data(context.GetPlace()); + + int rank = context.Input("X")->dims().size(); + + std::vector pads(rank * 2, 0); + + for (int j = 0; j < rank; ++j) { + pads[j * 2] = 0; + pads[j * 2 + 1] = static_cast(in_x->dims()[j] - in_y->dims()[j]); + } + + math::PaddingFunctor(rank, context, pads, pad_value, + *in_y, out); + } +}; + +template +class PadConstantLikeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto in_y = context.Input("Y"); + auto in_dout = + context.Input(framework::GradVarName("Out")); + auto* d_y = context.Output(framework::GradVarName("Y")); + + if (d_y == nullptr) { + return; + } + + if (in_dout->dims() == in_y->dims()) { + // TensorCopy(in_dout, context.GetPlace(), context, d_y); + d_y->ShareDataWith(*in_dout); + return; + } + + d_y->mutable_data(context.GetPlace()); + int rank = in_dout->dims().size(); + + std::vector pads(static_cast(rank) * 2, 0); + for (int j = 0; j < rank; ++j) { + pads[j * 2] = 0; + pads[j * 2 + 1] = static_cast(in_dout->dims()[j] - in_y->dims()[j]); + } + + math::PaddingGradFunctor(rank, context, pads, *in_dout, + d_y); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pad_op.h b/paddle/fluid/operators/pad_op.h index c93c096575..32698dac49 100644 --- a/paddle/fluid/operators/pad_op.h +++ b/paddle/fluid/operators/pad_op.h @@ -18,117 +18,44 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/padding.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenTensor = framework::EigenTensor; - -template -void PadFunction(const framework::ExecutionContext& context) { - auto pads = context.Attr>("paddings"); - Eigen::array, D> paddings; - for (size_t i = 0; i < paddings.size(); ++i) { - paddings[i].first = pads[i * 2]; - paddings[i].second = pads[i * 2 + 1]; - } - T pad_value = context.Attr("pad_value"); - - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - - auto x_tensor = EigenTensor::From(*x); - auto out_tensor = EigenTensor::From(*out); - auto& place = - *context.template device_context().eigen_device(); - out_tensor.device(place) = x_tensor.pad(paddings, pad_value); -} - template class PadKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int rank = context.Input("X")->dims().size(); - switch (rank) { - case 1: - PadFunction(context); - break; - case 2: - PadFunction(context); - break; - case 3: - PadFunction(context); - break; - case 4: - PadFunction(context); - break; - case 5: - PadFunction(context); - break; - case 6: - PadFunction(context); - break; - default: - PADDLE_THROW( - "PadOp only support tensors with no more than 6 dimensions."); - } + auto pads = context.Attr>("paddings"); + T pad_value = context.Attr("pad_value"); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + int rank = x->dims().size(); + math::PaddingFunctor(rank, context, pads, pad_value, *x, + out); } }; -template -void PadGradFunction(const framework::ExecutionContext& context) { - auto pads = context.Attr>("paddings"); - Eigen::array, D> paddings; - for (size_t i = 0; i < paddings.size(); ++i) { - paddings[i].first = -pads[i * 2]; - paddings[i].second = -pads[i * 2 + 1]; - } - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); - if (d_x != nullptr) { - d_x->mutable_data(context.GetPlace()); - auto d_x_tensor = EigenTensor::From(*d_x); - auto d_out_tensor = EigenTensor::From(*d_out); - auto& place = - *context.template device_context().eigen_device(); - d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0); - } -} - template class PadGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - size_t rank = - context.Input(framework::GradVarName("Out"))->dims().size(); - switch (rank) { - case 1: - PadGradFunction(context); - break; - case 2: - PadGradFunction(context); - break; - case 3: - PadGradFunction(context); - break; - case 4: - PadGradFunction(context); - break; - case 5: - PadGradFunction(context); - break; - case 6: - PadGradFunction(context); - break; - default: - PADDLE_THROW( - "PadOp only support tensors with no more than 6 dimensions."); + auto pads = context.Attr>("paddings"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + if (d_x == nullptr) { + return; } + + d_x->mutable_data(context.GetPlace()); + int rank = d_out->dims().size(); + math::PaddingGradFunctor(rank, context, pads, *d_out, + d_x); } }; diff --git a/python/paddle/fluid/tests/unittests/test_pad_constant_like.py b/python/paddle/fluid/tests/unittests/test_pad_constant_like.py new file mode 100644 index 0000000000..6b733fd8fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pad_constant_like.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestPadOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = "pad_constant_like" + self.inputs = { + 'X': np.random.random(self.x_shape).astype("float32"), + 'Y': np.random.random(self.y_shape).astype("float32") + } + self.attrs = {} + self.attrs['pad_value'] = self.pad_value + self.outputs = { + 'Out': np.pad(self.inputs['Y'], + self.paddings, + mode='constant', + constant_values=self.pad_value) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Y'], 'Out', max_relative_error=0.006) + + def initTestCase(self): + self.x_shape = (16, 16) + self.y_shape = (3, 16) + self.pad_value = 0.1 + self.paddings = [(0, 13), (0, 0)] + + +class TestCase1(TestPadOp): + def initTestCase(self): + self.x_shape = (4, 3, 4, 4) + self.y_shape = (2, 3, 4, 4) + self.paddings = [(0, 2), (0, 0), (0, 0), (0, 0)] + self.pad_value = 0.5 + + +class TestCase2(TestPadOp): + def initTestCase(self): + self.x_shape = (4, 3, 4, 4) + self.y_shape = (2, 3, 2, 4) + self.paddings = [(0, 2), (0, 0), (0, 2), (0, 0)] + self.pad_value = 0.5 + + +if __name__ == '__main__': + unittest.main() From 478eeabdd488fb155242629f39516ac3897d0bce Mon Sep 17 00:00:00 2001 From: nhzlx Date: Mon, 27 Aug 2018 05:11:41 +0000 Subject: [PATCH 113/140] refine uttest of api_tensorrt_subgraph_engine --- .../fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc index 8f1a72316d..9e7425eddd 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc @@ -37,6 +37,7 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) { config1.use_gpu = true; config1.fraction_of_gpu_memory = 0.3; config1.device = 0; + config1.max_batch_size = 10; auto predictor0 = CreatePaddlePredictor(config0); From c7c25067338dc147c5b6b282ce34205f4bfee373 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 13:12:33 +0800 Subject: [PATCH 114/140] add forward implementation --- .../operators/fusion_seq_concat_fc_op.cc | 318 +++++------------- 1 file changed, 83 insertions(+), 235 deletions(-) diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc index 810df3c3ed..203ebaf3e2 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc @@ -25,30 +25,15 @@ namespace operators { void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), - "Input(LSTMWeight) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), - "Input(LSTMBias) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE( - ctx->HasInput("AttentionWeight"), - "Input(AttentionWeight) of FusionSeqConcatFC should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("AttentionedX"), - "Output(AttentionedX) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("AttentionFCOut"), - "Output(AttentionFCOut) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), - "Output(LSTMX) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), - "Output(LSTMOUT) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasInput("FCWeight"), + "Input(FCWeight) of FusionSeqConcatFC should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("FCOut"), + "Output(FCOut) of FusionSeqConcatFC should not be null."); + + // need check fc height = all inputs width sum auto x_dims = ctx->GetInputDim("X"); const int M = x_dims[1]; @@ -120,6 +105,9 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { // AttentionFCOut should be reshape as (maxseqlen,1) in runtime ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); + + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( @@ -131,95 +119,37 @@ framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( void FusionSeqConcatFCOpMaker::Make() { AddInput("X", - "(LoDTensor) the input is a LodTensor, which support " - "variable-time length input sequence. The underlying tensor in " - "this LoDTensor is a matrix with shape (T X M), where T is the " - "total time steps in this mini-batch, M is the dim size of x."); - AddInput("C0", - "(Tensor) LSTM C0" - "This is a tensor with shape (N x D), where N is the batch size, D " - "is the gate size." - "C0 is necessary because of attention."); - AddInput("H0", - "(Tensor, optional) LSTM H0" - "This is a tensor with shape (N x D), where N is the " - "batch size and D is the gate size.") - .AsDispensable(); - AddInput("AttentionWeight", - "(Tensor) the weights of attention fc. Always relu the fc result." - "The shape is ((M+D) x 1), where M is the dim size of x, D is the " - "gate size of LSTM."); - AddInput("AttentionBias", - "(Tensor, optional) the bias of attention fc." - "The shape is (1 x 1)") - .AsDispensable(); - AddInput("AttentionScalar", - "(Tensor, optional) the scalar on the result of attentioned fc. " - "Always relu the Scalar." - "The shape is (1 x 1)") - .AsDispensable(); - AddInput("AttentionScalarBias", - "(Tensor, optional) the scalar bias of attention fc." - "The shape is (1 x 1)") - .AsDispensable(); - AddInput("LSTMWeight", - "(Tensor) the combined weight of LSTM" - " - The shape is ((D+M) x 4D), where D is the hidden gate size, M " - "is the dim size of x" - " - Weight = {W_forget, W_input, W_output, W_cell}"); - AddInput("LSTMBias", - "(Tensor) the combined bias of LSTM, shape (1x4D)." - "Note: we should add the bias of hidden and context accorindg to " - "the same gate: " - "{B_forget, B_input, B_output, B_cell}"); - AddOutput("Hidden", - "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " - "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput("Cell", - "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " - "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput("AttentionedX", - "(Tensor) shape is (T x 1), the result after X * AttentionWeight," - " where T is the total time steps in this mini-batch," - " D is the hidden size.") - .AsIntermediate(); - AddOutput("AttentionFCOut", - "(Tensor) (max_seq_len, 1), compute at each step.") - .AsIntermediate(); - AddOutput("LSTMX", - "(Tensor) the input X of LSTM for each step." - "Shape is (1 x M), where M is the x frame size") - .AsIntermediate(); + "(LoDTensor) input LodDTensors, the first one must be have ref lod " + "for sequence expand, and the rest input should have same lod.") + .AsDuplicable(); + AddInput("FCWeight", "(Tensor) the weights of fc."); + AddInput("FCBias", "(Tensor, optional) the bias of fc.").AsDispensable(); + AddOutput("Out", "(LoDTensor) Output LodTensor."); AddOutput( - "LSTMOUT", - "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." - "Shape is (1 x 4D), where M is the x frame size") + "FCOut", + "(Tensor) the intermediate tensor to keep the result of fc." + "Shape is (N x D), where N is the batch size, D is the output dim of fc") .AsIntermediate(); - AddAttr("gate_activation", - "(string, default: sigmoid)" - "The activation for input gate, forget gate and output " - "gate, `sigmoid` by default.") - .SetDefault("sigmoid") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("cell_activation", - "(string, default: tanh)" - "The activation for cell output, `tanh` by defalut.") - .SetDefault("tanh") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("candidate_activation", - "(string, default: tanh)" - "The activation for candidate hidden state, " - "`tanh` by default.") - .SetDefault("tanh") + AddAttr("fc_activation", + "(string, default: identity)" + "The activation for the result of fc." + "`identity` by default.") + .SetDefault("identity") .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Fusion Sequence expand + concat + fc Operator. -Only support seq_expand ref_level=0, +All below conditions should be meet: -and the ref lod of seq_expand level is the first input of concat, +The ref_level of seq_expand should be 0. -and the other inputs should have same lod and same batch size of ref lod. +The ref lod of seq_expand level is the first input of concat. + +The other inputs should have same lod and same batch size of ref lod. + +The seq len of other inputs should be 1. + +The concat axis should be 1. )DOC"); } @@ -257,150 +187,68 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; + auto* ins = ctx.Input("X"); + auto* w = ctx.Input("FCWeight"); + auto* b = ctx.Input("FCBias"); - auto* x = ctx.Input("X"); - auto* h0 = ctx.Input("H0"); - auto* c0 = ctx.Input("C0"); - auto* atten_w = ctx.Input("AttentionWeight"); - auto* atten_b = ctx.Input("AttentionBias"); - auto* atten_scalar = ctx.Input("AttentionScalar"); - auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); - auto* lstm_w = ctx.Input("LSTMWeight"); - auto* lstm_b = ctx.Input("LSTMBias"); - - auto* hidden_out = ctx.Output("Hidden"); - auto* cell_out = ctx.Output("Cell"); - auto* atted_x = ctx.Output("AttentionedX"); - auto* fc_out = ctx.Output("AttentionFCOut"); - auto* lstm_x = ctx.Output("LSTMX"); - auto* lstm_out = ctx.Output("LSTMOUT"); - - // some shape should be reshape here since infershape can not get lod info - auto x_lod = x->lod(); - const int N = x_lod[0].size() - 1; // batch size - auto x_dims = x->dims(); // T x M - auto w_dims = lstm_w->dims(); // (D+M) x 4D - const int total_T = x_dims[0]; - const int M = x_dims[1]; // x frame size - const int D = w_dims[1] / 4; // gate frame size - const int D2 = D * 2; - const int D3 = D * 3; - const int D4 = w_dims[1]; - int max_seq_len = x_lod[0][1]; - for (int i = 1; i < N; ++i) { - int len = x_lod[0][i + 1] - x_lod[0][i]; - max_seq_len = max_seq_len < len ? len : max_seq_len; - } - PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1."); - PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); - fc_out->Resize({max_seq_len, 1}); - - std::function act_gate, act_cell, act_cand; - auto& act_gate_str = ctx.Attr("gate_activation"); - auto& act_cell_str = ctx.Attr("cell_activation"); - auto& act_cand_str = ctx.Attr("candidate_activation"); + auto* out = ctx.Output("Out"); + auto* fc_out = ctx.Output("FCOUT"); + + std::function fc_act; + auto& fc_act_str = ctx.Attr("fc_activation"); if (platform::jit::MayIUse(platform::jit::avx)) { math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); + fc_act = act_functor(fc_act_str); } else { math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); + fc_act = act_functor(fc_act_str); } - const T* x_data = x->data(); - const T* h0_data = h0 ? h0->data() : NULL; - const T* c0_data = c0->data(); - const T* lstm_w_data = lstm_w->data(); - const T* lstm_b_data = lstm_b->data(); - const T* atten_w_data = atten_w->data(); - const T* atten_b_data = atten_b ? atten_b->data() : NULL; - const T* atten_scalar_data = atten_scalar ? atten_scalar->data() : NULL; - const T* atten_scalar_bias_data = - atten_scalar_bias ? atten_scalar_bias->data() : NULL; - - T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); - T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); - T* atted_x_data = atted_x->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_GT(ins.size(), 1, "Input(X)'s size must larger than 1."); + auto* ref_in = ins[0]; + auto ref_in_lod = ref_in->lod(); + const int N = ref_in_lod[0].size() - 1; + auto ref_in_dims = ref_in->dims(); // T x M0 + auto w_dims = w->dims(); // (M0+M1+M2+..) x D + const int total_T = ref_in_dims[0]; + const int M0 = ref_in_dims[1]; + const int M1 = ins[1]->dims()[1]; + const int D = w_dims[1]; + + const T* ref_in_data = + ref_in->data(); // size should be check at infershape + const T* in1_data = ins[1]->data(); + const T* w_data = w->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); - T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); - T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); - // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto blas = math::GetBlas(ctx); - math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, - atted_x_data, atten_b_data); - - const T* cur_atten_x_data = atted_x_data; - const T* cur_x_data = x_data; - const T* prev_cell_data = NULL; - const T* prev_hidden_data = NULL; - T* cur_cell_out_data = cell_out_data; - T* cur_hidden_out_data = hidden_out_data; + math::FCCompute(blas, total_T, D, M0, ref_in_data, w_data, + out_data, b ? b->data() : NULL); + w_data = w_data + M0 * D; + + // first one use write on + blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); + w_data = w_data + M1 * D; + for (int i = 2; i < ins.size(); ++i) { + // add on + const T* in_data = ins[i]->data(); + const int K = ins[i]->dims()[1]; + blas.GEMM(CblasNoTrans, CblasNoTrans, N, D, K, static_cast(1), in_data, + K, w_data, D, static_cast(1), fc_out_data, D); + w_data = w_data + K * D; + } + for (int i = 0; i < N; ++i) { - int seq_len = x_lod[0][i + 1] - x_lod[0][i]; - prev_cell_data = c0_data + i * D; - prev_hidden_data = h0_data ? h0_data + i * D : NULL; + int seq_len = ref_in_lod[0][i + 1] - ref_in_lod[0][i]; + T* src = fc_out_data + i * D; for (int step = 0; step < seq_len; ++step) { - /// 1. compute attention vector - // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt - T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); - // 1b. add cell bias and relu - bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data); - // 1c. fc scalar - if (atten_scalar_data) { - blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); - bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, - fc_out_data); - } - // 1d. softmax - vec_softmax(seq_len, fc_out_data, fc_out_data); - // mul x(seq_len*M) and sum pool - math::FCCompute(blas, 1, M, seq_len, fc_out_data, - cur_x_data, lstm_x_data); - - /// 2. compute LSTM step - // lstm weight : concat[forget , input , output , tilde] - // shape : (D + M) x (4 * D) - // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D - blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data); - if (prev_hidden_data) { - blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), - prev_hidden_data, D, lstm_w_data, D4, static_cast(1), - lstm_out_data, D4); - } - // since input is 1xM, so can use add bias - blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); - - // gate act: sigmoid - act_gate(D3, lstm_out_data, lstm_out_data); - // candicate act: tanh - act_cand(D, lstm_out_data + D3, lstm_out_data + D3); - - // a = forget * prev_cell - blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); - - // b = input * tilde - blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D); - - // cell_out = a + b - blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); - - // state act tanh(cell_out) * output_gate - act_cell(D, cur_cell_out_data, lstm_out_data); - blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); - - prev_hidden_data = cur_hidden_out_data; - prev_cell_data = cur_cell_out_data; - cur_cell_out_data = cur_cell_out_data + D; - cur_hidden_out_data = cur_hidden_out_data + D; + blas.VADD(D, out_data, src, out_data); + out_data = out_data + D; } - cur_x_data = cur_x_data + seq_len * M; - cur_atten_x_data = cur_atten_x_data + seq_len; } + + fc_act(out_dims[0] * out_dims[1], out_data, out_data); } }; From 0153c21d83c0ac77eb2125db20a59e233e83c098 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 27 Aug 2018 09:42:24 +0800 Subject: [PATCH 115/140] add unstack_op --- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/unstack_op.cc | 26 ++++ paddle/fluid/operators/unstack_op.h | 135 ++++++++++++++++++ paddle/fluid/platform/CMakeLists.txt | 5 + paddle/fluid/platform/cpu_info.cc | 21 ++- paddle/fluid/platform/device_tracer.h | 10 +- paddle/fluid/platform/dynload/CMakeLists.txt | 2 + .../fluid/platform/dynload/dynamic_loader.cc | 3 +- paddle/fluid/platform/enforce.h | 30 +++- paddle/fluid/platform/profiler.h | 10 ++ python/paddle/fluid/layers/nn.py | 42 ++++++ .../fluid/tests/unittests/test_unstack_op.py | 81 +++++++++++ 12 files changed, 356 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/operators/unstack_op.cc create mode 100644 paddle/fluid/operators/unstack_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_unstack_op.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 8da0aaaafe..e73d31562a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) +op_library(unstack_op DEPS stack_op) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc new file mode 100644 index 0000000000..4ff3249cc3 --- /dev/null +++ b/paddle/fluid/operators/unstack_op.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/unstack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +USE_OP(stack); + +REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, + ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker); + +REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp, + ops::UnStackOpGradInferShape); diff --git a/paddle/fluid/operators/unstack_op.h b/paddle/fluid/operators/unstack_op.h new file mode 100644 index 0000000000..348a103880 --- /dev/null +++ b/paddle/fluid/operators/unstack_op.h @@ -0,0 +1,135 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class UnStackOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist."); + + int axis = ctx->Attrs().Get("axis"); + int num = ctx->Attrs().Get("num"); + auto x_dim = ctx->GetInputDim("X"); + int rank = x_dim.size(); + PADDLE_ENFORCE(axis >= -rank && axis < rank, + "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + if (axis < 0) axis += rank; + + PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast(num), + "Number of Outputs(Y) is wrong"); + if (x_dim[axis] > 0) { + PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong"); + } + auto vec = framework::vectorize2int(x_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim("Y", std::vector( // NOLINT + x_dim[axis], framework::make_ddim(vec))); + } +}; + +class UnStackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of unstack op."); + AddOutput("Y", "The output of unstack op.").AsDuplicable(); + AddAttr("axis", "The axis along which Input(X) should be unstacked.") + .SetDefault(0); + AddAttr("num", "The number of outputs(Y).").GreaterThan(0); + AddComment(R"DOC( + UnStack Operator. + + UnStack Input(X) into several tensors along Attr(axis). + )DOC"); + } +}; + +class UnStackOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto stack_grad_op = framework::OpRegistry::CreateOp( + "stack_grad", {{framework::GradVarName("Y"), {Input("X")}}}, + {{framework::GradVarName("X"), Outputs("Y")}}, Attrs()); + stack_grad_op->Run(scope, place); + } +}; + +class UnStackOpGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0, + "Number of Inputs(Y@Grad) must be larger than 0"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad) must exist."); + + auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + "Dims of all Inputs(Y@Grad) must be the same"); + } + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE( + axis >= -(rank + 1) && axis < rank + 1, + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize2int(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec)); + } +}; + +class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("unstack_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +class UnStackGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto stack_op = framework::OpRegistry::CreateOp( + "stack", {{"X", Inputs(framework::GradVarName("Y"))}}, + {{"Y", {Output(framework::GradVarName("X"))}}}, Attrs()); + stack_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 75d3856d0d..e25efebe6c 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -1,3 +1,4 @@ +if (NOT WIN32) proto_library(profiler_proto SRCS profiler.proto DEPS framework_proto) py_proto_compile(profiler_py_proto SRCS profiler.proto) @@ -10,6 +11,7 @@ add_custom_command(TARGET profiler_py_proto POST_BUILD COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/profiler COMMENT "Copy generated python proto into directory paddle/fluid/proto/profiler." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) +endif(NOT WIN32) if(WITH_GPU) nv_library(enforce SRCS enforce.cc) @@ -58,9 +60,12 @@ cc_test(init_test SRCS init_test.cc DEPS device_context) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) + +if (NOT WIN32) cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer) cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) +endif(NOT WIN32) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index fcd658d67c..2880c09263 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -22,9 +22,13 @@ limitations under the License. */ #ifdef __APPLE__ #include #include + +#elif defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#include #else #include -#endif +#endif // _WIN32 #include #include "gflags/gflags.h" @@ -32,16 +36,20 @@ limitations under the License. */ DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," "reserve the rest for page tables, etc"); - +#if !defined(_WIN32) DEFINE_uint64(initial_cpu_memory_in_mb, #ifdef PADDLE_WITH_MKLDNN /* Aligned with mozga-intel, MKLDNN need at least 5000 MB * to obtain the best performance*/ - 5000, + 5000ul, #else - 500, + 500ul, #endif "Initial CPU memory for PaddlePaddle, in MD unit."); +#else +DEFINE_uint64(initial_cpu_memory_in_mb, 500ul, + "Initial CPU memory for PaddlePaddle, in MD unit."); +#endif // !defined(_WIN32) DEFINE_double( fraction_of_cuda_pinned_memory_to_use, 0.5, @@ -60,6 +68,11 @@ inline size_t CpuTotalPhysicalMemory() { size_t len = sizeof(size); if (sysctl(mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size; return 0L; +#elif defined(_WIN32) + MEMORYSTATUSEX sMeminfo; + sMeminfo.dwLength = sizeof(sMeminfo); + GlobalMemoryStatusEx(&sMeminfo); + return sMeminfo.ullTotalPhys; #else int64_t pages = sysconf(_SC_PHYS_PAGES); int64_t page_size = sysconf(_SC_PAGE_SIZE); diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index 322996fb4f..f59fc40b71 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -13,7 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#if !defined(_WIN32) #include +#else +#include +#endif // !_WIN32 + #include #include // NOLINT #include @@ -27,12 +32,15 @@ namespace platform { /////////////////////// // WARN: Under Development. Don't depend on it yet. ////////////////////// - +#if !defined(_WIN32) inline uint64_t PosixInNsec() { struct timeval tv; gettimeofday(&tv, nullptr); return 1000 * (static_cast(tv.tv_sec) * 1000000 + tv.tv_usec); } +#else +inline uint64_t PosixInNsec() { return static_cast(0); } +#endif // !_WIN32 // DeviceTracer performs the following tasks: // 1. Register cuda callbacks for various events: kernel, memcpy, etc. diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 07159d4a12..5939c500c9 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -16,7 +16,9 @@ if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) endif(CUPTI_FOUND) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) +if (NOT WIN32) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) +endif(NOT WIN32) if (WITH_MKLML) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 93bf7c1351..4fbfa6354a 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/dynload/dynamic_loader.h" -#include - #include #include // NOLINT #include @@ -23,6 +21,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/dynload/cupti_lib_path.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/port.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index a76ba75f9e..61a653d931 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -18,6 +18,11 @@ limitations under the License. */ #include // for __cxa_demangle #endif // __GNUC__ +#if defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h +#endif + #ifdef PADDLE_WITH_CUDA #include #include @@ -117,7 +122,12 @@ struct EOFException : public std::exception { // always forces branch prediction of true. // This generates faster binary code. __builtin_expect is since C++11. // For more details, please check https://stackoverflow.com/a/43870188/724872. +#if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#else +// there is no equivalent intrinsics in msvc. +#define UNLIKELY(condition) (condition == 0) +#endif template inline typename std::enable_if::type throw_on_error( @@ -230,6 +240,7 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } +#if !defined(_WIN32) #define PADDLE_THROW(...) \ do { \ throw ::paddle::platform::EnforceNotMet( \ @@ -248,15 +259,28 @@ inline void throw_on_error(T e) { __FILE__, __LINE__); \ } \ } while (false) -#else -#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__); -#endif #define PADDLE_THROW_EOF() \ do { \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ __LINE__); \ } while (false) + +#else +#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__) +#endif // REPLACE_ENFORCE_GLOG + +#else // !_WIN32 +// disable enforce, caused by the varardic macro exception error +#define PADDLE_THROW(x) \ + do { \ + throw std::make_exception_ptr( \ + std::runtime_error("Windows disable the enforce.")); \ + } while (false) + +#define PADDLE_ENFORCE(x, ...) x +#endif // !_WIN32 + /* * Some enforce helpers here, usage: * int a = 1; diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index c99d9c807d..38630686f7 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -69,6 +69,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx); void PopEvent(const std::string& name, const DeviceContext* dev_ctx); +#if !defined(_WIN32) struct RecordEvent { RecordEvent(const std::string& name, const DeviceContext* dev_ctx); @@ -94,6 +95,15 @@ struct RecordBlock { std::string name_; uint64_t start_ns_; }; +#else +// windows do not support profiler temporarily. +struct RecordEvent { + RecordEvent(const std::string& name, const DeviceContext* dev_ctx) {} +}; +struct RecordBlock { + explicit RecordBlock(int block_id) {} +}; +#endif // Return the event list of all threads. Assumed the returned value calls // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 66b776c08e..44416381c7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -105,6 +105,7 @@ __all__ = [ 'flatten', 'sequence_mask', 'stack', + 'unstack', ] @@ -5601,3 +5602,44 @@ def stack(x, axis=0): type='stack', inputs={'X': x}, outputs={'Y': out}, attrs={'axis': axis}) return out + + +def unstack(x, axis=0, num=None): + """ + **UnStack Layer** + + This layer unstacks input :code:`x` into several tensors along axis. + + If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`. + If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`, + and if :code:`x.shape[axis]` <= 0 or is unknown, :code:`ValueError` is + raised. + + Args: + x (Variable): Input variable. + axis (int): The axis along which the input is unstacked. + num (int|None): The number of output variables. + + Returns: + list(Variable): The unstacked variables. + + """ + + helper = LayerHelper('unstack', **locals()) + if num is None: + if axis is None or x.shape[axis] <= 0: + raise ValueError('unknown unstack number') + else: + num = x.shape[axis] + + outs = [] + for _ in num: + outs.append(helper.create_tmp_variable(x.dtype)) + + helper.append_op( + type='unstack', + inputs={'X': [x]}, + outputs={'Y': outs}, + attrs={'axis': axis, + 'num': num}) + return outs diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py new file mode 100644 index 0000000000..7cbac8928e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import numpy as np +import unittest + + +class TestUnStackOpBase(OpTest): + def initDefaultParameters(self): + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'float32' + + def initParameters(self): + pass + + def get_y_names(self): + y_names = [] + for i in range(self.input_dim[self.axis]): + y_names.append('y{}'.format(i)) + return y_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'unstack' + self.x = np.random.random(size=self.input_dim).astype(self.dtype) + + outs = np.split(self.x, self.input_dim[self.axis], self.axis) + new_shape = list(self.input_dim) + del new_shape[self.axis] + y_names = self.get_y_names() + tmp = [] + for i in range(self.input_dim[self.axis]): + tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + + self.inputs = {'X': self.x} + self.outputs = {'Y': tmp} + self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad('X', self.get_y_names()) + + +class TestStackOp3(TestUnStackOpBase): + def initParameters(self): + self.axis = -1 + + +class TestStackOp4(TestUnStackOpBase): + def initParameters(self): + self.axis = -3 + + +class TestStackOp5(TestUnStackOpBase): + def initParameters(self): + self.axis = 1 + + +class TestStackOp6(TestUnStackOpBase): + def initParameters(self): + self.axis = 2 + + +if __name__ == '__main__': + unittest.main() From 33b4def10a901ebb6c3f90c77ee4f1eecae43f3e Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 27 Aug 2018 05:41:47 +0000 Subject: [PATCH 116/140] add api.spec --- paddle/fluid/API.spec | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 37c2523c9f..71493409f4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -164,6 +164,7 @@ paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], vara paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)) paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) +paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) From c45cee0349a58ecd87e106fde958f8a78a066513 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 14:46:04 +0800 Subject: [PATCH 117/140] refine infershape and forward --- .../operators/fusion_seq_concat_fc_op.cc | 176 ++++++------------ 1 file changed, 54 insertions(+), 122 deletions(-) diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc index 203ebaf3e2..f61c822abf 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc @@ -23,91 +23,36 @@ namespace paddle { namespace operators { void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL, + "Inputs(X) of FusionSeqConcatFCOp should larger than 1."); PADDLE_ENFORCE(ctx->HasInput("FCWeight"), "Input(FCWeight) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FusionSeqConcatFC should not be null."); PADDLE_ENFORCE(ctx->HasOutput("FCOut"), "Output(FCOut) of FusionSeqConcatFC should not be null."); - // need check fc height = all inputs width sum - - auto x_dims = ctx->GetInputDim("X"); - const int M = x_dims[1]; - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); - - auto w_dims = ctx->GetInputDim("LSTMWeight"); - const int D = w_dims[1] / 4; - PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); - PADDLE_ENFORCE_EQ(w_dims[0], D + M, - "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); - - auto b_dims = ctx->GetInputDim("LSTMBias"); - PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); - - auto c_dims = ctx->GetInputDim("C0"); - PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); - PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); - if (ctx->HasInput("H0")) { - auto h_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); - } - - auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); - PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, - "Input(AttentionWeight)'s rank must be 2."); - PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - if (ctx->HasInput("AttentionBias")) { - auto atten_b_dims = ctx->GetInputDim("AttentionBias"); - PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, - "Input(AttentionBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, - "AttentionBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, - "AttentionBias shapes must be 1 * 1."); + auto ins_dims = ctx->GetInputsDim("X"); + auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D + PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2."); + const int D = w_dims[1]; + int sum = ins_dims[0][1]; + for (size_t i = 1; i < ins_dims.size(); ++i) { + sum += ins_dims[i][1]; } - - if (ctx->HasInput("AttentionScalar")) { - auto dims = ctx->GetInputDim("AttentionScalar"); - PADDLE_ENFORCE_EQ(dims.size(), 2, - "Input(AttentionScalar)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(sum, w_dims[0], + "FC height should be sum of all inputs width."); + if (ctx->HasInput("FCBias")) { + auto b_dims = ctx->GetInputDim("FCBias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D); + PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D); } - if (ctx->HasInput("AttentionScalarBias")) { - auto dims = ctx->GetInputDim("AttentionScalarBias"); - PADDLE_ENFORCE( - ctx->HasInput("AttentionScalar"), - "AttentionScalar should not be null when have AttentionScalarBias."); - PADDLE_ENFORCE_EQ(dims.size(), 2, - "Input(AttentionScalarBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); - } - - framework::DDim out_dims({x_dims[0], D}); - ctx->SetOutputDim("Hidden", out_dims); - ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("AttentionedX", {x_dims[0], 1}); - ctx->SetOutputDim("LSTMX", {1, M}); - ctx->SetOutputDim("LSTMOUT", {1, 4 * D}); - // AttentionFCOut should be reshape as (maxseqlen,1) in runtime - ctx->ShareLoD("X", "Hidden"); - ctx->ShareLoD("X", "Cell"); - - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", {ins_dims[0][0], D}); + // fcout should be reshape when run since can not get lod in infershape + // explicit share the ref lod + ctx->ShareLoD("X", "Out", 0); } framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( @@ -154,46 +99,46 @@ The concat axis should be 1. )DOC"); } -// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; -template -inline void bias_relu(const int n, const T* x, const T* bias, T* y) { - if (bias) { - math::vec_add_bias(n, *bias, x, y); - math::vec_relu(n, y, y); - } else { - math::vec_relu(n, x, y); - } -} - -template -inline void vec_softmax(const int n, const T* x, T* y) { - T scalar = x[0]; - // max - for (int i = 1; i < n; ++i) { - scalar = scalar < x[i] ? x[i] : scalar; - } - math::vec_add_bias(n, -scalar, x, y); // sub - math::vec_exp(n, y, y); // exp - // sum - scalar = T(0); - for (int i = 0; i < n; ++i) { - scalar += y[i]; - } - math::vec_scal(n, static_cast(1) / scalar, y); // scale -} - template class FusionSeqConcatFCKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* ins = ctx.Input("X"); + auto ins = ctx.MultiInput("X"); auto* w = ctx.Input("FCWeight"); auto* b = ctx.Input("FCBias"); - auto* out = ctx.Output("Out"); auto* fc_out = ctx.Output("FCOUT"); + auto* ref_in = ins[0]; + auto ref_lod = ref_in->lod(); + auto in1_lod = ins[1]->lod(); + auto ref_dims = ref_in->dims(); // T x M0 + auto in1_dims = ins[1]->dims(); // N x M1 + auto w_dims = w->dims(); + const int N = ref_lod[0].size() - 1; + const int total_T = ref_dims[0]; + const int M0 = ref_dims[1]; + const int M1 = in1_dims[1]; + const int D = w_dims[1]; + + // some check and fcout should be reshape here + // since infershape can not get lod info + PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1."); + PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1."); + PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N, + "Batch size of all inputs should be equal."); + PADDLE_ENFORCE_EQ(in1_lod[0][N], N, + "Seq_length of other inputs should be 1."); + PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size."); + for (size_t i = 2; i < ins.size(); ++i) { + PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N, + "All other inputs height should be equal"); + PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod, + "All other inputs should have same lod"); + } + fc_out->Resize({N, D}); + std::function fc_act; auto& fc_act_str = ctx.Attr("fc_activation"); if (platform::jit::MayIUse(platform::jit::avx)) { @@ -204,19 +149,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { fc_act = act_functor(fc_act_str); } - PADDLE_ENFORCE_GT(ins.size(), 1, "Input(X)'s size must larger than 1."); - auto* ref_in = ins[0]; - auto ref_in_lod = ref_in->lod(); - const int N = ref_in_lod[0].size() - 1; - auto ref_in_dims = ref_in->dims(); // T x M0 - auto w_dims = w->dims(); // (M0+M1+M2+..) x D - const int total_T = ref_in_dims[0]; - const int M0 = ref_in_dims[1]; - const int M1 = ins[1]->dims()[1]; - const int D = w_dims[1]; - - const T* ref_in_data = - ref_in->data(); // size should be check at infershape + const T* ref_in_data = ref_in->data(); const T* in1_data = ins[1]->data(); const T* w_data = w->data(); T* out_data = out->mutable_data(ctx.GetPlace()); @@ -226,11 +159,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { math::FCCompute(blas, total_T, D, M0, ref_in_data, w_data, out_data, b ? b->data() : NULL); w_data = w_data + M0 * D; - // first one use write on blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); w_data = w_data + M1 * D; - for (int i = 2; i < ins.size(); ++i) { + for (size_t i = 2; i < ins.size(); ++i) { // add on const T* in_data = ins[i]->data(); const int K = ins[i]->dims()[1]; @@ -240,7 +172,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { } for (int i = 0; i < N; ++i) { - int seq_len = ref_in_lod[0][i + 1] - ref_in_lod[0][i]; + int seq_len = ref_lod[0][i + 1] - ref_lod[0][i]; T* src = fc_out_data + i * D; for (int step = 0; step < seq_len; ++step) { blas.VADD(D, out_data, src, out_data); @@ -248,7 +180,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { } } - fc_act(out_dims[0] * out_dims[1], out_data, out_data); + fc_act(total_T * D, out_data, out_data); } }; From 1f09bc320c14e1364812580de2e72454f8192cdc Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 27 Aug 2018 15:04:24 +0800 Subject: [PATCH 118/140] Support data type int8_t . (#12841) * Support int8 type. --- paddle/fluid/API.spec | 2 +- paddle/fluid/framework/data_type.cc | 1 + paddle/fluid/framework/data_type.h | 3 +++ paddle/fluid/framework/framework.proto | 1 + paddle/fluid/operators/math/math_function.cc | 3 ++- paddle/fluid/operators/math/math_function.cu | 9 ++++---- paddle/fluid/pybind/protobuf.cc | 1 + paddle/fluid/pybind/pybind.cc | 3 +++ paddle/fluid/pybind/tensor_py.h | 2 +- python/paddle/fluid/framework.py | 2 ++ .../fluid/tests/unittests/test_tensor.py | 21 +++++++++++++++++++ .../fluid/tests/unittests/test_variable.py | 3 ++- 12 files changed, 43 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 37c2523c9f..d21cb2697b 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -378,7 +378,7 @@ paddle.fluid.LoDTensor.__init__ 1. __init__(self: paddle.fluid.core.LoDTensor, a paddle.fluid.LoDTensor.has_valid_recursive_sequence_lengths has_valid_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> bool paddle.fluid.LoDTensor.lod lod(self: paddle.fluid.core.LoDTensor) -> List[List[int]] paddle.fluid.LoDTensor.recursive_sequence_lengths recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> List[List[int]] -paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPlace) -> None 9. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPlace) -> None 10. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPlace) -> None 11. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPlace) -> None 12. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPlace) -> None 13. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPlace) -> None 14. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPlace) -> None 15. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPinnedPlace) -> None 16. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPinnedPlace) -> None 17. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPinnedPlace) -> None 18. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPinnedPlace) -> None 19. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPinnedPlace) -> None 20. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPinnedPlace) -> None 21. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPinnedPlace) -> None +paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CPUPlace) -> None 9. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPlace) -> None 10. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPlace) -> None 11. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPlace) -> None 12. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPlace) -> None 13. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPlace) -> None 14. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPlace) -> None 15. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPlace) -> None 16. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPlace) -> None 17. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPinnedPlace) -> None 18. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPinnedPlace) -> None 19. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPinnedPlace) -> None 20. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPinnedPlace) -> None 21. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPinnedPlace) -> None 22. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPinnedPlace) -> None 23. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPinnedPlace) -> None 24. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPinnedPlace) -> None paddle.fluid.LoDTensor.set_lod set_lod(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None paddle.fluid.LoDTensor.set_recursive_sequence_lengths set_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None paddle.fluid.LoDTensor.shape shape(self: paddle.fluid.core.Tensor) -> List[int] diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc index 1a9ce746ea..28f3da88fa 100644 --- a/paddle/fluid/framework/data_type.cc +++ b/paddle/fluid/framework/data_type.cc @@ -64,6 +64,7 @@ static DataTypeMap* InitDataTypeMap() { RegType(size_t, proto::VarType::SIZE_T); RegType(int16_t, proto::VarType::INT16); RegType(uint8_t, proto::VarType::UINT8); + RegType(int8_t, proto::VarType::INT8); #undef RegType return retv; diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index f8c72ffc89..84691a2059 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -54,6 +54,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { case proto::VarType::INT16: visitor.template operator()(); break; + case proto::VarType::INT8: + visitor.template operator()(); + break; default: PADDLE_THROW("Not supported %d", type); } diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 2cf14bd371..c658843581 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -107,6 +107,7 @@ message VarType { // Tensor is used in C++. SIZE_T = 19; UINT8 = 20; + INT8 = 21; // Other types that may need additional descriptions LOD_TENSOR = 7; diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index c3387be6da..9a6e646b28 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -41,7 +41,8 @@ template struct SetConstant; template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ - template struct Transpose; + template struct Transpose; \ + template struct Transpose; DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(2); diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index d5af718723..12d1baa8fb 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -33,10 +33,11 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; -#define DEFINE_GPU_TRANS(RANK) \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; +#define DEFINE_GPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(2); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index c2137ec6d7..f21f8d23f9 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -234,6 +234,7 @@ void BindVarDsec(pybind11::module *m) { pybind11::enum_(var_desc, "VarType", "") .value("BOOL", pd::proto::VarType::BOOL) .value("UINT8", pd::proto::VarType::UINT8) + .value("INT8", pd::proto::VarType::INT8) .value("INT16", pd::proto::VarType::INT16) .value("INT32", pd::proto::VarType::INT32) .value("INT64", pd::proto::VarType::INT64) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6773465923..5b20b87174 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -130,6 +130,7 @@ PYBIND11_PLUGIN(core) { .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) + .def("set", PyCPUTensorSetFromArray) #ifdef PADDLE_WITH_CUDA .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) @@ -138,6 +139,7 @@ PYBIND11_PLUGIN(core) { .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) + .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) @@ -145,6 +147,7 @@ PYBIND11_PLUGIN(core) { .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) + .def("set", PyCUDAPinnedTensorSetFromArray) #endif .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("_set_float_element", TensorSetElement) diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 3e2ea1ef88..51614a6a3d 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -97,7 +97,7 @@ struct CastToPyBufferImpl { inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) { auto buffer_info = details::CastToPyBufferImpl()(tensor); + uint8_t, int8_t, platform::float16>()(tensor); return buffer_info; } diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index febb750ee1..fbe766336b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -95,6 +95,8 @@ def convert_np_dtype_to_dtype_(np_dtype): return core.VarDesc.VarType.INT16 elif dtype == np.uint8: return core.VarDesc.VarType.UINT8 + elif dtype == np.int8: + return core.VarDesc.VarType.INT8 else: raise ValueError("Not supported numpy dtype %s" % dtype) diff --git a/python/paddle/fluid/tests/unittests/test_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor.py index e9d0f8a019..1822957c23 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor.py @@ -59,6 +59,27 @@ class TestTensor(unittest.TestCase): self.assertAlmostEqual(1.0, tensor_array_2[3, 9]) self.assertAlmostEqual(2.0, tensor_array_2[19, 11]) + def test_int8_tensor(self): + scope = core.Scope() + var = scope.var("int8_tensor") + cpu_tensor = var.get_tensor() + tensor_array = numpy.random.randint( + -127, high=128, size=[100, 200], dtype=numpy.int8) + place = core.CPUPlace() + cpu_tensor.set(tensor_array, place) + cpu_tensor_array_2 = numpy.array(cpu_tensor) + self.assertAlmostEqual(cpu_tensor_array_2.all(), tensor_array.all()) + + if core.is_compiled_with_cuda(): + cuda_tensor = var.get_tensor() + tensor_array = numpy.random.randint( + -127, high=128, size=[100, 200], dtype=numpy.int8) + place = core.CUDAPlace(0) + cuda_tensor.set(tensor_array, place) + cuda_tensor_array_2 = numpy.array(cuda_tensor) + self.assertAlmostEqual(cuda_tensor_array_2.all(), + tensor_array.all()) + def test_int_lod_tensor(self): place = core.CPUPlace() scope = core.Scope() diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index b0830e130d..4f3c26ca7b 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -31,7 +31,8 @@ class TestVariable(unittest.TestCase): self.assertEqual(DT.INT16, convert("int16")) self.assertEqual(DT.INT64, convert("int64")) self.assertEqual(DT.BOOL, convert("bool")) - self.assertRaises(ValueError, lambda: convert("int8")) + self.assertEqual(DT.INT8, convert("int8")) + self.assertEqual(DT.UINT8, convert("uint8")) def test_var(self): b = default_main_program().current_block() From 1a67061fee99e3205cf40cfc4d1153198bd371fa Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 24 Aug 2018 13:49:41 +0800 Subject: [PATCH 119/140] graph to program pass fix a few other things --- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/graph_to_program_pass.cc | 65 +++++++++++++++++++ .../framework/ir/graph_to_program_pass.h | 30 +++++++++ .../ir/graph_to_program_pass_test.cc | 21 ++++++ paddle/fluid/framework/op_desc.cc | 4 +- paddle/fluid/framework/program_desc.cc | 6 ++ paddle/fluid/framework/program_desc.h | 2 + paddle/fluid/inference/CMakeLists.txt | 2 +- paddle/fluid/inference/io.cc | 1 - paddle/fluid/inference/tests/test_helper.h | 14 ++++ paddle/fluid/operators/parallel_do_op.cc | 1 + .../test_memopt_image_classification_train.py | 4 +- .../test_memopt_machine_translation.py | 4 +- 14 files changed, 151 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/framework/ir/graph_to_program_pass.cc create mode 100644 paddle/fluid/framework/ir/graph_to_program_pass.h create mode 100644 paddle/fluid/framework/ir/graph_to_program_pass_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2c62d4ed6b..0668ff43c8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) if(WITH_DISTRIBUTE) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) endif() if (NOT WIN32) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index da0955a9a0..9300573d7f 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -3,6 +3,7 @@ cc_library(graph SRCS graph.cc DEPS node) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) +cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper) cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter) @@ -12,5 +13,6 @@ cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) +cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto) diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.cc b/paddle/fluid/framework/ir/graph_to_program_pass.cc new file mode 100644 index 0000000000..414d8f79b1 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_to_program_pass.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/graph_to_program_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" + +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr GraphToProgramPass::ApplyImpl( + std::unique_ptr graph) const { + ProgramDesc& program = Get("program"); + + std::unique_ptr program_pb( + new proto::ProgramDesc(*program.Proto())); + + auto block = program_pb->mutable_blocks(kRootBlockIndex); + block->clear_vars(); + std::unordered_set visited_vars; + for (ir::Node* n : graph->Nodes()) { + if (n->NodeType() == ir::Node::Type::kVariable) { + if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) { + visited_vars.insert(n->Var()->Name()); + block->add_vars()->MergeFrom(*n->Var()->Proto()); + } + } + } + + block->clear_ops(); + std::vector nodes = TopologySortOperations(*graph); + for (ir::Node* n : nodes) { + if (!n->Op()) { + continue; + } + block->add_ops()->MergeFrom(*n->Op()->Proto()); + } + + program.CopyFrom(*program_pb); + return graph; +} +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(graph_to_program_pass, paddle::framework::ir::GraphToProgramPass); diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.h b/paddle/fluid/framework/ir/graph_to_program_pass.h new file mode 100644 index 0000000000..124ec5a8e7 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_to_program_pass.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class GraphToProgramPass : public Pass { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc new file mode 100644 index 0000000000..3adbf888a8 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc @@ -0,0 +1,21 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/graph_to_program_pass.h" + +namespace paddle { +namespace framework { +namespace ir {} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 122dc161b4..59b6007284 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -132,7 +132,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) std::string attr_name = attr.name(); // The sub_block referred to by the BLOCK attr hasn't been added // to ProgramDesc class yet, we skip setting BLOCK attr here. - if (attr.type() != proto::AttrType::BLOCK) { + // TODO(paddle-dev): Need copy fix this to copy Block as well. + if (attr.type() != proto::AttrType::BLOCK && + attr.type() != proto::AttrType::BLOCKS) { attrs_[attr_name] = GetAttrValue(attr); } } diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 344c001a69..c2b91069d9 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { InitFromProto(); } +void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) { + blocks_.clear(); + desc_ = desc; + InitFromProto(); +} + ProgramDesc::ProgramDesc(const std::string &binary_str) { PADDLE_ENFORCE(desc_.ParseFromString(binary_str), "Fail to parse program_desc from binary string."); diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index f3afc85eb9..a0e81cade1 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -53,6 +53,8 @@ class ProgramDesc { void Flush(); + void CopyFrom(const proto::ProgramDesc &desc); + proto::ProgramDesc *Proto(); // The output variable of feed_op is referenced as feed_target. diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index ba7645aa02..a4f6364ae5 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor) # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal? cc_library(paddle_fluid_api SRCS io.cc - DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} graph_to_program_pass) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 181868977d..f29b190a73 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -138,7 +138,6 @@ std::unique_ptr Load( std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); - LoadPersistables(executor, scope, *main_program, "", param_filename); return main_program; } diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 695790a37d..94f0550df5 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/profiler.h" @@ -135,6 +136,15 @@ std::vector> GetFeedTargetShapes( return feed_target_shapes; } +void Compile(paddle::framework::ProgramDesc* program) { + std::unique_ptr g( + new paddle::framework::ir::Graph(*program)); + auto pass = paddle::framework::ir::PassRegistry::Instance().Get( + "graph_to_program_pass"); + pass->SetNotOwned("program", program); + pass->Apply(std::move(g)); +} + template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, @@ -172,6 +182,8 @@ void TestInference(const std::string& dirname, paddle::platform::DeviceContextPool::Instance().Get(place)); inference_program = InitProgram(&executor, scope, dirname, is_combined); } + Compile(inference_program.get()); + // Disable the profiler and print the timing information paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, "load_program_profiler"); @@ -249,3 +261,5 @@ void TestInference(const std::string& dirname, delete scope; } + +USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index eb09470f37..97c36a83fc 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { grad->SetInput(framework::GradVarName(output_param), og_names); } } + grad->SetInput("Communicator", {"nccl_com__do_not_change_"}); grad->SetAttrMap(this->Attrs()); grad->SetBlockAttr(kParallelBlock, grad_block_[0]); diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py index 3951e7b8ca..a231bbfbc8 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py @@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost) batch_size = fluid.layers.create_tensor(dtype='int64') batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size) -# fluid.memory_optimize(fluid.default_main_program(), level=0) -fluid.release_memory(fluid.default_main_program()) +fluid.memory_optimize(fluid.default_main_program(), level=0) +# fluid.release_memory(fluid.default_main_program()) BATCH_SIZE = 16 PASS_NUM = 1 diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py index 1ad51936b5..e520c89650 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py @@ -92,8 +92,8 @@ def main(): optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) optimizer.minimize(avg_cost) - # fluid.memory_optimize(fluid.default_main_program()) - fluid.release_memory(fluid.default_main_program()) + fluid.memory_optimize(fluid.default_main_program()) + # fluid.release_memory(fluid.default_main_program()) # fix the order of training data train_data = paddle.batch( From 880cb8c4c33744379f5e519252302c23a6dde030 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 24 Aug 2018 13:51:43 +0800 Subject: [PATCH 120/140] clean --- paddle/fluid/inference/io.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index f29b190a73..181868977d 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -138,6 +138,7 @@ std::unique_ptr Load( std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); + LoadPersistables(executor, scope, *main_program, "", param_filename); return main_program; } From 6fdb7f534822711e6873de13a292bdf08fcd18ed Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 24 Aug 2018 16:26:41 +0800 Subject: [PATCH 121/140] add test --- .../ir/graph_to_program_pass_test.cc | 93 ++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc index 3adbf888a8..9aa2ea7664 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc +++ b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc @@ -14,8 +14,99 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_to_program_pass.h" +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/program_desc.h" + namespace paddle { namespace framework { -namespace ir {} // namespace ir +namespace ir { + +void BuildNoCircleGraph(Graph* g) { + OpDesc op1; + op1.SetType("op1"); + OpDesc op2; + op2.SetType("op2"); + OpDesc op3; + op3.SetType("op3"); + OpDesc op4; + op4.SetType("op4"); + OpDesc op5; + op5.SetType("op5"); + VarDesc var1("var1"); + VarDesc var2("var2"); + VarDesc var3("var3"); + VarDesc var4("var4"); + + ir::Node* o1 = g->CreateOpNode(&op1); + ir::Node* o2 = g->CreateOpNode(&op2); + ir::Node* o3 = g->CreateOpNode(&op3); + ir::Node* o4 = g->CreateOpNode(&op4); + ir::Node* o5 = g->CreateOpNode(&op5); + ir::Node* v1 = g->CreateVarNode(&var1); + ir::Node* v2 = g->CreateVarNode(&var2); + ir::Node* v3 = g->CreateVarNode(&var3); + ir::Node* v4 = g->CreateVarNode(&var4); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + v2->inputs.push_back(o2); + // o2->v3->o5 + o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + v4->outputs.push_back(o5); +} + +TEST(GraphToProgramPass, Basic) { + ProgramDesc prog; + std::unique_ptr g(new Graph(prog)); + BuildNoCircleGraph(g.get()); + + auto pass = paddle::framework::ir::PassRegistry::Instance().Get( + "graph_to_program_pass"); + + ProgramDesc compiled_prog; + pass->SetNotOwned("program", &compiled_prog); + pass->Apply(std::move(g)); + std::vector ops = compiled_prog.Block(0).AllOps(); + compiled_prog.Flush(); + LOG(ERROR) << compiled_prog.Proto()->DebugString(); + EXPECT_EQ(ops[0]->Type(), "op1"); + EXPECT_EQ(ops[1]->Type(), "op2"); + if (ops[2]->Type() == "op3") { + EXPECT_EQ(ops[3]->Type(), "op4"); + } else if (ops[2]->Type() == "op4") { + EXPECT_EQ(ops[3]->Type(), "op3"); + } + EXPECT_EQ(ops[4]->Type(), "op5"); + + std::unordered_set vars; + for (VarDesc* v : compiled_prog.Block(0).AllVars()) { + vars.insert(v->Name()); + } + EXPECT_TRUE(vars.find("var1") != vars.end()); + EXPECT_TRUE(vars.find("var2") != vars.end()); + EXPECT_TRUE(vars.find("var3") != vars.end()); +} +} // namespace ir } // namespace framework } // namespace paddle + +USE_PASS(graph_to_program_pass); From 08352fe56a1e915b415775d772ad350bea85c4a5 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 24 Aug 2018 17:00:23 +0800 Subject: [PATCH 122/140] fix --- paddle/fluid/framework/ir/graph_to_program_pass_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc index 9aa2ea7664..88ad17a0c6 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc +++ b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc @@ -86,8 +86,6 @@ TEST(GraphToProgramPass, Basic) { pass->SetNotOwned("program", &compiled_prog); pass->Apply(std::move(g)); std::vector ops = compiled_prog.Block(0).AllOps(); - compiled_prog.Flush(); - LOG(ERROR) << compiled_prog.Proto()->DebugString(); EXPECT_EQ(ops[0]->Type(), "op1"); EXPECT_EQ(ops[1]->Type(), "op2"); if (ops[2]->Type() == "op3") { From cf547e2714fec03a77d2f2aeb1e676d83a66a9f9 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 27 Aug 2018 13:57:37 +0800 Subject: [PATCH 123/140] fix program_desc feed/fetch names' order. --- paddle/fluid/framework/program_desc.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index c2b91069d9..a63944eaee 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -117,10 +117,16 @@ void ProgramDesc::InitFromProto() { const std::vector ProgramDesc::GetFeedTargetNames() { auto &global_block = Block(0); + // The order of feed_target_names must follow the index specified in `col`. + // since feed operator's order doesn't necessary follow 'col'. std::vector feed_target_names; for (auto *op : global_block.AllOps()) { if (op->Type() == kFeedOpType) { - feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]); + int col = boost::get(op->GetAttr("col")); + if (col >= feed_target_names.size()) { + feed_target_names.resize(col + 1); + } + feed_target_names[col] = op->Output("Out")[0]; } } return feed_target_names; @@ -128,10 +134,16 @@ const std::vector ProgramDesc::GetFeedTargetNames() { const std::vector ProgramDesc::GetFetchTargetNames() { auto &global_block = Block(0); + // The order of fetch_target_names must follow the index specified in `col`. + // since fetch operator's order doesn't necessary follow 'col'. std::vector fetch_target_names; for (auto *op : global_block.AllOps()) { if (op->Type() == kFetchOpType) { - fetch_target_names.push_back(op->Input("X")[0]); + int col = boost::get(op->GetAttr("col")); + if (col >= fetch_target_names.size()) { + fetch_target_names.resize(col + 1); + } + fetch_target_names[col] = op->Input("X")[0]; } } return fetch_target_names; From 0f0d48230c07b7c5ce2f8ecc1138d360c67fa8ce Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 14:54:20 +0800 Subject: [PATCH 124/140] add fusion seq_concat_fc op test --- .../unittests/test_fusion_seq_concat_fc_op.py | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py diff --git a/python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py b/python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py new file mode 100644 index 0000000000..a389b605f0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py @@ -0,0 +1,139 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from test_fusion_lstm_op import fc, ACTIVATION + + +def fusion_seqexpand_concat_fc(xs, lod, w, b, fc_act): + + T = sum(lod[0]) + N = len(lod[0]) + num_inputs = len(xs) + D = w.shape[1] + + expanded_inputs = [xs[0]] + for i in range(num_inputs - 1): + x = xs[i + 1] + assert x.shape[0] == N + expanded = np.repeat(x, lod[0], axis=0) + assert expanded.shape[0] == T + assert expanded.shape[1] == x.shape[1] + expanded_inputs.append(expanded) + + fc_input = np.concatenate(expanded_inputs, axis=1) + assert fc_input.shape[0] == T + assert fc_input.shape[1] == w.shape[0] + fc_out = fc(fc_input, w, b) + fc_out = fc_act(fc_out) + assert fc_out.shape[0] == T + assert fc_out.shape[1] == D + return fc_out + + +class TestFusionSeqExpandConcatFCOp(OpTest): + def set_conf(self): + pass + + def setUp(self): + self.op_type = 'fusion_seq_concat_fc' + self.lod = [[3, 5, 8, 2]] + self.inputs_M = [15, 10, 10] + self.D = 20 + self.with_bias = True + self.fc_act = 'relu' + self.set_conf() + + T = sum(self.lod[0]) + bs = len(self.lod[0]) + num_inputs = len(self.inputs_M) + + x0 = np.random.normal(size=(T, self.inputs_M[0])).astype('float32') + xs = [x0] + for i in range(num_inputs - 1): + xi = np.random.normal(size=(bs, + self.inputs_M[i + 1])).astype('float32') + xs.append(xi) + + # fc weight and bias + w = np.random.normal(size=(sum(self.inputs_M), + self.D)).astype('float32') + b = np.random.normal(size=( + 1, self.D)).astype('float32') if self.with_bias else np.zeros( + (1, self.D)).astype('float32') + + out = fusion_seqexpand_concat_fc(xs, self.lod, w, b, + ACTIVATION[self.fc_act]) + + self.inputs = {'X': [(x0, self.lod)], 'FCWeight': w} + normal_lod = [i for i in range(bs + 1)] + for i in range(num_inputs - 1): + self.inputs['X'].append((xs[i + 1], normal_lod)) + + if self.with_bias: + self.inputs['FCBias'] = b + + self.outputs = {'Out': (out, self.lod)} + self.attrs = {'fc_activation': self.fc_act, } + + def test_check_output(self): + self.check_output() + + +class TestFusionSECFCOpNonBias(TestFusionSeqExpandConcatFCOp): + def set_conf(self): + self.with_bias = False + + +class TestFusionSECFCOpNonAct(TestFusionSeqExpandConcatFCOp): + def set_conf(self): + self.fc_act = 'identity' + + +class TestFusionSECFCOpMD1(TestFusionSeqExpandConcatFCOp): + def set_conf(self): + self.inputs_M = [3, 4, 2, 1, 5] + self.D = 8 + + +class TestFusionSECFCOpMD2(TestFusionSeqExpandConcatFCOp): + def set_conf(self): + self.lod = [[5, 6]] + self.inputs_M = [1, 1] + + +class TestFusionSECFCOpBS1_1(TestFusionSeqExpandConcatFCOp): + def set_conf(self): + self.lod = [[1]] + self.inputs_M = [3, 4, 2] + + +class TestFusionSECFCOpBS1_2(TestFusionSeqExpandConcatFCOp): + def set_conf(self): + self.lod = [[1]] + self.inputs_M = [3, 4] + + +class TestFusionSECFCOpBS1_3(TestFusionSeqExpandConcatFCOp): + def set_conf(self): + self.lod = [[5]] + self.inputs_M = [6, 3] + + +if __name__ == '__main__': + unittest.main() From 02909335e9208dc9c1a8835b6e25b708ea366005 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 16:01:29 +0800 Subject: [PATCH 125/140] rename fusion seq_concat_fc to fusion seqexpand_concat_fc --- ...op.cc => fusion_seqexpand_concat_fc_op.cc} | 41 +++++++++++-------- ...c_op.h => fusion_seqexpand_concat_fc_op.h} | 7 ++-- ... => test_fusion_seqexpand_concat_fc_op.py} | 2 +- 3 files changed, 28 insertions(+), 22 deletions(-) rename paddle/fluid/operators/{fusion_seq_concat_fc_op.cc => fusion_seqexpand_concat_fc_op.cc} (85%) rename paddle/fluid/operators/{fusion_seq_concat_fc_op.h => fusion_seqexpand_concat_fc_op.h} (82%) rename python/paddle/fluid/tests/unittests/{test_fusion_seq_concat_fc_op.py => test_fusion_seqexpand_concat_fc_op.py} (98%) diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc similarity index 85% rename from paddle/fluid/operators/fusion_seq_concat_fc_op.cc rename to paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc index f61c822abf..641851585d 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h" +#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h" #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" @@ -22,15 +22,20 @@ limitations under the License. */ namespace paddle { namespace operators { -void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL, - "Inputs(X) of FusionSeqConcatFCOp should larger than 1."); - PADDLE_ENFORCE(ctx->HasInput("FCWeight"), - "Input(FCWeight) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("FCOut"), - "Output(FCOut) of FusionSeqConcatFC should not be null."); +void FusionSeqExpandConcatFCOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE_GT( + ctx->Inputs("X").size(), 1UL, + "Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1."); + PADDLE_ENFORCE( + ctx->HasInput("FCWeight"), + "Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FusionSeqExpandConcatFCOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("FCOut"), + "Output(FCOut) of FusionSeqExpandConcatFCOp should not be null."); auto ins_dims = ctx->GetInputsDim("X"); auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D @@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Out", 0); } -framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( +framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } -void FusionSeqConcatFCOpMaker::Make() { +void FusionSeqExpandConcatFCOpMaker::Make() { AddInput("X", "(LoDTensor) input LodDTensors, the first one must be have ref lod " "for sequence expand, and the rest input should have same lod.") @@ -100,7 +105,7 @@ The concat axis should be 1. } template -class FusionSeqConcatFCKernel : public framework::OpKernel { +class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; @@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp, - ops::FusionSeqConcatFCOpMaker, +REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp, + ops::FusionSeqExpandConcatFCOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc, - ops::FusionSeqConcatFCKernel, - ops::FusionSeqConcatFCKernel); +REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc, + ops::FusionSeqExpandConcatFCOpKernel, + ops::FusionSeqExpandConcatFCOpKernel); diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.h b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h similarity index 82% rename from paddle/fluid/operators/fusion_seq_concat_fc_op.h rename to paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h index 66ac48f4c1..f78e820f60 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.h +++ b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; -class FusionSeqConcatFCOp : public framework::OperatorWithKernel { +class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -32,7 +32,8 @@ class FusionSeqConcatFCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override; }; -class FusionSeqConcatFCOpMaker : public framework::OpProtoAndCheckerMaker { +class FusionSeqExpandConcatFCOpMaker + : public framework::OpProtoAndCheckerMaker { public: void Make() override; }; diff --git a/python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py b/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py similarity index 98% rename from python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py rename to python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py index a389b605f0..7baf39eb3f 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py @@ -51,7 +51,7 @@ class TestFusionSeqExpandConcatFCOp(OpTest): pass def setUp(self): - self.op_type = 'fusion_seq_concat_fc' + self.op_type = 'fusion_seqexpand_concat_fc' self.lod = [[3, 5, 8, 2]] self.inputs_M = [15, 10, 10] self.D = 20 From 50d3e6e96bb03bc2352ce9e69df206dcec0e3863 Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Thu, 2 Aug 2018 13:05:23 +0200 Subject: [PATCH 126/140] Reusing primitives for forward Batch Norm operator --- .../fluid/operators/batch_norm_mkldnn_op.cc | 421 +++++++++++++----- 1 file changed, 303 insertions(+), 118 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index 9ab2179b5f..cd1fb754a1 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -37,6 +37,122 @@ struct bn_type_traits { using op_prim = typename op_type::primitive_desc; }; +class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { + public: + BatchNormMKLDNNHandler( + std::shared_ptr batch_norm_pd, + const platform::MKLDNNDeviceContext &dev_ctx, mkldnn::engine engine, + const std::string &base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) { + batch_norm_pd_ = batch_norm_pd; + } + + std::shared_ptr AcquireScaleshiftMemoryFromPrimitive(void *ptr) { + return this->AcquireMemoryFromPrimitive( + batch_norm_pd_->weights_primitive_desc(), ptr, "@scaleshift_mem_p"); + } + + std::shared_ptr AcquireMeanMemoryFromPrimitive(void *ptr) { + return this->AcquireMemoryFromPrimitive( + batch_norm_pd_->mean_primitive_desc(), ptr, "@mean_mem_p"); + } + + std::shared_ptr AcquireVarianceMemoryFromPrimitive(void *ptr) { + return this->AcquireMemoryFromPrimitive( + batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); + } + + std::shared_ptr AcquireTestBatchNormFwd( + std::shared_ptr src_memory, + const mkldnn::primitive::at &mean_memory, + const mkldnn::primitive::at &variance_memory, + std::shared_ptr scaleshift_memory, + std::shared_ptr dst_memory) { + auto prim_key = key_ + "@batch_norm_p"; + auto batch_norm_p = + std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); + PADDLE_ENFORCE( + (batch_norm_p != nullptr) || (is_reusing_ == false), + "Fail to find batch norm primitive for test in device context"); + if (batch_norm_p == nullptr) { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, mean_memory, variance_memory, + *scaleshift_memory, *dst_memory); + + dev_ctx_.SetBlob(prim_key, batch_norm_p); + } else { + is_reusing_ = true; + } + return batch_norm_p; + } + + std::shared_ptr AcquireTrainingBatchNormFwd( + std::shared_ptr src_memory, + std::shared_ptr scaleshift_memory, + std::shared_ptr dst_memory, std::shared_ptr mean_memory, + std::shared_ptr variance_memory) { + auto prim_key = key_ + "@batch_norm_p"; + auto batch_norm_p = + std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); + PADDLE_ENFORCE( + (batch_norm_p != nullptr) || (is_reusing_ == false), + "Fail to find batch norm primitive for training in device context"); + if (batch_norm_p == nullptr) { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory, + *mean_memory, *variance_memory); + + dev_ctx_.SetBlob(prim_key, batch_norm_p); + } else { + is_reusing_ = true; + } + return batch_norm_p; + } + // + static std::string GetHash(const memory::dims &input_dims, float epsilon, + unsigned flag, bool is_test, memory::format format, + const std::string &suffix) { + auto dims2str = [](const memory::dims &operand_dims) { + std::string dstr = ""; + for (size_t i = 0; i < operand_dims.size(); ++i) { + dstr += std::to_string(operand_dims[i]) + "-"; + } + return dstr; + }; + return dims2str(input_dims) + std::to_string(epsilon) + + std::to_string(flag) + std::to_string(is_test) + + std::to_string(format) + suffix; + } + + private: + std::shared_ptr batch_norm_pd_; +}; + +std::string gethash(const memory::dims &input_dims, float epsilon, + unsigned flag, bool is_test, memory::format format) { + auto dims2str = [](const memory::dims &operand_dims) { + std::string dstr = ""; + for (size_t i = 0; i < operand_dims.size(); ++i) { + dstr += std::to_string(operand_dims[i]) + "-"; + } + return dstr; + }; + return dims2str(input_dims) + std::to_string(epsilon) + std::to_string(flag) + + std::to_string(is_test) + std::to_string(format); +} + +std::shared_ptr UpdateMemoryData( + const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key, + void *new_ptr) { + auto mem = std::static_pointer_cast(dev_ctx.GetBlob(key)); + PADDLE_ENFORCE( + mem != nullptr, + (std::string("Fail to find memory in device context [key: ") + key + "]") + .c_str()); + mem->set_data_handle(new_ptr); + return mem; +} + template void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, Container *c) { @@ -48,15 +164,6 @@ void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end)))); } -template -void run_batch_norm_op(Args &&... args) { - Op batch_norm_op{args...}; - - std::vector pipeline; - pipeline.push_back(batch_norm_op); - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); -} - } // namespace template @@ -110,6 +217,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); const unsigned int ic = scale_tz[0]; + // MKLDNN requires a single piece of memory for scale and shift/bias data + const size_t scaleshift_size = 2 * ic; + std::vector scaleshift_data; + scaleshift_data.reserve(scaleshift_size); + + copy_to_weights(scale->data(), scale->data() + ic, shift->data(), + shift->data() + ic, &scaleshift_data); + unsigned flags = mkldnn::use_scale_shift; if (is_test) flags |= mkldnn::use_global_stats; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; @@ -118,64 +233,70 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::format input_format = platform::MKLDNNFormatForSize(src_tz.size(), x->format()); - auto src_memory = memory( - {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, - to_void_cast(x_data)); + // keys for backward pass + const std::string key = BatchNormMKLDNNHandler::GetHash( + src_tz, epsilon, flags, is_test, input_format, + ctx.op().Output("SavedMean")); + const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; + + auto user_src_md = platform::MKLDNNMemDesc( + {src_tz}, platform::MKLDNNGetDataType(), input_format); // create primitive descriptor for batch norm forward using bn_fwd_types = bn_type_traits; - auto batch_norm_fwd_desc = bn_fwd_types::op_desc{ - propagation, src_memory.get_primitive_desc().desc(), epsilon, flags}; - std::shared_ptr batch_norm_fwd_pd = - std::shared_ptr( - new batch_norm_fwd::primitive_desc(batch_norm_fwd_desc, - mkldnn_engine)); - - // Save the pd to be used in backward pass - const std::string key = ctx.op().Output("SavedMean"); - const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; + auto batch_norm_fwd_desc = + bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags}; + auto batch_norm_fwd_pd = std::make_shared( + batch_norm_fwd_desc, mkldnn_engine); + // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd); - // MKLDNN requires a single piece of memory for scale and shift/bias data - const size_t scaleshift_size = 2 * ic; - std::vector scaleshift_data; - scaleshift_data.reserve(scaleshift_size); + BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine, + key); - copy_to_weights(scale->data(), scale->data() + ic, shift->data(), - shift->data() + ic, &scaleshift_data); + auto src_memory = + handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data)); // crate mkldnn memory for weights(scale/shift) - auto scaleshift_memory = memory(batch_norm_fwd_pd->weights_primitive_desc(), - scaleshift_data.data()); + auto scaleshift_memory = + handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data()); // create mkldnn memory for output y tensor - auto dst_memory = memory(batch_norm_fwd_pd->dst_primitive_desc(), y_data); + auto dst_memory = handler.AcquireDstMemory( + batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data); + std::shared_ptr batch_norm_p; if (is_test) { // create mkldnn memory for stats (as input) - auto mean_memory = memory(batch_norm_fwd_pd->mean_primitive_desc(), - to_void_cast(mean_data)); - auto variance_memory = - memory(batch_norm_fwd_pd->variance_primitive_desc(), - to_void_cast(variance_data)); - - run_batch_norm_op( - *batch_norm_fwd_pd, src_memory, - (const mkldnn::primitive::at &)mean_memory, - (const mkldnn::primitive::at &)variance_memory, scaleshift_memory, + std::shared_ptr mean_memory = + handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data)); + std::shared_ptr variance_memory = + handler.AcquireVarianceMemoryFromPrimitive( + to_void_cast(variance_data)); + + batch_norm_p = handler.AcquireTestBatchNormFwd( + src_memory, (const mkldnn::primitive::at &)*mean_memory, + (const mkldnn::primitive::at &)*variance_memory, scaleshift_memory, dst_memory); } else { // create mkldnn memory for stats (as output) - auto mean_memory = - memory(batch_norm_fwd_pd->mean_primitive_desc(), batch_mean_data); - auto variance_memory = memory( - batch_norm_fwd_pd->variance_primitive_desc(), batch_variance_data); - - run_batch_norm_op(*batch_norm_fwd_pd, src_memory, - scaleshift_memory, dst_memory, - mean_memory, variance_memory); + std::shared_ptr mean_memory = + handler.AcquireMeanMemoryFromPrimitive(batch_mean_data); + std::shared_ptr variance_memory = + handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data); + + batch_norm_p = handler.AcquireTrainingBatchNormFwd( + src_memory, scaleshift_memory, dst_memory, mean_memory, + variance_memory); } + y->set_layout(DataLayout::kMKLDNN); + y->set_format(platform::GetMKLDNNFormat(*dst_memory)); + + std::vector pipeline; + pipeline.push_back(*batch_norm_p); + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + if (!is_test) { // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib @@ -192,10 +313,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { running_variance_e = variance_e * momentum + batch_variance_e * one_minus_momentum; } - - y->set_layout(DataLayout::kMKLDNN); - y->set_format( - (memory::format)dst_memory.get_primitive_desc().desc().data.format); } }; @@ -242,61 +359,47 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { const unsigned int ic = scale_tz[0]; - // Retrieve bn_fwd_pd from device context - const std::string key = ctx.op().Input("SavedMean"); - const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; - auto batch_norm_fwd_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_batch_norm_fwd_pd)); - PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr, - "Fail to find batch_norm_fwd_pd in device context"); - using bn_bwd_types = bn_type_traits; - // create mkldnn memory from input diff_y tensor - mkldnn::memory::format dst_format = platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); - auto user_diff_dst_memory = memory( - {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, - to_void_cast(diff_y_data)); - - // create mkldnn memory from input x tensor mkldnn::memory::format input_format = platform::MKLDNNFormatForSize(src_tz.size(), x->format()); - auto src_memory = memory( - {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, - to_void_cast(x_data)); + unsigned flags = mkldnn::use_scale_shift; + + // keys from forward pass + const std::string key = BatchNormMKLDNNHandler::GetHash( + src_tz, epsilon, flags, false, input_format, + ctx.op().Input("SavedMean")); + const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; - // for diff_dst, try to use same format as dst in forward pass - auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); - auto diff_dst_md = diff_dst_pd.desc(); + // keys for primitives reuse + const std::string key_with_hash = + key + gethash(src_tz, epsilon, flags, false, input_format); + const std::string key_batch_norm_bwd_p = + key_with_hash + "@batch_norm_bwd_p"; + const std::string key_batch_norm_src_mem_p = + key_with_hash + "@batch_norm_bwd_src_mem_p"; + const std::string key_batch_norm_mean_mem_p = + key_with_hash + "@batch_norm_bwd_mean_mem_p"; + const std::string key_batch_norm_variance_mem_p = + key_with_hash + "@batch_norm_bwd_variance_mem_p"; + const std::string key_batch_norm_scaleshift_mem_p = + key_with_hash + "@batch_norm_bwd_scaleshift_mem_p"; + const std::string key_batch_norm_diff_scaleshift_mem_p = + key_with_hash + "@batch_norm_bwd_diff_scaleshift_mem_p"; + const std::string key_batch_norm_diff_src_mem_p = + key_with_hash + "@batch_norm_bwd_diff_src_mem_p"; + const std::string key_batch_norm_diff_dst_mem_p = + key_with_hash + "@batch_norm_bwd_diff_dst_mem_p"; - // create primitive descriptor for batch norm backward - unsigned flags = mkldnn::use_scale_shift; - auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ - mkldnn::prop_kind::backward, diff_dst_md, - src_memory.get_primitive_desc().desc(), epsilon, flags}; - auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ - batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd}; - - // reorder user_diff_dst if it's not in preferred format - auto diff_dst_memory = user_diff_dst_memory; primitive reorder_diff_dst; bool is_diff_dst_reordered = false; - if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) { - diff_dst_memory = memory(diff_dst_pd); - reorder_diff_dst = reorder(user_diff_dst_memory, diff_dst_memory); - is_diff_dst_reordered = true; - } - - // create mkldnn memory for input tensors (src/mean/variance) - auto mean_memory = memory(batch_norm_bwd_pd.mean_primitive_desc(), - to_void_cast(batch_mean_data)); - auto variance_memory = memory(batch_norm_bwd_pd.variance_primitive_desc(), - to_void_cast(batch_variance_data)); + auto user_diff_dst_memory = memory( + {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, + to_void_cast(diff_y_data)); // MKLDNN requires a single piece of memory for scale and shift/bias data const size_t scaleshift_size = 2 * ic; @@ -306,30 +409,118 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic, &scaleshift_data); - // create mkldnn memory for input tensors (scale/shift) - auto scaleshift_memory = memory(batch_norm_bwd_pd.weights_primitive_desc(), - scaleshift_data.data()); - - // create mkldnn memory for output diff weights (combined scale/shift) std::vector diff_scaleshift_data; diff_scaleshift_data.reserve(scaleshift_size); - auto diff_scaleshift_memory = - memory(batch_norm_bwd_pd.diff_weights_primitive_desc(), - diff_scaleshift_data.data()); - // here assume diff_src is in the same format of src - auto diff_src_memory = memory(src_memory.get_primitive_desc(), diff_x_data); + auto batch_norm_fwd_pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_batch_norm_fwd_pd)); + PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr, + "Fail to find batch_norm_fwd_pd in device context"); - // finally create batch_norm backward primitive - auto batch_norm_bwd_prim = - batch_norm_bwd(batch_norm_bwd_pd, src_memory, mean_memory, - variance_memory, diff_dst_memory, scaleshift_memory, - diff_src_memory, diff_scaleshift_memory); + auto batch_norm_bwd_p = std::static_pointer_cast( + dev_ctx.GetBlob(key_batch_norm_bwd_p)); + + if (batch_norm_bwd_p == nullptr) { + auto src_memory = std::shared_ptr(new memory( + {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, + to_void_cast(x_data))); + + // for diff_dst, try to use same format as dst in forward pass + auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); + auto diff_dst_md = diff_dst_pd.desc(); + + // create primitive descriptor for batch norm backward + auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ + mkldnn::prop_kind::backward, diff_dst_md, + src_memory->get_primitive_desc().desc(), epsilon, flags}; + auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ + batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd}; + + // reorder user_diff_dst if it's not in preferred format + auto diff_dst_memory = std::make_shared(user_diff_dst_memory); + if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) { + diff_dst_memory = std::make_shared(diff_dst_pd); + reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); + is_diff_dst_reordered = true; + } + + // create mkldnn memory for input tensors (src/mean/variance) + auto mean_memory = + std::make_shared(batch_norm_bwd_pd.mean_primitive_desc(), + to_void_cast(batch_mean_data)); + auto variance_memory = + std::make_shared(batch_norm_bwd_pd.variance_primitive_desc(), + to_void_cast(batch_variance_data)); + + // create mkldnn memory for input tensors (scale/shift) + auto scaleshift_memory = std::make_shared( + batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()); + + // create mkldnn memory for output diff weights (combined scale/shift) + auto diff_scaleshift_memory = std::make_shared( + batch_norm_bwd_pd.diff_weights_primitive_desc(), + diff_scaleshift_data.data()); + + // here assume diff_src is in the same format of src + auto diff_src_memory = std::make_shared( + src_memory->get_primitive_desc(), diff_x_data); + + // finally create batch_norm backward primitive + batch_norm_bwd_p = std::make_shared( + batch_norm_bwd_pd, *src_memory, *mean_memory, *variance_memory, + *diff_dst_memory, *scaleshift_memory, *diff_src_memory, + *diff_scaleshift_memory); + + dev_ctx.SetBlob(key_batch_norm_bwd_p, batch_norm_bwd_p); + dev_ctx.SetBlob(key_batch_norm_src_mem_p, src_memory); + dev_ctx.SetBlob(key_batch_norm_mean_mem_p, mean_memory); + dev_ctx.SetBlob(key_batch_norm_variance_mem_p, variance_memory); + dev_ctx.SetBlob(key_batch_norm_scaleshift_mem_p, scaleshift_memory); + dev_ctx.SetBlob(key_batch_norm_diff_scaleshift_mem_p, + diff_scaleshift_memory); + dev_ctx.SetBlob(key_batch_norm_diff_src_mem_p, diff_src_memory); + dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory); + + // set layout/format of output tensors + diff_x->set_layout(DataLayout::kMKLDNN); + diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc() + .desc() + .data.format); + } else { + // primitives already exist + UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data)); + UpdateMemoryData(dev_ctx, key_batch_norm_mean_mem_p, + to_void_cast(batch_mean_data)); + UpdateMemoryData(dev_ctx, key_batch_norm_variance_mem_p, + to_void_cast(batch_variance_data)); + UpdateMemoryData(dev_ctx, key_batch_norm_scaleshift_mem_p, + scaleshift_data.data()); + UpdateMemoryData(dev_ctx, key_batch_norm_diff_scaleshift_mem_p, + diff_scaleshift_data.data()); + auto diff_src_memory = UpdateMemoryData( + dev_ctx, key_batch_norm_diff_src_mem_p, to_void_cast(diff_x_data)); + auto diff_dst_memory = UpdateMemoryData( + dev_ctx, key_batch_norm_diff_dst_mem_p, to_void_cast(diff_y_data)); + + // reorder user_diff_dst if it's not in preferred format + if (diff_dst_memory->get_primitive_desc() != + user_diff_dst_memory.get_primitive_desc()) { + reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); + is_diff_dst_reordered = true; + } + + // set layout/format of output tensors + diff_x->set_layout(DataLayout::kMKLDNN); + diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc() + .desc() + .data.format); + } // execute optional reorder and batch_norm backward primitive std::vector pipeline; if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst); - pipeline.push_back(batch_norm_bwd_prim); + pipeline.push_back(*batch_norm_bwd_p); stream(stream::kind::eager).submit(pipeline).wait(); // copy back diff sacle/shift to output tensors (diff scale/shift) @@ -338,12 +529,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { std::copy(it, std::next(it, ic), diff_scale_data); std::copy(std::next(it, ic), std::end(diff_scaleshift_data), diff_shift_data); - - // set layout/format of output tensors - diff_x->set_layout(DataLayout::kMKLDNN); - diff_x->set_format((memory::format)diff_src_memory.get_primitive_desc() - .desc() - .data.format); } }; } // namespace operators From fb4b4f8d57b8219a708b738ca4a7cf5790e46897 Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Mon, 27 Aug 2018 10:54:21 +0200 Subject: [PATCH 127/140] Refactor code --- .../fluid/operators/batch_norm_mkldnn_op.cc | 79 ++++++------------- 1 file changed, 26 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index cd1fb754a1..de641cb08e 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -62,56 +62,42 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); } - std::shared_ptr AcquireTestBatchNormFwd( + std::shared_ptr AcquireTestTrainingBatchNormFwd( std::shared_ptr src_memory, - const mkldnn::primitive::at &mean_memory, - const mkldnn::primitive::at &variance_memory, std::shared_ptr scaleshift_memory, - std::shared_ptr dst_memory) { + std::shared_ptr dst_memory, std::shared_ptr mean_memory, + std::shared_ptr variance_memory, bool is_test) { auto prim_key = key_ + "@batch_norm_p"; auto batch_norm_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE( - (batch_norm_p != nullptr) || (is_reusing_ == false), - "Fail to find batch norm primitive for test in device context"); - if (batch_norm_p == nullptr) { - batch_norm_p = std::make_shared( - *batch_norm_pd_, *src_memory, mean_memory, variance_memory, - *scaleshift_memory, *dst_memory); - dev_ctx_.SetBlob(prim_key, batch_norm_p); - } else { - is_reusing_ = true; - } - return batch_norm_p; - } + PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_, + "Fail to find batch norm primitive in device context"); - std::shared_ptr AcquireTrainingBatchNormFwd( - std::shared_ptr src_memory, - std::shared_ptr scaleshift_memory, - std::shared_ptr dst_memory, std::shared_ptr mean_memory, - std::shared_ptr variance_memory) { - auto prim_key = key_ + "@batch_norm_p"; - auto batch_norm_p = - std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE( - (batch_norm_p != nullptr) || (is_reusing_ == false), - "Fail to find batch norm primitive for training in device context"); if (batch_norm_p == nullptr) { - batch_norm_p = std::make_shared( - *batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory, - *mean_memory, *variance_memory); + if (is_test) { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, + (const mkldnn::primitive::at &)*mean_memory, + (const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory, + *dst_memory); + } else { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory, + *mean_memory, *variance_memory); + } dev_ctx_.SetBlob(prim_key, batch_norm_p); } else { is_reusing_ = true; } + return batch_norm_p; } - // + static std::string GetHash(const memory::dims &input_dims, float epsilon, unsigned flag, bool is_test, memory::format format, - const std::string &suffix) { + const std::string &suffix = "") { auto dims2str = [](const memory::dims &operand_dims) { std::string dstr = ""; for (size_t i = 0; i < operand_dims.size(); ++i) { @@ -128,19 +114,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr batch_norm_pd_; }; -std::string gethash(const memory::dims &input_dims, float epsilon, - unsigned flag, bool is_test, memory::format format) { - auto dims2str = [](const memory::dims &operand_dims) { - std::string dstr = ""; - for (size_t i = 0; i < operand_dims.size(); ++i) { - dstr += std::to_string(operand_dims[i]) + "-"; - } - return dstr; - }; - return dims2str(input_dims) + std::to_string(epsilon) + std::to_string(flag) + - std::to_string(is_test) + std::to_string(format); -} - std::shared_ptr UpdateMemoryData( const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key, void *new_ptr) { @@ -274,10 +247,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { handler.AcquireVarianceMemoryFromPrimitive( to_void_cast(variance_data)); - batch_norm_p = handler.AcquireTestBatchNormFwd( - src_memory, (const mkldnn::primitive::at &)*mean_memory, - (const mkldnn::primitive::at &)*variance_memory, scaleshift_memory, - dst_memory); + batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( + src_memory, scaleshift_memory, dst_memory, mean_memory, + variance_memory, true); } else { // create mkldnn memory for stats (as output) std::shared_ptr mean_memory = @@ -285,9 +257,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr variance_memory = handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data); - batch_norm_p = handler.AcquireTrainingBatchNormFwd( + batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( src_memory, scaleshift_memory, dst_memory, mean_memory, - variance_memory); + variance_memory, false); } y->set_layout(DataLayout::kMKLDNN); @@ -377,7 +349,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { // keys for primitives reuse const std::string key_with_hash = - key + gethash(src_tz, epsilon, flags, false, input_format); + key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false, + input_format); const std::string key_batch_norm_bwd_p = key_with_hash + "@batch_norm_bwd_p"; const std::string key_batch_norm_src_mem_p = From 312f3b86546ddbf7fce2a3bd0fb991c9cca2962a Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 27 Aug 2018 16:58:50 +0800 Subject: [PATCH 128/140] Fix random diff between python2 and python3 --- python/paddle/dataset/movielens.py | 5 ++- python/paddle/fluid/layers/nn.py | 42 +++++++++---------- .../fluid/transpiler/distribute_transpiler.py | 5 +-- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/python/paddle/dataset/movielens.py b/python/paddle/dataset/movielens.py index c98e0019f7..64bf741481 100644 --- a/python/paddle/dataset/movielens.py +++ b/python/paddle/dataset/movielens.py @@ -24,6 +24,7 @@ set and test set into paddle reader creators. from __future__ import print_function +import numpy as np import zipfile import paddle.dataset.common import re @@ -150,12 +151,12 @@ def __initialize_meta_info__(): def __reader__(rand_seed=0, test_ratio=0.1, is_test=False): fn = __initialize_meta_info__() - rand = random.Random(x=rand_seed) + np.random.seed(rand_seed) with zipfile.ZipFile(file=fn) as package: with package.open('ml-1m/ratings.dat') as rating: for line in rating: line = cpt.to_text(line, encoding='latin') - if (rand.random() < test_ratio) == is_test: + if (np.random.random() < test_ratio) == is_test: uid, mov_id, rating, _ = line.strip().split("::") uid = int(uid) mov_id = int(mov_id) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 66b776c08e..ca10d73b08 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -17,6 +17,7 @@ All layers just related to the neural network. from __future__ import print_function +import numpy as np from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable @@ -24,7 +25,6 @@ from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc from .tensor import concat from . import utils -import random from .. import unique_name from functools import reduce @@ -5102,7 +5102,7 @@ def random_crop(x, shape, seed=None): dtype = x.dtype out = helper.create_tmp_variable(dtype) if seed is None: - seed = random.randint(-65536, 65535) + seed = np.random.randint(-65536, 65536) op_attrs = {"shape": shape} if isinstance(seed, int): op_attrs["startup_seed"] = seed @@ -5416,7 +5416,7 @@ def prelu(x, mode, param_attr=None, name=None): channel:elements in a channel share same weight element:each element has a weight name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + will be named automatically. Returns: Variable: The output tensor with the same shape as input. @@ -5530,23 +5530,23 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): Supposing :code:`x` is a Tensor with shape [d_1, d_2, ..., d_n], the :code:`y` is a mask with shape [d_1, d_2, ..., d_n, maxlen], where: - + .. math:: - + y(i_1, i_2,..., i_n, j) = (j < x(i_1, i_2,..., i_n)) Args: - x (Variable): Input tensor of sequence_mask layer, + x (Variable): Input tensor of sequence_mask layer, whose elements are integers less than :code:`maxlen`. maxlen (int|None): Maximum length of the sequence. If :code:`maxlen` is None, it would be replace with :math:`max(x)`. dtype (np.dtype|core.VarDesc.VarType|str): Data type of the output. - name (str|None): A name for this layer(optional). If set None, the - layer will be named automatically. - + name (str|None): A name for this layer(optional). If set None, the + layer will be named automatically. + Returns: Variable: The output sequence mask. - + """ helper = LayerHelper('sequence_mask', **locals()) @@ -5571,23 +5571,23 @@ def stack(x, axis=0): **Stack Layer** This layer stacks all of the input :code:`x` along axis. - - Input :code:`x` can be a single variable, a :code:`list` of variables, - or a :code:`tuple` of variables. If :code:`x` is a :code:`list` or - :code:`tuple`, the shapes of all these variables must be the same. - Supposing the shape of each input is :math:`[d_0, d_1, ..., d_{n-1}]`, - the shape of the output variable would be - :math:`[d_0, d_1, ..., d_{axis}=len(x), ..., d_{n-1}]`. + + Input :code:`x` can be a single variable, a :code:`list` of variables, + or a :code:`tuple` of variables. If :code:`x` is a :code:`list` or + :code:`tuple`, the shapes of all these variables must be the same. + Supposing the shape of each input is :math:`[d_0, d_1, ..., d_{n-1}]`, + the shape of the output variable would be + :math:`[d_0, d_1, ..., d_{axis}=len(x), ..., d_{n-1}]`. If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x[0])+1`. - If :code:`axis` is None, it would be replaced with 0. + If :code:`axis` is None, it would be replaced with 0. Args: - x (Variable|list(Variable)|tuple(Variable)): Input variables. + x (Variable|list(Variable)|tuple(Variable)): Input variables. axis (int|None): The axis along which all inputs are stacked. - + Returns: Variable: The stacked variable. - + """ helper = LayerHelper('stack', **locals()) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 80d9758b3d..28ae89acd3 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -31,7 +31,6 @@ Steps to transpile pserver: """ import math -import random import numpy as np import collections import six @@ -239,8 +238,8 @@ class DistributeTranspiler(object): grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping)) if not self.config.slice_var_up: - random.seed(self.origin_program.random_seed) - random.shuffle(grad_var_mapping_items) + np.random.seed(self.origin_program.random_seed) + np.random.shuffle(grad_var_mapping_items) grad_name_to_send_dummy_out = dict() for grad_varname, splited_vars in grad_var_mapping_items: From 9cb455fa7d89f78199662f796149f3c108070bbe Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 27 Aug 2018 17:11:33 +0800 Subject: [PATCH 129/140] update function --- .../fluid/operators/math/sequence_padding.cc | 9 ------ .../fluid/operators/math/sequence_padding.cu | 8 ----- paddle/fluid/operators/sequence_pad_op.cc | 32 +++++++++++-------- .../tests/unittests/test_sequence_pad_op.py | 7 ++-- 4 files changed, 20 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index 02ede3edce..25f06a25a0 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -98,15 +98,6 @@ class PaddingLoDTensorFunctor { CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, step_width, norm_by_times, kSeqToPad, layout); - - // Set pad_tensor's lod info if possible - if (layout == kBatchLengthWidth) { - framework::LoD pad_lod(seq_lod.begin() + lod_level, seq_lod.end()); - for (size_t i = 0; i < pad_lod[0].size(); ++i) { - pad_lod[0][i] = i * pad_seq_len; - } - pad_tensor->set_lod(pad_lod); - } } }; diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index f94e8dbc3a..035e10dcbe 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -106,14 +106,6 @@ class PaddingLoDTensorFunctor { pad_data, seq_data, pad_value_data, pad_value.numel() == 1, seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, step_width, norm_by_times, layout); - - if (layout == kBatchLengthWidth) { - framework::LoD pad_lod(seq_lod.begin() + lod_level, seq_lod.end()); - for (size_t i = 0; i < pad_lod[0].size(); ++i) { - pad_lod[0][i] = i * pad_seq_len; - } - pad_tensor->set_lod(pad_lod); - } } }; diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index a08804cfba..44d73aa407 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -40,7 +40,8 @@ class SequencePadOp : public framework::OperatorWithKernel { "The Input(PadValue) must be a scalar or a tensor whose " "shape equals to time steps in sequences"); - int batch_dim_size = -1; + int out_dim_0 = -1; + int out_dim_1 = -1; if (ctx->IsRuntime()) { // run time @@ -64,7 +65,8 @@ class SequencePadOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GE(padded_length, max_seq_len, "The Attr(padded_length) must be -1 or an int greater " "than the length of the longest original sequence."); - batch_dim_size = padded_length * seq_num; + out_dim_0 = seq_num; + out_dim_1 = padded_length; } else { // compile time framework::VarDesc* x_desc = @@ -72,9 +74,11 @@ class SequencePadOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1); } - auto out_dims = x_dims; - out_dims[0] = batch_dim_size; - ctx->SetOutputDim("Out", out_dims); + std::vector out_dims_vec{out_dim_0, out_dim_1}; + auto time_step_dims_vec = framework::vectorize2int(time_step_dims); + out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(), + time_step_dims_vec.end()); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); } }; @@ -118,9 +122,9 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { and Input(PadValue): PadValue.data = [0] and attribite 'padded_length' = 4, - then we get 1-level LoDTensor: - Out.lod = [[0, 4, 8]] - Out.data = [a, b, 0, 0, c, d, e, 0] + then we get LoDTensor: + Out.data = [[a, b, 0, 0], + [c, d, e, 0]] Case 2: @@ -131,9 +135,9 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { PadValue.data = [0] and attribite 'padded_length' = -1, which mean using the length of longest input sequence(3 in this case), - then we get 1-level LoDTensor: - Out.lod = [[0, 3, 6]] - Out.data = [[a1, a2], [b1, b2], [0, 0], [c1, c2], [d1, d2], [e1, e2]] + then we get LoDTensor: + Out.data = [[[a1, a2], [b1, b2], [0, 0]], + [[c1, c2], [d1, d2], [e1, e2]]] Case 3: @@ -144,9 +148,9 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { PadValue.data = [p1, p2] and attribite 'padded_length' = -1, which mean using the length of longest input sequence(3 in this case), - then we get 1-level LoDTensor: - Out.lod = [[0, 3, 6]] - Out.data = [[a1, a2], [b1, b2], [p1, p2], [c1, c2], [d1, d2], [e1, e2]] + then we get LoDTensor: + Out.data = [[[a1, a2], [b1, b2], [p1, p2]], + [[c1, c2], [d1, d2], [e1, e2]]] )DOC"); } diff --git a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py index 7b9eedbf52..471515c817 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py @@ -61,11 +61,8 @@ class TestSequencePadOp(OpTest): padded_sequences.append(seq) start_idx = end_idx - out_len_lod = self.x_len_lod[:] - out_len_lod_0 = [padded_length] * len(x_len_lod_0) - out_len_lod[0] = out_len_lod_0 - out_data = np.concatenate(padded_sequences, axis=0) - self.outputs = {'Out': (out_data, out_len_lod)} + out_data = np.array(padded_sequences) + self.outputs = {'Out': out_data} def setUp(self): self.op_type = 'sequence_pad' From 49c31febb5fb5073705d2d3b17b8954ecea10ec1 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 17:06:18 +0800 Subject: [PATCH 130/140] fix typo and op test --- .../operators/fusion_seqexpand_concat_fc_op.cc | 13 ++++++------- .../unittests/test_fusion_seqexpand_concat_fc_op.py | 8 ++++---- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc index 641851585d..90aba5fe89 100644 --- a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc @@ -63,7 +63,7 @@ void FusionSeqExpandConcatFCOp::InferShape( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.MultiInput("X")[0]->type()), ctx.device_context()); } @@ -113,7 +113,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { auto* w = ctx.Input("FCWeight"); auto* b = ctx.Input("FCBias"); auto* out = ctx.Output("Out"); - auto* fc_out = ctx.Output("FCOUT"); + auto* fc_out = ctx.Output("FCOut"); auto* ref_in = ins[0]; auto ref_lod = ref_in->lod(); @@ -164,7 +164,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { math::FCCompute(blas, total_T, D, M0, ref_in_data, w_data, out_data, b ? b->data() : NULL); w_data = w_data + M0 * D; - // first one use write on + // first write on blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); w_data = w_data + M1 * D; for (size_t i = 2; i < ins.size(); ++i) { @@ -175,16 +175,15 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { K, w_data, D, static_cast(1), fc_out_data, D); w_data = w_data + K * D; } - + T* cur_out_data = out_data; for (int i = 0; i < N; ++i) { int seq_len = ref_lod[0][i + 1] - ref_lod[0][i]; T* src = fc_out_data + i * D; for (int step = 0; step < seq_len; ++step) { - blas.VADD(D, out_data, src, out_data); - out_data = out_data + D; + blas.VADD(D, cur_out_data, src, cur_out_data); + cur_out_data = cur_out_data + D; } } - fc_act(total_T * D, out_data, out_data); } }; diff --git a/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py b/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py index 7baf39eb3f..aeee3a9999 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py @@ -80,16 +80,16 @@ class TestFusionSeqExpandConcatFCOp(OpTest): out = fusion_seqexpand_concat_fc(xs, self.lod, w, b, ACTIVATION[self.fc_act]) - self.inputs = {'X': [(x0, self.lod)], 'FCWeight': w} - normal_lod = [i for i in range(bs + 1)] + self.inputs = {'X': [('x0', (x0, self.lod))], 'FCWeight': w} + normal_lod = [[1] * bs] for i in range(num_inputs - 1): - self.inputs['X'].append((xs[i + 1], normal_lod)) + self.inputs['X'].append(('x%d' % (i + 1), (xs[i + 1], normal_lod))) if self.with_bias: self.inputs['FCBias'] = b self.outputs = {'Out': (out, self.lod)} - self.attrs = {'fc_activation': self.fc_act, } + self.attrs = {'fc_activation': self.fc_act} def test_check_output(self): self.check_output() From b6d261dff51b216551b6a886faa67406a93849f9 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Mon, 27 Aug 2018 09:02:01 +0200 Subject: [PATCH 131/140] Enforce requested size of tensor to be sufficiently large --- paddle/fluid/framework/tensor.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index d61dbb98a2..b6ba0df033 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -40,7 +40,11 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type, "When calling this method, the Tensor's numel must be " "equal or larger than zero. " "Please check Tensor::Resize has been called first."); - size_t size = requested_size ? requested_size : numel() * SizeOfType(type); + size_t size = numel() * SizeOfType(type); + if (requested_size) { + PADDLE_ENFORCE_GE(requested_size, size); + size = requested_size; + } /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || holder_->size() < size + offset_) { From 9b2b49ff26e4c6e457f9f0b1da560b8ce061fd83 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Tue, 28 Aug 2018 10:08:54 +0800 Subject: [PATCH 132/140] test fix release branch api check (#12977) * test fix release branch api check * fix reviews > 30 * check approval after test, check api diff before test --- paddle/scripts/paddle_build.sh | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index a643e0ded0..7199424b47 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -335,12 +335,18 @@ function assert_api_not_changed() { fi python ${PADDLE_ROOT}/tools/diff_api.py ${PADDLE_ROOT}/paddle/fluid/API.spec new.spec deactivate +} + +function assert_api_spec_approvals() { + if [ -z ${BRANCH} ]; then + BRANCH="develop" + fi - API_CHANGE=`git diff --name-only upstream/develop | grep "paddle/fluid/API.spec" || true` + API_CHANGE=`git diff --name-only upstream/$BRANCH | grep "paddle/fluid/API.spec" || true` echo "checking API.spec change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}" if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then - # TODO: curl -H 'Authorization: token ${TOKEN}' - APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews | \ + # NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable. + APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433` echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}" if [ "${APPROVALS}" == "FALSE" ]; then @@ -622,11 +628,12 @@ function main() { cicheck) cmake_gen ${PYTHON_ABI:-""} build + assert_api_not_changed ${PYTHON_ABI:-""} run_test gen_capi_package gen_fluid_inference_lib test_fluid_inference_lib - assert_api_not_changed ${PYTHON_ABI:-""} + assert_api_spec_approvals ;; *) print_usage From 0ee6fed05b67a4f1e54676d0817a65311d851e45 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Tue, 28 Aug 2018 11:15:26 +0800 Subject: [PATCH 133/140] Refine dist rpc deps (#12899) * refine dist train RPC deps * clean up * clean up * fix ut * remove input for fetch_barrier * follow comments --- .../details/multi_devices_graph_pass.cc | 65 ++++++++++++------- paddle/fluid/framework/ir/graph.cc | 57 ---------------- paddle/fluid/operators/fetch_barrier_op.cc | 2 + paddle/fluid/operators/send_barrier_op.cc | 4 ++ python/paddle/fluid/layers/io.py | 11 +++- .../fluid/tests/unittests/dist_se_resnext.py | 22 ++++--- .../fluid/tests/unittests/dist_word2vec.py | 16 +++-- .../fluid/tests/unittests/test_dist_train.py | 2 +- .../tests/unittests/test_dist_word2vec.py | 2 +- .../fluid/transpiler/distribute_transpiler.py | 26 ++++++-- 10 files changed, 101 insertions(+), 106 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index bc61b0eacb..7722c9401e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -754,17 +754,26 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, node->Op()->Type()); CreateComputationalOp(result, node, op_dev_id); - if (node->Op()->Type() == "concat") { - ConnectOp(result, result->Get(kGraphOps).back().get(), - "fetch_barrier"); +} + +void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { + auto *op_handle = result->Get(kGraphOps).back().get(); + for (ir::Node *input : node->inputs) { + VarHandle *var = nullptr; + for (int place_offset = 0; place_offset < num_places; ++place_offset) { + auto &var_holders = result->Get(kGraphVars)[place_offset]; + auto &var_holder = var_holders[input->Name()]; + if (!var_holder.empty()) { + var = var_holder.rbegin()->get(); + op_handle->AddInput(var); + } + } } } // Create RPC related op handles that connects its in ops and out ops. void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { - // FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode - // put them into transpiler. int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. @@ -799,8 +808,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } auto recv_param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); - // FIXME(typhoonzero): assume each recv op output one param - // Use the same place as send. if (recv_param_grad.size() == 2U) { op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); VLOG(10) << "recv param " << recv_param_grad[0] @@ -814,34 +821,44 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, .emplace(varname, op_dev_id); } } else { - // send_barrier and fetch_barrier op can be scheduled on device 0 + // send_barrier, fetch_barrier will run on place 0; op_dev_id = 0; } PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", node->Op()->Type()); - result->Get(kGraphOps).emplace_back(new RPCOpHandle( result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], node->Op()->Type(), places_[op_dev_id])); - // TODO(panyx0718): This might not be needed anymore. - if (node->Op()->Type() == "send_barrier") { - ConnectOp(result, result->Get(kGraphOps).back().get(), "send"); - } else if (node->Op()->Type() == "recv") { - ConnectOp(result, result->Get(kGraphOps).back().get(), - "send_barrier"); - } else if (node->Op()->Type() == "fetch_barrier") { - ConnectOp(result, result->Get(kGraphOps).back().get(), "recv"); - } else if (node->Op()->Type() == "send") { - // do nothing + if (node->Op()->Type() == "send") { + CreateOpHandleIOs(result, node, op_dev_id); } else { - PADDLE_THROW( - "rpc op should be in [" - "send, send_barrier. recv, fetch_barrier]"); - } + // send_barrier, recv, fetch_barrier's inputs are deps var, get them from + // all places + auto p = places_[op_dev_id]; + auto *op_handle = result->Get(kGraphOps).back().get(); + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); - CreateOpHandleIOs(result, node, op_dev_id); + SetOpInputsAllPlaces(result, node, places_.size()); + for (ir::Node *output : node->outputs) { + int outvar_dev_id = op_dev_id; + if (node->Op()->Type() == "fetch_barrier") { + outvar_dev_id = GetVarDeviceID(*result, output->Name()); + PADDLE_ENFORCE_NE(outvar_dev_id, -1); + } + p = places_[outvar_dev_id]; + ir::Node *new_node = nullptr; + if (output->Var()) { + new_node = result->CreateVarNode(output->Var()); + } else { + new_node = + result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); + } + CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id); + } + } } bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 2a6bf4ac23..39b0f2f038 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -132,63 +132,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { } } - std::vector send_ops; - ir::Node *send_bar = nullptr; - std::vector recv_ops; - ir::Node *fetch_bar = nullptr; - for (ir::Node *node : Nodes()) { - if (node->Name() == "send") { - send_ops.push_back(node); - } else if (node->Name() == "send_barrier") { - PADDLE_ENFORCE(!send_bar, "only has one send barrier"); - send_bar = node; - } else if (node->Name() == "recv") { - recv_ops.push_back(node); - } else if (node->Name() == "fetch_barrier") { - PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier"); - fetch_bar = node; - } - } - if (send_bar) { - for (ir::Node *send : send_ops) { - ir::Node *dep_var = CreateControlDepVar(); - send->outputs.push_back(dep_var); - dep_var->inputs.push_back(send); - send_bar->inputs.push_back(dep_var); - dep_var->outputs.push_back(send_bar); - } - for (ir::Node *recv : recv_ops) { - ir::Node *dep_var = CreateControlDepVar(); - recv->inputs.push_back(dep_var); - dep_var->outputs.push_back(recv); - send_bar->outputs.push_back(dep_var); - dep_var->inputs.push_back(send_bar); - } - } - if (fetch_bar) { - for (ir::Node *recv : recv_ops) { - ir::Node *dep_var = CreateControlDepVar(); - recv->outputs.push_back(dep_var); - dep_var->inputs.push_back(recv); - fetch_bar->inputs.push_back(dep_var); - dep_var->outputs.push_back(fetch_bar); - } - } - - std::vector send_vars = FindDistTrainSendVars(send_ops); - std::vector recv_vars = FindDistTrainRecvVars(recv_ops); - for (ir::Node *node : Nodes()) { - if (IsDistTrainOp(node, send_vars, recv_vars)) { - if (fetch_bar && node->Name() == "concat") { - ir::Node *dep_var = CreateControlDepVar(); - fetch_bar->outputs.push_back(dep_var); - dep_var->inputs.push_back(fetch_bar); - node->inputs.push_back(dep_var); - dep_var->outputs.push_back(node); - } - } - } - /** * We should handle write after read(WAR) and write after write(WAW) here. * Because some of the operators of the program can be executed parallelly. diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index d9cd956dfd..9d7ac7ab61 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -52,6 +52,8 @@ class FetchBarrierOp : public framework::OperatorBase { class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { + AddOutput("Out", "(Any) Dummy outputs, used for control dependency") + .AsDuplicable(); AddComment(R"DOC( SendBarrier operator diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 14b07649c4..4040429526 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -56,6 +56,10 @@ class SendBarrierOp : public framework::OperatorBase { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { + AddInput("X", "(Any) Dummy inputs, used for control dependency") + .AsDuplicable(); + AddOutput("Out", "(Any) Dummy outputs, used for control dependency") + .AsDuplicable(); AddComment(R"DOC( SendBarrier operator diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index b03ee514f5..0cf7aaef4a 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -246,7 +246,11 @@ def Send(endpoints, send_vars, dummy_output=None, sync=True): rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC }) if sync: - helper.append_op(type="send_barrier", attrs={"endpoints": endpoints}) + helper.append_op( + type="send_barrier", + inputs={"X": dummy_output}, + outputs={"Out": []}, + attrs={"endpoints": endpoints}) def Recv(endpoints, get_vars, dummy_input=None, sync=True): @@ -282,7 +286,10 @@ def Recv(endpoints, get_vars, dummy_input=None, sync=True): attrs={"endpoints": endpoints, "epmap": epmap}) if sync: - helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints}) + helper.append_op( + type="fetch_barrier", + outputs={"Out": get_vars}, + attrs={"endpoints": endpoints}) return get_vars diff --git a/python/paddle/fluid/tests/unittests/dist_se_resnext.py b/python/paddle/fluid/tests/unittests/dist_se_resnext.py index 0387e91188..a4ffe7d40c 100644 --- a/python/paddle/fluid/tests/unittests/dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/dist_se_resnext.py @@ -134,7 +134,7 @@ class SE_ResNeXt(): size=class_dim, act='softmax', param_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(value=0.2))) + initializer=fluid.initializer.Constant(value=0.05))) return out def shortcut(self, input, ch_out, stride): @@ -184,7 +184,7 @@ class SE_ResNeXt(): act=None, # avoid pserver CPU init differs from GPU param_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(value=0.2)), + initializer=fluid.initializer.Constant(value=0.05)), bias_attr=False) return fluid.layers.batch_norm(input=conv, act=act) @@ -192,13 +192,19 @@ class SE_ResNeXt(): pool = fluid.layers.pool2d( input=input, pool_size=0, pool_type='avg', global_pooling=True) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) - squeeze = fluid.layers.fc(input=pool, - size=num_channels // reduction_ratio, - act='relu') + squeeze = fluid.layers.fc( + input=pool, + size=num_channels // reduction_ratio, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.05)), + act='relu') stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) - excitation = fluid.layers.fc(input=squeeze, - size=num_channels, - act='sigmoid') + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.05)), + act='sigmoid') scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) return scale diff --git a/python/paddle/fluid/tests/unittests/dist_word2vec.py b/python/paddle/fluid/tests/unittests/dist_word2vec.py index 0ad994a258..f3e740fc70 100644 --- a/python/paddle/fluid/tests/unittests/dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/dist_word2vec.py @@ -49,28 +49,32 @@ class TestDistWord2vec2x2(TestDistRunnerBase): dtype='float32', is_sparse=IS_SPARSE, param_attr=fluid.ParamAttr( - name='shared_w', initializer=fluid.initializer.Constant())) + name='shared_w', + initializer=fluid.initializer.Constant(value=0.1))) embed_second = fluid.layers.embedding( input=words[1], size=[dict_size, EMBED_SIZE], dtype='float32', is_sparse=IS_SPARSE, param_attr=fluid.ParamAttr( - name='shared_w', initializer=fluid.initializer.Constant())) + name='shared_w', + initializer=fluid.initializer.Constant(value=0.1))) embed_third = fluid.layers.embedding( input=words[2], size=[dict_size, EMBED_SIZE], dtype='float32', is_sparse=IS_SPARSE, param_attr=fluid.ParamAttr( - name='shared_w', initializer=fluid.initializer.Constant())) + name='shared_w', + initializer=fluid.initializer.Constant(value=0.1))) embed_forth = fluid.layers.embedding( input=words[3], size=[dict_size, EMBED_SIZE], dtype='float32', is_sparse=IS_SPARSE, param_attr=fluid.ParamAttr( - name='shared_w', initializer=fluid.initializer.Constant())) + name='shared_w', + initializer=fluid.initializer.Constant(value=0.1))) concat_embed = fluid.layers.concat( input=[embed_first, embed_second, embed_third, embed_forth], @@ -80,13 +84,13 @@ class TestDistWord2vec2x2(TestDistRunnerBase): size=HIDDEN_SIZE, act='sigmoid', param_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant())) + initializer=fluid.initializer.Constant(value=0.1))) predict_word = fluid.layers.fc( input=hidden1, size=dict_size, act='softmax', param_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant())) + initializer=fluid.initializer.Constant(value=0.1))) cost = fluid.layers.cross_entropy( input=predict_word, label=words[4]) avg_cost = fluid.layers.mean(cost) diff --git a/python/paddle/fluid/tests/unittests/test_dist_train.py b/python/paddle/fluid/tests/unittests/test_dist_train.py index 9581abdf39..083525ccf5 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_train.py +++ b/python/paddle/fluid/tests/unittests/test_dist_train.py @@ -100,7 +100,7 @@ class TestSendOp(unittest.TestCase): main.global_block().append_op( type="fetch_barrier", inputs={}, - outputs={}, + outputs={"Out": []}, attrs={ "endpoints": ["127.0.0.1:{0}".format(port)], RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE diff --git a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py index 38af149ad3..9a3e92e8d7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py @@ -22,7 +22,7 @@ class TestDistSeResneXt2x2(TestDistBase): self._sync_mode = True def test_se_resnext(self): - self.check_with_place("dist_word2vec.py", delta=1e-7) + self.check_with_place("dist_word2vec.py", delta=1e-4) class TestDistSeResneXt2x2Async(TestDistBase): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 28ae89acd3..21c51bd139 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -283,10 +283,13 @@ class DistributeTranspiler(object): send_vars.append(var) if self.sync_mode: + send_barrier_out = program.global_block().create_var( + name=framework.generate_control_dev_var_name()) + input_deps = grad_name_to_send_dummy_out.values() program.global_block().append_op( type="send_barrier", - inputs={}, - outputs={}, + inputs={"X": input_deps}, + outputs={"Out": send_barrier_out}, attrs={ "endpoints": pserver_endpoints, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE @@ -304,16 +307,22 @@ class DistributeTranspiler(object): self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) # step4: Concat the parameters splits together after recv. + all_recv_outputs = [] for param_varname, splited_var in six.iteritems(self.param_var_mapping): eps = [] for var in splited_var: index = [v.name for v in recv_vars].index(var.name) eps.append(eplist[index]) - grad_send_dummy_out = grad_name_to_send_dummy_out[ - self.param_name_to_grad_name[param_varname]] + if self.sync_mode: + recv_dep_in = send_barrier_out + else: + # connect deps to send op in async mode + recv_dep_in = grad_name_to_send_dummy_out[ + self.param_name_to_grad_name[param_varname]] + all_recv_outputs.extend(splited_var) program.global_block().append_op( type="recv", - inputs={"X": [grad_send_dummy_out]}, + inputs={"X": [recv_dep_in]}, outputs={"Out": splited_var}, attrs={ "epmap": eps, @@ -326,10 +335,11 @@ class DistributeTranspiler(object): }) if self.sync_mode: + # form a WAW dependency program.global_block().append_op( type="fetch_barrier", inputs={}, - outputs={}, + outputs={"Out": all_recv_outputs}, attrs={ "endpoints": pserver_endpoints, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE @@ -413,10 +423,12 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) + fetch_barrier_out = startup_program.global_block().create_var( + name=framework.generate_control_dev_var_name()) startup_program.global_block().append_op( type="fetch_barrier", inputs={}, - outputs={}, + outputs={"Out": fetch_barrier_out}, attrs={ "endpoints": self.pserver_endpoints, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE From 11e01d9b2d9a7796a57a22a77f68ea6427d75f5f Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 28 Aug 2018 12:47:26 +0800 Subject: [PATCH 134/140] Scale support selectedrows (#12960) * add ScaleOpVarTypeInference for scale op * scale op support scale selected rows * optimize code * use FindVar * use FindVarRecursive in ScaleOpVarTypeInference --- paddle/fluid/operators/scale_op.cc | 21 +++++++- paddle/fluid/operators/scale_op.h | 28 +++++++--- .../fluid/tests/unittests/test_scale_op.py | 54 +++++++++++++++++++ 3 files changed, 94 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 7f8822e400..c614de2eac 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" + #include +#include "paddle/fluid/operators/detail/safe_ref.h" + namespace paddle { namespace operators { @@ -52,6 +55,21 @@ $$Out = scale*X$$ } }; +class ScaleOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto &in_var_name = op_desc.Input("X").front(); + auto &in_var = detail::Ref(block->FindVarRecursive(in_var_name)); + + auto out_var_name = op_desc.Output("Out").front(); + auto *out_var = block->FindVarRecursive(out_var_name); + + out_var->SetType(in_var.GetType()); + out_var->SetDataType(in_var.GetDataType()); + } +}; + class ScaleGradMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -71,7 +89,8 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; -REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker); +REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker, + ops::ScaleOpVarTypeInference); REGISTER_OP_CPU_KERNEL( scale, ops::ScaleKernel, ops::ScaleKernel, diff --git a/paddle/fluid/operators/scale_op.h b/paddle/fluid/operators/scale_op.h index c6a59b76ad..fe035aba81 100644 --- a/paddle/fluid/operators/scale_op.h +++ b/paddle/fluid/operators/scale_op.h @@ -22,17 +22,29 @@ namespace operators { template class ScaleKernel : public framework::OpKernel { public: - virtual void Compute(const framework::ExecutionContext& context) const { - auto* tensor = context.Output("Out"); - auto* in = context.Input("X"); - tensor->mutable_data(in->place()); + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto* in_var = ctx.InputVar("X"); + auto* in = ctx.Input("X"); - auto scale = static_cast(context.Attr("scale")); + auto* out_var = ctx.OutputVar("Out"); + auto* out = ctx.Output("Out"); + out->mutable_data(in->place()); - auto eigen_out = framework::EigenVector::Flatten(*tensor); + PADDLE_ENFORCE_EQ(in->dims(), out->dims(), + "in and out should have the same dim"); + + auto scale = static_cast(ctx.Attr("scale")); + + if (in_var->IsType() && in_var != out_var) { + auto& in_slr = in_var->Get(); + auto* out_slr = out_var->GetMutable(); + out_slr->set_rows(in_slr.rows()); + out_slr->set_height(in_slr.height()); + } + + auto eigen_out = framework::EigenVector::Flatten(*out); auto eigen_in = framework::EigenVector::Flatten(*in); - auto& dev = - *context.template device_context().eigen_device(); + auto& dev = *ctx.template device_context().eigen_device(); eigen_out.device(dev) = scale * eigen_in; } }; diff --git a/python/paddle/fluid/tests/unittests/test_scale_op.py b/python/paddle/fluid/tests/unittests/test_scale_op.py index 0a8a43253d..032af6ed5c 100644 --- a/python/paddle/fluid/tests/unittests/test_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_scale_op.py @@ -17,6 +17,8 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator class TestScaleOp(OpTest): @@ -33,5 +35,57 @@ class TestScaleOp(OpTest): self.check_grad(['X'], 'Out') +class TestScaleOpSelectedRows(unittest.TestCase): + def check_with_place(self, place, in_name, out_name): + scope = core.Scope() + + # create and initialize Grad Variable + in_height = 10 + in_rows = [0, 4, 7] + in_row_numel = 12 + scale = 2.0 + + in_selected_rows = scope.var(in_name).get_selected_rows() + in_selected_rows.set_height(in_height) + in_selected_rows.set_rows(in_rows) + in_array = np.random.random( + (len(in_rows), in_row_numel)).astype("float32") + + in_tensor = in_selected_rows.get_tensor() + in_tensor.set(in_array, place) + + # create and initialize Param Variable + out_selected_rows = scope.var(out_name).get_selected_rows() + out_tensor = out_selected_rows.get_tensor() + out_tensor._set_dims(in_tensor._get_dims()) + + # create and run sgd operator + scale_op = Operator("scale", X=in_name, Out=out_name, scale=scale) + scale_op.run(scope, place) + + # get and compare result + out_height = out_selected_rows.height() + out_rows = out_selected_rows.rows() + result_array = np.array(out_tensor) + + assert (in_array * scale == result_array).all() + assert in_height == out_height + assert in_rows == out_rows + + def test_scale_selected_rows(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place, 'in', 'out') + + def test_scale_selected_rows_inplace(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place, 'in', 'in') + + if __name__ == "__main__": unittest.main() From 0353eddb51414e48693f0434121b1df761f4644e Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 28 Aug 2018 12:59:31 +0800 Subject: [PATCH 135/140] Improve fake_dequantize_op. (#12877) * Improve fake_dequantize_op. * Follow comments. --- paddle/fluid/operators/fake_dequantize_op.cc | 37 +++++++++++++------ paddle/fluid/operators/fake_dequantize_op.cu | 36 ++++++++++++++++++ paddle/fluid/operators/fake_dequantize_op.h | 23 ++++++++---- .../unittests/test_fake_dequantize_op.py | 33 +++++++++++------ 4 files changed, 97 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 43f9491111..2008e70275 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -18,15 +18,32 @@ limitations under the License. */ namespace paddle { namespace operators { +template +struct DequantizeFunctor { + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + T max_range, framework::Tensor* out) { + auto in_e = framework::EigenVector::Flatten(*in); + const T* scale_factor = scale->data(); + auto out_e = framework::EigenVector::Flatten(*out); + + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = (scale_factor[0] / max_range) * in_e; + } +}; + +template struct DequantizeFunctor; +template struct DequantizeFunctor; + class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { public: - FakeDequantizeMaxAbsOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + FakeDequantizeMaxAbsOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FakeDequantizeMaxAbsOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor) The input with float-32/64 type is the " "low precision tensor."); + AddInput("Scale", "(float) The scale in quantization stage."); AddOutput("Out", "(Tensor) The output is the dequantized high " "precision tensor."); - AddAttr("num_bits", - "(int) `num_bits` is the quantization level bits, " - "such as 2, 5, 8."); - AddAttr("scale", - "(float) The maximum absolute value of low precision tensor." - "It is usually calculated by the fake_quantize_max_abs_op."); + AddAttr("max_range", "(float) The max range in quantization stage."); AddComment(R"DOC( FakeDequantizeMaxAbsOp operator. This calculation is an opposite operation of FakeQuantizeMaxAbsOp: -$$Out = \frac{scale*X}{2^{num_bits} - 1}$$ +$$Out = \frac{scale*X}{ max_range }$$ )DOC"); } diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 1bd38d1bd2..225bcc45bc 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -14,6 +14,42 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_dequantize_op.h" +namespace paddle { +namespace operators { + +template +__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num, + T* out) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < num) { + out[idx] = in[idx] * scale[0] / max_range; + } +} + +template +struct DequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + T max_range, framework::Tensor* out) { + const T* in_data = in->data(); + const T* scale_factor = scale->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + + int num = in->numel(); + int block = 512; + int grid = (num + block - 1) / block; + + KeDequantize<<>>( + in_data, scale_factor, max_range, num, out_data); + } +}; + +template struct DequantizeFunctor; +template struct DequantizeFunctor; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs, diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index 0901e68b37..d9923a10da 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -19,22 +19,29 @@ limitations under the License. */ namespace paddle { namespace operators { + +template +struct DequantizeFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor* scale, T max_range, + framework::Tensor* out); +}; + template class FakeDequantizeMaxAbsKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { auto* in = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); auto* out = ctx.Output("Out"); - out->mutable_data(in->place()); - int num_bits = ctx.Attr("num_bits"); - T scale = static_cast(ctx.Attr("scale")); - int range = std::pow(2, num_bits) - 1; + float max_range = ctx.Attr("max_range"); + + auto& dev_ctx = ctx.template device_context(); + out->mutable_data(dev_ctx.GetPlace()); - auto eigen_out = framework::EigenVector::Flatten(*out); - auto eigen_in = framework::EigenVector::Flatten(*in); - auto& dev = *ctx.template device_context().eigen_device(); - eigen_out.device(dev) = (scale / range) * eigen_in; + DequantizeFunctor()(dev_ctx, in, scale, + static_cast(max_range), out); } }; diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index d84ebed3fa..1bb4662e8d 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -20,41 +20,50 @@ import math from op_test import OpTest -def quantize_max_abs(x, num_bits): - range = math.pow(2, num_bits) - 1 +def quantize_max_abs(x, max_range): scale = np.max(np.abs(x).flatten()) - y = np.round(x / scale * range) + y = np.round(x / scale * max_range) return y, scale -def dequantize_max_abs(x, num_bits, scale): - range = math.pow(2, num_bits) - 1 - y = (scale / range) * x +def dequantize_max_abs(x, scale, max_range): + y = (scale / max_range) * x return y class TestFakeDequantizeMaxAbsOp(OpTest): def set_args(self): self.num_bits = 8 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" def setUp(self): self.set_args() self.op_type = "fake_dequantize_max_abs" - x = np.random.randn(31, 65).astype("float32") - yq, scale = quantize_max_abs(x, self.num_bits) - ydq = dequantize_max_abs(yq, self.num_bits, scale) + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + ydq = dequantize_max_abs(yq, scale, self.max_range) - self.inputs = {'X': yq} - self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)} + self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)} + self.attrs = {'max_range': self.max_range} self.outputs = {'Out': ydq} def test_check_output(self): self.check_output() -class TestFakeDequantizeMaxAbsOp5Bits(OpTest): +class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp): + def set_args(self): + self.num_bits = 8 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float64" + + +class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp): def set_args(self): self.num_bits = 5 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" if __name__ == "__main__": From 7ad39c4077b9bc50ab61079be4e7117140a9b18b Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 28 Aug 2018 13:16:32 +0800 Subject: [PATCH 136/140] Enhance pad_constant_like_op (#12999) * enhance pad_constant_like_op * add API * add API --- paddle/fluid/API.spec | 1 + .../fluid/operators/pad_constant_like_op.cc | 16 ++++ python/paddle/fluid/layers/nn.py | 83 ++++++++++++++++++- 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index a9ca260621..7ae0f445a8 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -147,6 +147,7 @@ paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', ' paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.lrn ArgSpec(args=['input', 'n', 'k', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(5, 1.0, 0.0001, 0.75, None)) paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None)) +paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None)) paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None)) paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc index 5958811d38..37646c7b4c 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cc +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -43,6 +43,14 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Y")->type()), + ctx.device_context()); + } }; class PadConstantLikeOpMaker : public framework::OpProtoAndCheckerMaker { @@ -159,6 +167,14 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { } } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Y")->type()), + ctx.device_context()); + } }; class PadConstantLikeOpGradMaker : public framework::SingleGradOpDescMaker { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f98b18afa7..3e3f884137 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -88,6 +88,7 @@ __all__ = [ 'lod_reset', 'lrn', 'pad', + 'pad_constant_like', 'label_smooth', 'roi_pool', 'dice_loss', @@ -4755,6 +4756,86 @@ def pad(x, paddings, pad_value=0., name=None): return out +def pad_constant_like(x, y, pad_value=0., name=None): + """ + Pad input(Y) with :attr:`pad_value`, the number of values padded to + the edges of each axis is specified by the difference of the shape + of X and Y. ((0, shape_x_0 - shape_y_0), ... (0, shape_x_n - shape_y_n)) + unique pad widths for each axis. The input should be a k-D + tensor(k > 0 and k < 7). + + See below for an example. + + .. code-block:: text + + Given: + X = [[[[ 0, 1, 2], + [ 3, 4, 5]], + [[ 6, 7, 8], + [ 9, 10, 11]], + [[12, 13, 14], + [15, 16, 17]]], + [[[18, 19, 20], + [21, 22, 23]], + [[24, 25, 26], + [27, 28, 29]], + [[30, 31, 32], + [33, 34, 35]]]] + X.shape = (2, 3, 2, 3) + + Y = [[[[35, 36, 37]], + [[38, 39, 40]], + [[41, 42, 43]]]] + Y.shape = (1, 3, 1, 3) + + And + pad_value = -1, + + Return: + Out = [[[[35, 36, 37], + [-1, -1, -1]], + [[38, 39, 40], + [-1, -1, -1]], + [[41, 42, 43], + [-1, -1, -1]]], + [[[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]]]] + Out.shape = (2, 3, 2, 3) + + Args: + x (Variable): The input tensor variable. + y (Variable): The input tensor variable. + pad_value (float): The constant value used to pad. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The padded tensor variable. + + Examples: + .. code-block:: python + + # x is a rank 4 tensor variable, x.shape = (2, 3, 2, 3) + # y is a rank 4 tensor variable, y.shape = (1, 3, 1, 3) + out = fluid.layers.pad_constant_like(x=x, y=y, pad_value=0.) + # out is a rank 4 tensor variable, and out.shape = [2, 3 ,2 , 3] + """ + helper = LayerHelper('pad_constant_like', input=x, **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + helper.append_op( + type='pad_constant_like', + inputs={'X': x, + 'Y': y}, + outputs={'Out': out}, + attrs={'pad_value': float(pad_value)}) + return out + + def label_smooth(label, prior_dist=None, epsilon=0.1, @@ -5351,7 +5432,7 @@ def crop(x, shape=None, offsets=None, name=None): helper = LayerHelper('crop', **locals()) if not (isinstance(shape, list) or isinstance(shape, tuple) or \ - isinstance(shape, Variable)): + isinstance(shape, Variable)): raise ValueError("The shape should be a list, tuple or Variable.") if offsets is None: From 9be39bb4b702c70649cd59bdd7aee95b3db0c34b Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 28 Aug 2018 14:55:36 +0800 Subject: [PATCH 137/140] Enhence optimizer. (#13004) --- python/paddle/fluid/optimizer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 031ddd09a0..6b9749a579 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -46,10 +46,12 @@ class Optimizer(object): def __init__(self, learning_rate, regularization=None, - LARS_weight_decay=0.0): + LARS_weight_decay=0.0, + name=None): if not isinstance(learning_rate, float) and \ not isinstance(learning_rate, framework.Variable): raise TypeError("learning rate should be float or Variable") + self._name = name self.regularization = regularization self._learning_rate = learning_rate # the learning rate type should be inferenced from loss @@ -153,6 +155,8 @@ class Optimizer(object): dtype: data type of the accumulator variable fill_value: value to initialize the accumulator variable """ + if self._name is not None: + name = self._name + "_" + name if (name in self._accumulators and param.name in self._accumulators[name]): raise Exception("Accumulator {} already exists for parameter {}". @@ -181,6 +185,8 @@ class Optimizer(object): Returns: accumulator variable for the parameter """ + if self._name is not None: + name = self._name + "_" + name if (name not in self._accumulators or param.name not in self._accumulators[name]): raise Exception("Accumulator {} does not exist for parameter {}". From 8965cee89f83f2d2d4d403e0908232a2810e3149 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 28 Aug 2018 16:26:37 +0800 Subject: [PATCH 138/140] Polish PrintOp (#12895) * Polish PrintOp * Polish PrintOp * Polish PrintOp * Refine test_print_op --- paddle/fluid/framework/var_type.h | 2 +- paddle/fluid/operators/print_op.cc | 104 ++++++------------ python/paddle/fluid/layers/control_flow.py | 5 +- .../fluid/tests/unittests/test_print_op.py | 5 +- 4 files changed, 37 insertions(+), 79 deletions(-) diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h index 429997c8b8..e9550dbfb9 100644 --- a/paddle/fluid/framework/var_type.h +++ b/paddle/fluid/framework/var_type.h @@ -26,7 +26,7 @@ namespace paddle { namespace framework { template -bool IsType(const std::type_index& type_index) { +inline bool IsType(const std::type_index& type_index) { return type_index == std::type_index(typeid(T)); } diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index cceac40295..e7f1caf4d3 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -13,14 +13,12 @@ limitations under the License. */ #include -#include - #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/framework/variable.h" namespace paddle { namespace operators { +using framework::GradVarName; #define CLOG std::cout @@ -35,7 +33,7 @@ struct Formater { std::type_index dtype{typeid(const char)}; framework::LoD lod; int summarize; - void* data{nullptr}; + void *data{nullptr}; void operator()(size_t size) { PrintMessage(); @@ -101,7 +99,7 @@ struct Formater { template void Display(size_t size) { - auto* d = reinterpret_cast(data); + auto *d = reinterpret_cast(data); CLOG << "\tdata: "; if (summarize != -1) { summarize = std::min(size, (size_t)summarize); @@ -120,51 +118,36 @@ struct Formater { // TODO(ChunweiYan) there should be some other printers for TensorArray class TensorPrintOp : public framework::OperatorBase { public: - TensorPrintOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + TensorPrintOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - TensorPrintOp(const TensorPrintOp& o) + TensorPrintOp(const TensorPrintOp &o) : framework::OperatorBase( - static_cast(o)) { + static_cast(o)) { PADDLE_THROW("Not implemented."); } private: - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - const framework::Variable* in_var_ptr = nullptr; - std::string phase(kForward); + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + const framework::Variable *in_var_ptr = nullptr; std::string printed_var_name = ""; - auto& inputs = Inputs(); - if (inputs.find("In") != inputs.end() && !Inputs("In").empty()) { - in_var_ptr = scope.FindVar(Input("In")); - printed_var_name = Inputs("In").front(); - } else if (inputs.find("In@GRAD") != inputs.end() && - !Inputs("In@GRAD").empty()) { - in_var_ptr = scope.FindVar(Input("In@GRAD")); - printed_var_name = Inputs("In@GRAD").front(); - phase = std::string(kBackward); - } else { - PADDLE_THROW("Unknown phase, should be forward or backward."); - } + in_var_ptr = scope.FindVar(Input("In")); + printed_var_name = Inputs("In").front(); PADDLE_ENFORCE_NOT_NULL(in_var_ptr); - auto& in_tensor = in_var_ptr->Get(); - auto* out_var_ptr = scope.FindVar(Output("Out")); - auto& out_tensor = *out_var_ptr->GetMutable(); - - // Just copy data from input tensor to output tensor - // output tensor share same memory with input tensor - out_tensor.ShareDataWith(in_tensor); - out_tensor.set_lod(in_tensor.lod()); + auto &in_tensor = in_var_ptr->Get(); std::string print_phase = Attr("print_phase"); - if (print_phase != phase && print_phase != std::string(kBoth)) { + bool is_forward = Attr("is_forward"); + + if ((is_forward && print_phase == kBackward) || + (!is_forward && print_phase == kForward)) { return; } @@ -192,7 +175,7 @@ class TensorPrintOp : public framework::OperatorBase { formater.dtype = printed_tensor.type(); } if (Attr("print_tensor_shape")) { - auto& dims = printed_tensor.dims(); + auto &dims = printed_tensor.dims(); formater.dims.resize(dims.size()); for (int i = 0; i < dims.size(); ++i) formater.dims[i] = dims[i]; } @@ -200,7 +183,7 @@ class TensorPrintOp : public framework::OperatorBase { formater.lod = printed_tensor.lod(); } formater.summarize = Attr("summarize"); - formater.data = reinterpret_cast(printed_tensor.data()); + formater.data = reinterpret_cast(printed_tensor.data()); formater(printed_tensor.numel()); } @@ -219,14 +202,14 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker { AddAttr("print_tensor_type", "Whether to print the tensor's dtype."); AddAttr("print_tensor_shape", "Whether to print the tensor's shape."); AddAttr("print_tensor_lod", "Whether to print the tensor's lod."); - AddAttr( - "print_phase", - "(string, default 'BOTH') Which phase to display including 'FORWARD' " - "'BACKWARD' and 'BOTH'.") + AddAttr("print_phase", + "(string, default 'FORWARD') Which phase to display " + "including 'FORWARD' " + "'BACKWARD' and 'BOTH'.") .SetDefault(std::string(kBoth)) .InEnum({std::string(kForward), std::string(kBackward), std::string(kBoth)}); - AddOutput("Out", "Output tensor with same data as input tensor."); + AddAttr("is_forward", "Whether is forward or not").SetDefault(true); AddComment(R"DOC( Creates a print op that will print when a tensor is accessed. @@ -238,40 +221,21 @@ tensor `t`.)DOC"); class InferShapeForward : public framework::InferShapeBase { public: - void operator()(framework::InferShapeContext* context) const override { + void operator()(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInput("In"), "Input(In) should not be null."); - context->ShareLoD("In", /*->*/ "Out"); - context->SetOutputDim("Out", context->GetInputDim("In")); - } -}; - -class InferShapeBackward : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* context) const override { - PADDLE_ENFORCE(context->HasInput("In@GRAD"), - "Input(In@GRAD) should not be null."); - context->ShareLoD("In@GRAD", /*->*/ "Out"); - context->SetOutputDim("Out", context->GetInputDim("In@GRAD")); } }; -class InferVarType : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override {} -}; - -class PrintOpProtoAndCheckGradOpMaker - : public framework::SingleGradOpDescMaker { +class PrintOpGradientMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; std::unique_ptr Apply() const override { - auto* op_desc_ptr = new framework::OpDesc(); - op_desc_ptr->SetType("print_grad"); - op_desc_ptr->SetInput("In@GRAD", OutputGrad("Out")); - op_desc_ptr->SetOutput("Out", InputGrad("In")); + auto *op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("print"); + op_desc_ptr->SetInput("In", InputGrad("In")); op_desc_ptr->SetAttrMap(Attrs()); + op_desc_ptr->SetAttr("is_forward", false); return std::unique_ptr(op_desc_ptr); } }; @@ -282,6 +246,4 @@ class PrintOpProtoAndCheckGradOpMaker namespace ops = paddle::operators; REGISTER_OPERATOR(print, ops::TensorPrintOp, ops::PrintOpProtoAndCheckMaker, - ops::PrintOpProtoAndCheckGradOpMaker, ops::InferShapeForward, - ops::InferVarType); -REGISTER_OPERATOR(print_grad, ops::TensorPrintOp, ops::InferShapeBackward); + ops::PrintOpGradientMaker, ops::InferShapeForward); diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index d2954c4c22..c9a2f8a0ab 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -189,7 +189,6 @@ def Print(input, message="The content of some_layer: ") ''' helper = LayerHelper('print', **locals()) - out = helper.create_tmp_variable(dtype=helper.input_dtype()) helper.append_op( type='print', inputs={'In': input}, @@ -202,9 +201,7 @@ def Print(input, 'print_tensor_shape': print_tensor_shape, 'print_tensor_lod': print_tensor_lod, 'print_phase': print_phase.upper() - }, - outputs={'Out': out}) - return out + }) class BlockGuard(object): diff --git a/python/paddle/fluid/tests/unittests/test_print_op.py b/python/paddle/fluid/tests/unittests/test_print_op.py index ac682d6181..8097b5f734 100644 --- a/python/paddle/fluid/tests/unittests/test_print_op.py +++ b/python/paddle/fluid/tests/unittests/test_print_op.py @@ -35,9 +35,8 @@ class TestPrintOpCPU(unittest.TestCase): def build_network(self, only_forward, **kargs): x = layers.data('x', shape=[3], dtype='float32', lod_level=1) x.stop_gradient = False - printed = layers.Print(input=x, **kargs) - if only_forward: return printed - loss = layers.mean(printed) + layers.Print(input=x, **kargs) + loss = layers.mean(x) append_backward(loss=loss) return loss From a22309afe8e0a520ec16fb4ada3dbc2f12e7ce57 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 28 Aug 2018 17:39:37 +0800 Subject: [PATCH 139/140] clean useless check code in auc_op (#13023) --- paddle/fluid/operators/auc_op.h | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/paddle/fluid/operators/auc_op.h b/paddle/fluid/operators/auc_op.h index 0651203286..0a18585edb 100644 --- a/paddle/fluid/operators/auc_op.h +++ b/paddle/fluid/operators/auc_op.h @@ -60,20 +60,6 @@ class AucKernel : public framework::OpKernel { const T* inference_data = predict->data(); const auto* label_data = label->data(); - // check if states are inited. - auto* tp_in = ctx.Input("TP"); - auto* fp_in = ctx.Input("FP"); - auto* tn_in = ctx.Input("TN"); - auto* fn_in = ctx.Input("FN"); - PADDLE_ENFORCE(tp_in->IsInitialized(), "true_positive is not inited!"); - PADDLE_ENFORCE(fp_in->IsInitialized(), "false_negative is not inited!"); - PADDLE_ENFORCE(tn_in->IsInitialized(), "true_negative is not inited!"); - PADDLE_ENFORCE(fn_in->IsInitialized(), "false_positive is not inited!"); - PADDLE_ENFORCE_EQ(tp_in->numel(), num_thresholds, ""); - PADDLE_ENFORCE_EQ(fp_in->numel(), num_thresholds, ""); - PADDLE_ENFORCE_EQ(tn_in->numel(), num_thresholds, ""); - PADDLE_ENFORCE_EQ(fn_in->numel(), num_thresholds, ""); - auto* tp_data = true_positive->mutable_data(ctx.GetPlace()); auto* fn_data = false_negative->mutable_data(ctx.GetPlace()); auto* tn_data = true_negative->mutable_data(ctx.GetPlace()); From 82671e9486e9f6ed13eded1792750512c19f2a8f Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 28 Aug 2018 19:52:21 +0800 Subject: [PATCH 140/140] Fix bug in flowers dataset. (#13024) --- python/paddle/dataset/image.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/dataset/image.py b/python/paddle/dataset/image.py index b32736ee7c..920dbf3b4e 100644 --- a/python/paddle/dataset/image.py +++ b/python/paddle/dataset/image.py @@ -203,7 +203,7 @@ def resize_short(im, size): h_new = size * h // w else: w_new = size * w // h - im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) + im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_CUBIC) return im @@ -345,7 +345,6 @@ def simple_transform(im, if np.random.randint(2) == 0: im = left_right_flip(im, is_color) else: - im = center_crop(im, crop_size, is_color) im = center_crop(im, crop_size, is_color=is_color) if len(im.shape) == 3: im = to_chw(im)