From 334c84526b7ec858d1ea952e24b9c96ece39862c Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 20 Sep 2017 11:06:21 -0700 Subject: [PATCH 01/28] lstm unit --- paddle/operators/lstm_unit_op.cc | 101 +++++++++++++++++++++ paddle/operators/lstm_unit_op.h | 147 +++++++++++++++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 paddle/operators/lstm_unit_op.cc create mode 100644 paddle/operators/lstm_unit_op.h diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc new file mode 100644 index 0000000000..e3cac2605a --- /dev/null +++ b/paddle/operators/lstm_unit_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lstm_unit_op.h" + +namespace paddle { +namespace operators { + +class LstmUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"), + "Input(C_prev) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"), + "Output(C) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"), + "Output(H) of LSTM should not be null."); + + auto *x = ctx.Input("X"); + auto *c_prev = ctx.Input("C_prev"); + + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4, + "Dimension of FC should equal to prev state * 4"); + + int b_size = c_prev->dims()[0]; // batch size + int s_dim = c_prev->dims()[1]; // state dim + ctx.Output("C")->Resize({b_size, s_dim}); + ctx.Output("H")->Resize({b_size, s_dim}); + } +}; + +template +class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LstmUnitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "FC input before the non-linear activation."); + AddInput( + "C_prev", + "The cell state tensor of last time-step in the Lstm Unit operator."); + AddOutput("C", "The cell tensor of Lstm Unit operator."); + AddOutput("H", "The hidden state tensor of Lstm Unit operator."); + + AddComment(R"DOC(Lstm-Unit Operator + +Equation: + i, j, f, o = split(X) + C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) + H = C * sigm(o) + +)DOC"); + AddAttr("forget_bias", "The forget bias of Lstm Unit.") + .SetDefault(0.0); + } +}; + +class LstmUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")), + "Input(C@GRAD) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")), + "Input(H@GRAD) should not be null"); + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("X")->dims()); + ctx.Output(framework::GradVarName("C_prev")) + ->Resize(ctx.Input("C_prev")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, + lstm_unit_grad, ops::LstmUnitGradOp); +REGISTER_OP_CPU_KERNEL(lstm_unit, + ops::LstmUnitKernel); diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h new file mode 100644 index 0000000000..6e870f65e2 --- /dev/null +++ b/paddle/operators/lstm_unit_op.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; + +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +class LstmUnitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto* x_tensor = ctx.Input("X"); + auto* c_prev_tensor = ctx.Input("C_prev"); + auto* c_tensor = ctx.Output("C"); + auto* h_tensor = ctx.Output("H"); + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int b_size = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + T* C = c_tensor->mutable_data(ctx.GetPlace()); + T* H = h_tensor->mutable_data(ctx.GetPlace()); + + const T* X = x_tensor->data(); + const T* C_prev = c_prev_tensor->data(); + + for (int n = 0; n < b_size; ++n) { + for (int d = 0; d < D; ++d) { + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = f * c_prev + i * g; + C[d] = c; + const T tanh_c = tanh(c); + H[d] = o * tanh_c; + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + } + } +}; + +template +class LstmUnitGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x_tensor = ctx.Input("X"); + auto c_prev_tensor = ctx.Input("C_prev"); + auto c_tensor = ctx.Input("C"); + auto h_tensor = ctx.Input("H"); + + auto hdiff_tensor = ctx.Input(framework::GradVarName("H")); + auto cdiff_tensor = ctx.Input(framework::GradVarName("C")); + + auto xdiff_tensor = ctx.Output(framework::GradVarName("X")); + auto c_prev_diff_tensor = + ctx.Output(framework::GradVarName("C_prev")); + + auto* X = x_tensor->data(); + auto* C_prev = c_prev_tensor->data(); + auto* C = c_tensor->data(); + auto* H = h_tensor->data(); + + auto* H_diff = hdiff_tensor->data(); + auto* C_diff = cdiff_tensor->data(); + + auto* C_prev_diff = c_prev_diff_tensor->mutable_data(ctx.GetPlace()); + auto* X_diff = xdiff_tensor->mutable_data(ctx.GetPlace()); + + int N = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + for (int n = 0; n < N; ++n) { + for (int d = 0; d < D; ++d) { + T* c_prev_diff = C_prev_diff + d; + T* i_diff = X_diff + d; + T* f_diff = X_diff + 1 * D + d; + T* o_diff = X_diff + 2 * D + d; + T* g_diff = X_diff + 3 * D + d; + + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = C[d]; + const T tanh_c = tanh(c); + const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[d] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + C_diff += D; + H_diff += D; + X_diff += 4 * D; + C_prev_diff += D; + } + } +}; + +} // namespace operators +} // namespace paddle From 7883227716d1652bcff3d652873e5c0f1d7a004c Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 20 Sep 2017 14:56:55 -0700 Subject: [PATCH 02/28] lstm --- paddle/operators/lstm_unit_op.cc | 4 +- paddle/operators/lstm_unit_op.h | 1 + .../v2/framework/tests/test_lstm_unit_op.py | 39 +++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/framework/tests/test_lstm_unit_op.py diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index e3cac2605a..3600f19977 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -64,7 +64,7 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(Lstm-Unit Operator Equation: - i, j, f, o = split(X) + i, f, o, j = split(X) C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) H = C * sigm(o) @@ -99,3 +99,5 @@ REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, lstm_unit_grad, ops::LstmUnitGradOp); REGISTER_OP_CPU_KERNEL(lstm_unit, ops::LstmUnitKernel); +REGISTER_OP_CPU_KERNEL( + lstm_unit_grad, ops::LstmUnitGradKernel); diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h index 6e870f65e2..683034fe15 100644 --- a/paddle/operators/lstm_unit_op.h +++ b/paddle/operators/lstm_unit_op.h @@ -13,6 +13,7 @@ limitations under the License. */ #pragma once +#include "glog/logging.h" #include "paddle/framework/op_registry.h" namespace paddle { diff --git a/python/paddle/v2/framework/tests/test_lstm_unit_op.py b/python/paddle/v2/framework/tests/test_lstm_unit_op.py new file mode 100644 index 0000000000..f5888e908d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lstm_unit_op.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +from op_test import OpTest +import glog as log + + +def sigmoid_np(x): + return 1. / (1. + np.exp(-x)) + + +def tanh_np(x): + return 2 * sigmoid_np(2. * x) - 1. + + +class LstmUnitTest(OpTest): + def setUp(self): + self.op_type = "lstm_unit" + x_np = np.random.normal(size=(5, 16)).astype("float32") + c_np = np.random.normal(size=(5, 4)).astype("float32") + i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1) + forget_bias_np = 0. + self.attrs = {'forget_bias': 0.} + + new_c = c_np * sigmoid_np(f_np + forget_bias_np) + sigmoid_np( + i_np) * tanh_np(j_np) + new_h = tanh_np(new_c) * sigmoid_np(o_np) + + self.inputs = {'X': x_np, 'C_prev': c_np} + self.outputs = {'C': new_c, 'H': new_h} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01) + + +if __name__ == "__main__": + unittest.main() From bda67d9d4bd3875e8ca86072dd5b92aa9c2ec761 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 19 Sep 2017 11:27:47 -0700 Subject: [PATCH 03/28] Add TensorCopy method A method to copy a tensor with stride and dimension. It is useful for Crop, Concat, etc. --- paddle/operators/CMakeLists.txt | 1 + paddle/operators/detail/tensor_copy.h | 93 +++++++++++++++++++++++++++ paddle/operators/tensor_copy.h | 43 +++++++++++++ paddle/operators/tensor_copy_test.cc | 77 ++++++++++++++++++++++ 4 files changed, 214 insertions(+) create mode 100644 paddle/operators/detail/tensor_copy.h create mode 100644 paddle/operators/tensor_copy.h create mode 100644 paddle/operators/tensor_copy_test.cc diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e3e934bccc..95f0acace9 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -96,3 +96,4 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) +cc_test(tensor_copy_test SRCS tensor_copy_test.cc DEPS tensor paddle_memory) diff --git a/paddle/operators/detail/tensor_copy.h b/paddle/operators/detail/tensor_copy.h new file mode 100644 index 0000000000..44fe495648 --- /dev/null +++ b/paddle/operators/detail/tensor_copy.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/ddim.h" +#include "paddle/memory/memcpy.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace detail { + +template +struct TensorCopyFunctor; + +template +struct TensorCopyFunctor { + void operator()(const platform::DeviceContext& dev_ctx, const T* src, + framework::Dim<1> src_stride, framework::Dim<1> dst_dim, + framework::Dim<1> dst_stride, T* dst) const { + auto place = dev_ctx.GetPlace(); + if (platform::is_cpu_place(place)) { + auto& cpu_place = boost::get(place); + memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head); + } else { +#ifndef PADDLE_ONLY_CPU + auto& gpu_place = boost::get(place); + auto& cuda_ctx = + reinterpret_cast(dev_ctx); + memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim.head, + cuda_ctx.stream()); +#else + PADDLE_THROW("Paddle is not compiled with GPU"); +#endif + } + } +}; + +template +struct TensorCopyFunctor { + void operator()(const platform::DeviceContext& dev_ctx, const T* src, + framework::Dim src_stride, framework::Dim dst_dim, + framework::Dim dst_stride, T* dst) const { + for (int64_t i = 0; i < dst_dim.head; ++i) { + TensorCopyFunctor func; + func(dev_ctx, src, src_stride.tail, dst_dim.tail, dst_stride.tail, dst); + src += src_stride.head; + dst += dst_stride.head; + } + } +}; + +template +struct TensorCopyDimVisitor : public boost::static_visitor { + TensorCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src, + const framework::DDim& src_stride, + const framework::DDim& dst_stride, T* dst) + : dev_ctx_(dev_ctx), + src_(src), + src_stride_(src_stride), + dst_stride_(dst_stride), + dst_(dst) {} + + template + void operator()(Dim dst_dim) const { + Dim src_stride = boost::get(src_stride_); + Dim dst_stride = boost::get(dst_stride_); + constexpr int dim = Dim::dimensions; + TensorCopyFunctor functor; + functor(dev_ctx_, src_, src_stride, dst_dim, dst_stride, dst_); + } + + const platform::DeviceContext& dev_ctx_; + const T* src_; + const framework::DDim& src_stride_; + const framework::DDim& dst_stride_; + T* dst_; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/tensor_copy.h b/paddle/operators/tensor_copy.h new file mode 100644 index 0000000000..9210b4638b --- /dev/null +++ b/paddle/operators/tensor_copy.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/operators/detail/tensor_copy.h" + +namespace paddle { +namespace operators { + +// Copy a tensor from src to dst. +// The src and dst should be both on dev_ctx.GetPlace() +// +// the stride of an array (also referred to as increment, pitch or step size) is +// the number of locations in memory between beginnings of successive array +// elements +// +// For example, for tensor like [1, 3, 300, 300]. If there is no padding, the +// stride is [270000, 90000, 300, 1]. +// +// NOTE: When use GPU, the memcpy is async. To sync memcpy, please invoke +// `dev_ctx.Wait()`. +template +inline void TensorCopy(const platform::DeviceContext& dev_ctx, const T* src, + const framework::DDim& src_stride, + const framework::DDim& dst_dim, + const framework::DDim& dst_stride, T* dst) { + using namespace detail; + TensorCopyDimVisitor func(dev_ctx, src, src_stride, dst_stride, dst); + boost::apply_visitor(func, dst_dim); +} +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/tensor_copy_test.cc b/paddle/operators/tensor_copy_test.cc new file mode 100644 index 0000000000..df177096d3 --- /dev/null +++ b/paddle/operators/tensor_copy_test.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/tensor_copy.h" +#include "gtest/gtest.h" +#include "paddle/memory/memory.h" + +namespace paddle { +namespace operators { +TEST(TensorCopy, CPU_COPY) { + int src[] = { + 0, 1, 2, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, + }; + + framework::DDim src_stride({5, 1}); + + int dst[4]; + framework::DDim dst_dim({2, 2}); + framework::DDim dst_stride({2, 1}); + + platform::CPUDeviceContext ctx; + TensorCopy(ctx, src + 1, src_stride, dst_dim, dst_stride, dst); + + ASSERT_EQ(1, dst[0]); + ASSERT_EQ(2, dst[1]); + ASSERT_EQ(3, dst[2]); + ASSERT_EQ(4, dst[3]); +} + +#ifndef PADDLE_ONLY_CPU +TEST(TensorCopy, GPU_COPY) { + int src[] = { + 0, 1, 2, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, + }; + + platform::GPUPlace gpu0(0); + platform::CPUPlace cpu; + + int* gpu_src = reinterpret_cast(memory::Alloc(gpu0, sizeof(src))); + memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src)); + + framework::DDim src_stride({5, 1}); + + int dst[4]; + int* gpu_dst = reinterpret_cast(memory::Alloc(gpu0, sizeof(dst))); + + framework::DDim dst_dim({2, 2}); + framework::DDim dst_stride({2, 1}); + + platform::CUDADeviceContext ctx(gpu0); + TensorCopy(ctx, gpu_src + 1, src_stride, dst_dim, dst_stride, gpu_dst); + + memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst)); + + ASSERT_EQ(1, dst[0]); + ASSERT_EQ(2, dst[1]); + ASSERT_EQ(3, dst[2]); + ASSERT_EQ(4, dst[3]); + + memory::Free(gpu0, gpu_dst); + memory::Free(gpu0, gpu_src); +} + +#endif +} // namespace operators +} // namespace paddle \ No newline at end of file From 3fb0b6e67bb15530e2d41afd3f018278900a744d Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 19 Sep 2017 14:30:34 -0700 Subject: [PATCH 04/28] Renamed to strided_memcpy and prettify unittests Add unittests for Crop and Concat --- paddle/operators/CMakeLists.txt | 2 +- .../{tensor_copy.h => strided_memcpy.h} | 18 +- .../{tensor_copy.h => strided_memcpy.h} | 20 ++- paddle/operators/strided_memcpy_test.cc | 160 ++++++++++++++++++ paddle/operators/tensor_copy_test.cc | 77 --------- 5 files changed, 181 insertions(+), 96 deletions(-) rename paddle/operators/detail/{tensor_copy.h => strided_memcpy.h} (86%) rename paddle/operators/{tensor_copy.h => strided_memcpy.h} (65%) create mode 100644 paddle/operators/strided_memcpy_test.cc delete mode 100644 paddle/operators/tensor_copy_test.cc diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 95f0acace9..90c7171419 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -96,4 +96,4 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) -cc_test(tensor_copy_test SRCS tensor_copy_test.cc DEPS tensor paddle_memory) +cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) diff --git a/paddle/operators/detail/tensor_copy.h b/paddle/operators/detail/strided_memcpy.h similarity index 86% rename from paddle/operators/detail/tensor_copy.h rename to paddle/operators/detail/strided_memcpy.h index 44fe495648..b165224b37 100644 --- a/paddle/operators/detail/tensor_copy.h +++ b/paddle/operators/detail/strided_memcpy.h @@ -22,10 +22,10 @@ namespace operators { namespace detail { template -struct TensorCopyFunctor; +struct StridedMemcpyFunctor; template -struct TensorCopyFunctor { +struct StridedMemcpyFunctor { void operator()(const platform::DeviceContext& dev_ctx, const T* src, framework::Dim<1> src_stride, framework::Dim<1> dst_dim, framework::Dim<1> dst_stride, T* dst) const { @@ -48,12 +48,12 @@ struct TensorCopyFunctor { }; template -struct TensorCopyFunctor { +struct StridedMemcpyFunctor { void operator()(const platform::DeviceContext& dev_ctx, const T* src, framework::Dim src_stride, framework::Dim dst_dim, framework::Dim dst_stride, T* dst) const { for (int64_t i = 0; i < dst_dim.head; ++i) { - TensorCopyFunctor func; + StridedMemcpyFunctor func; func(dev_ctx, src, src_stride.tail, dst_dim.tail, dst_stride.tail, dst); src += src_stride.head; dst += dst_stride.head; @@ -62,10 +62,10 @@ struct TensorCopyFunctor { }; template -struct TensorCopyDimVisitor : public boost::static_visitor { - TensorCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src, - const framework::DDim& src_stride, - const framework::DDim& dst_stride, T* dst) +struct StridedCopyDimVisitor : public boost::static_visitor { + StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src, + const framework::DDim& src_stride, + const framework::DDim& dst_stride, T* dst) : dev_ctx_(dev_ctx), src_(src), src_stride_(src_stride), @@ -77,7 +77,7 @@ struct TensorCopyDimVisitor : public boost::static_visitor { Dim src_stride = boost::get(src_stride_); Dim dst_stride = boost::get(dst_stride_); constexpr int dim = Dim::dimensions; - TensorCopyFunctor functor; + StridedMemcpyFunctor functor; functor(dev_ctx_, src_, src_stride, dst_dim, dst_stride, dst_); } diff --git a/paddle/operators/tensor_copy.h b/paddle/operators/strided_memcpy.h similarity index 65% rename from paddle/operators/tensor_copy.h rename to paddle/operators/strided_memcpy.h index 9210b4638b..c9dd805184 100644 --- a/paddle/operators/tensor_copy.h +++ b/paddle/operators/strided_memcpy.h @@ -13,15 +13,17 @@ limitations under the License. */ #pragma once -#include "paddle/operators/detail/tensor_copy.h" +#include "paddle/operators/detail/strided_memcpy.h" namespace paddle { namespace operators { -// Copy a tensor from src to dst. -// The src and dst should be both on dev_ctx.GetPlace() +// Strided memory copy from src to dst. // -// the stride of an array (also referred to as increment, pitch or step size) is +// The src and dst should be both on dev_ctx.GetPlace(), otherwise, there will +// be a segment fault. +// +// The stride of an array (also referred to as increment, pitch or step size) is // the number of locations in memory between beginnings of successive array // elements // @@ -31,12 +33,12 @@ namespace operators { // NOTE: When use GPU, the memcpy is async. To sync memcpy, please invoke // `dev_ctx.Wait()`. template -inline void TensorCopy(const platform::DeviceContext& dev_ctx, const T* src, - const framework::DDim& src_stride, - const framework::DDim& dst_dim, - const framework::DDim& dst_stride, T* dst) { +inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src, + const framework::DDim& src_stride, + const framework::DDim& dst_dim, + const framework::DDim& dst_stride, T* dst) { using namespace detail; - TensorCopyDimVisitor func(dev_ctx, src, src_stride, dst_stride, dst); + StridedCopyDimVisitor func(dev_ctx, src, src_stride, dst_stride, dst); boost::apply_visitor(func, dst_dim); } } // namespace operators diff --git a/paddle/operators/strided_memcpy_test.cc b/paddle/operators/strided_memcpy_test.cc new file mode 100644 index 0000000000..05882a8873 --- /dev/null +++ b/paddle/operators/strided_memcpy_test.cc @@ -0,0 +1,160 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/strided_memcpy.h" +#include "gtest/gtest.h" +#include "paddle/memory/memory.h" + +namespace paddle { +namespace operators { + +TEST(StridedMemcpy, CPUCrop) { + // clang-format off + int src[] = { + 0, 1, 2, 0, 0, + 0, 3, 4, 0, 0, + 0, 0, 0, 0, 0, + }; + // clang-format on + + framework::DDim src_stride({5, 1}); + + int dst[4]; + framework::DDim dst_dim({2, 2}); + framework::DDim dst_stride({2, 1}); + + platform::CPUDeviceContext ctx; + StridedMemcpy(ctx, src + 1, src_stride, dst_dim, dst_stride, dst); + + ASSERT_EQ(1, dst[0]); + ASSERT_EQ(2, dst[1]); + ASSERT_EQ(3, dst[2]); + ASSERT_EQ(4, dst[3]); +} + +TEST(StridedMemcpy, CPUConcat) { + // clang-format off + int src[] = { + 1, 2, + 3, 4 + }; + // clang-format on + + int dst[8]; + + framework::DDim src_stride({2, 1}); + framework::DDim dst_dim({2, 2}); + framework::DDim dst_stride({4, 1}); + platform::CPUDeviceContext ctx; + + StridedMemcpy(ctx, src, src_stride, dst_dim, dst_stride, dst); + StridedMemcpy(ctx, src, src_stride, dst_dim, dst_stride, dst + 2); + + // clang-format off + int expect_dst[] = { + 1, 2, 1, 2, + 3, 4, 3, 4 + }; + // clang-format on + for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) { + ASSERT_EQ(expect_dst[i], dst[i]); + } +} + +#ifndef PADDLE_ONLY_CPU +TEST(StridedMemcpy, GPUCrop) { + // clang-format off + int src[] = { + 0, 1, 2, 0, 0, + 0, 3, 4, 0, 0, + 0, 0, 0, 0, 0, + }; + // clang-format on + + platform::GPUPlace gpu0(0); + platform::CPUPlace cpu; + + int* gpu_src = reinterpret_cast(memory::Alloc(gpu0, sizeof(src))); + memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src)); + + framework::DDim src_stride({5, 1}); + + int dst[4]; + int* gpu_dst = reinterpret_cast(memory::Alloc(gpu0, sizeof(dst))); + + framework::DDim dst_dim({2, 2}); + framework::DDim dst_stride({2, 1}); + + platform::CUDADeviceContext ctx(gpu0); + StridedMemcpy(ctx, gpu_src + 1, src_stride, dst_dim, dst_stride, + gpu_dst); + + memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream()); + ctx.Wait(); + + ASSERT_EQ(1, dst[0]); + ASSERT_EQ(2, dst[1]); + ASSERT_EQ(3, dst[2]); + ASSERT_EQ(4, dst[3]); + + memory::Free(gpu0, gpu_dst); + memory::Free(gpu0, gpu_src); +} + +TEST(StridedMemcpy, GPUConcat) { + // clang-format off + int src[] = { + 1, 2, + 3, 4 + }; + // clang-format on + + platform::GPUPlace gpu0(0); + platform::CPUPlace cpu; + + int* gpu_src = reinterpret_cast(memory::Alloc(gpu0, sizeof(src))); + memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src)); + + int dst[8]; + int* gpu_dst = reinterpret_cast(memory::Alloc(gpu0, sizeof(dst))); + + framework::DDim src_stride({2, 1}); + framework::DDim dst_dim({2, 2}); + framework::DDim dst_stride({4, 1}); + platform::CUDADeviceContext ctx(gpu0); + + StridedMemcpy(ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst); + StridedMemcpy(ctx, gpu_src, src_stride, dst_dim, dst_stride, + gpu_dst + 2); + + memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream()); + ctx.Wait(); + + // clang-format off + int expect_dst[] = { + 1, 2, 1, 2, + 3, 4, 3, 4 + }; + // clang-format on + for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) { + ASSERT_EQ(expect_dst[i], dst[i]); + } + + memory::Free(gpu0, gpu_dst); + memory::Free(gpu0, gpu_src); +} + +#endif +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/operators/tensor_copy_test.cc b/paddle/operators/tensor_copy_test.cc deleted file mode 100644 index df177096d3..0000000000 --- a/paddle/operators/tensor_copy_test.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/tensor_copy.h" -#include "gtest/gtest.h" -#include "paddle/memory/memory.h" - -namespace paddle { -namespace operators { -TEST(TensorCopy, CPU_COPY) { - int src[] = { - 0, 1, 2, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, - }; - - framework::DDim src_stride({5, 1}); - - int dst[4]; - framework::DDim dst_dim({2, 2}); - framework::DDim dst_stride({2, 1}); - - platform::CPUDeviceContext ctx; - TensorCopy(ctx, src + 1, src_stride, dst_dim, dst_stride, dst); - - ASSERT_EQ(1, dst[0]); - ASSERT_EQ(2, dst[1]); - ASSERT_EQ(3, dst[2]); - ASSERT_EQ(4, dst[3]); -} - -#ifndef PADDLE_ONLY_CPU -TEST(TensorCopy, GPU_COPY) { - int src[] = { - 0, 1, 2, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, - }; - - platform::GPUPlace gpu0(0); - platform::CPUPlace cpu; - - int* gpu_src = reinterpret_cast(memory::Alloc(gpu0, sizeof(src))); - memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src)); - - framework::DDim src_stride({5, 1}); - - int dst[4]; - int* gpu_dst = reinterpret_cast(memory::Alloc(gpu0, sizeof(dst))); - - framework::DDim dst_dim({2, 2}); - framework::DDim dst_stride({2, 1}); - - platform::CUDADeviceContext ctx(gpu0); - TensorCopy(ctx, gpu_src + 1, src_stride, dst_dim, dst_stride, gpu_dst); - - memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst)); - - ASSERT_EQ(1, dst[0]); - ASSERT_EQ(2, dst[1]); - ASSERT_EQ(3, dst[2]); - ASSERT_EQ(4, dst[3]); - - memory::Free(gpu0, gpu_dst); - memory::Free(gpu0, gpu_src); -} - -#endif -} // namespace operators -} // namespace paddle \ No newline at end of file From a7a66b80f70aef70e2adb7b75ca22d87f33885f6 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 19 Sep 2017 22:28:16 -0700 Subject: [PATCH 05/28] move OpProtoAndCheckerMaker from operator to op_proto_maker --- paddle/framework/CMakeLists.txt | 4 +- paddle/framework/op_proto_maker.cc | 58 ++++++++++++++++ paddle/framework/op_proto_maker.h | 88 +++++++++++++++++++++++++ paddle/framework/op_proto_maker_test.cc | 51 ++++++++++++++ paddle/framework/op_registry.h | 1 + paddle/framework/operator.cc | 38 ----------- paddle/framework/operator.h | 65 ------------------ paddle/framework/operator_test.cc | 34 ---------- paddle/operators/prelu_op.h | 2 +- 9 files changed, 202 insertions(+), 139 deletions(-) create mode 100644 paddle/framework/op_proto_maker.cc create mode 100644 paddle/framework/op_proto_maker.h create mode 100644 paddle/framework/op_proto_maker_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 3371962c63..e535f84dba 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,12 +19,14 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) +cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) +cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder) +cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) diff --git a/paddle/framework/op_proto_maker.cc b/paddle/framework/op_proto_maker.cc new file mode 100644 index 0000000000..151d61d5b1 --- /dev/null +++ b/paddle/framework/op_proto_maker.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { + +void OpProtoAndCheckerMaker::Validate() { + validated_ = true; + CheckNoDuplicatedInOutAttrs(); +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( + const std::string& name, const std::string& comment) { + auto* input = proto_->add_inputs(); + input->set_name(name); + input->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{input}; +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( + const std::string& name, const std::string& comment) { + auto* output = proto_->add_outputs(); + output->set_name(name); + output->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{output}; +} + +void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { + std::unordered_set names; + auto checker = [&](const std::string& name) { + PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); + names.insert(name); + }; + for (auto& attr : proto_->attrs()) { + checker(attr.name()); + } + for (auto& input : proto_->inputs()) { + checker(input.name()); + } + for (auto& output : proto_->outputs()) { + checker(output.name()); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_proto_maker.h b/paddle/framework/op_proto_maker.h new file mode 100644 index 0000000000..fea15a374b --- /dev/null +++ b/paddle/framework/op_proto_maker.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/attribute.h" +#include "paddle/framework/framework.pb.h" + +namespace paddle { +namespace framework { + +// this class not only make proto but also init attribute checkers. +class OpProtoAndCheckerMaker { + public: + OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : proto_(proto), op_checker_(op_checker) {} + + ~OpProtoAndCheckerMaker() { + PADDLE_ENFORCE(validated_, "should call Validate after build"); + } + + void Validate(); + + protected: + struct VariableBuilder { + OpProto::Var* var_; + + VariableBuilder& AsDuplicable() { + var_->set_duplicable(true); + return *this; + } + + VariableBuilder& AsIntermediate() { + var_->set_intermediate(true); + return *this; + } + + VariableBuilder& NotInGradient() { + var_->set_not_in_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, const std::string& comment); + + VariableBuilder AddOutput(const std::string& name, + const std::string& comment); + + template + TypedAttrChecker& AddAttr(const std::string& name, + const std::string& comment, + bool generated = false) { + auto* attr = proto_->add_attrs(); + attr->set_name(name); + attr->set_comment(comment); + attr->set_generated(generated); + attr->set_type(AttrTypeID()); + return op_checker_->AddAttrChecker(name); + } + + void AddComment(const std::string& comment) { proto_->set_comment(comment); } + + private: + void CheckNoDuplicatedInOutAttrs(); + + OpProto* proto_; + OpAttrChecker* op_checker_; + bool validated_{false}; +}; + +class NOPMaker : public OpProtoAndCheckerMaker { + public: + NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_proto_maker_test.cc b/paddle/framework/op_proto_maker_test.cc new file mode 100644 index 0000000000..b01e30f753 --- /dev/null +++ b/paddle/framework/op_proto_maker_test.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_proto_maker.h" + +#include "gtest/gtest.h" + +class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { + public: + TestAttrProtoMaker(paddle::framework::OpProto* proto, + paddle::framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("scale", "scale of test op"); + AddAttr("scale", "scale of test op"); + } +}; + +TEST(ProtoMaker, DuplicatedAttr) { + paddle::framework::OpProto op_proto; + paddle::framework::OpAttrChecker op_checker; + auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); + ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); +} + +class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { + public: + TestInOutProtoMaker(paddle::framework::OpProto* proto, + paddle::framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddInput("input", "input of test op"); + } +}; + +TEST(ProtoMaker, DuplicatedInOut) { + paddle::framework::OpProto op_proto; + paddle::framework::OpAttrChecker op_checker; + auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); + ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); +} \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 572dff860a..90077d0192 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_info.h" +#include "paddle/framework/op_proto_maker.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index f8a64a7866..49509af663 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -228,43 +228,5 @@ std::vector ExecutionContext::MultiOutput( return res; } -void OpProtoAndCheckerMaker::Validate() { - validated_ = true; - CheckNoDuplicatedInOutAttrs(); -} - -OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( - const std::string& name, const std::string& comment) { - auto* input = proto_->add_inputs(); - input->set_name(name); - input->set_comment(comment); - return OpProtoAndCheckerMaker::VariableBuilder{input}; -} - -OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( - const std::string& name, const std::string& comment) { - auto* output = proto_->add_outputs(); - output->set_name(name); - output->set_comment(comment); - return OpProtoAndCheckerMaker::VariableBuilder{output}; -} - -void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { - std::unordered_set names; - auto checker = [&](const std::string& name) { - PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); - names.insert(name); - }; - for (auto& attr : proto_->attrs()) { - checker(attr.name()); - } - for (auto& input : proto_->inputs()) { - checker(input.name()); - } - for (auto& output : proto_->outputs()) { - checker(output.name()); - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index b7c9c39402..1a78b6d1e1 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -167,71 +167,6 @@ class NOP : public OperatorBase { } }; -// this class not only make proto but also init attribute checkers. -class OpProtoAndCheckerMaker { - public: - OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : proto_(proto), op_checker_(op_checker) {} - - ~OpProtoAndCheckerMaker() { - PADDLE_ENFORCE(validated_, "should call Validate after build"); - } - - void Validate(); - - protected: - struct VariableBuilder { - OpProto::Var* var_; - - VariableBuilder& AsDuplicable() { - var_->set_duplicable(true); - return *this; - } - - VariableBuilder& AsIntermediate() { - var_->set_intermediate(true); - return *this; - } - - VariableBuilder& NotInGradient() { - var_->set_not_in_gradient(true); - return *this; - } - }; - - VariableBuilder AddInput(const std::string& name, const std::string& comment); - - VariableBuilder AddOutput(const std::string& name, - const std::string& comment); - - template - TypedAttrChecker& AddAttr(const std::string& name, - const std::string& comment, - bool generated = false) { - auto* attr = proto_->add_attrs(); - attr->set_name(name); - attr->set_comment(comment); - attr->set_generated(generated); - attr->set_type(AttrTypeID()); - return op_checker_->AddAttrChecker(name); - } - - void AddComment(const std::string& comment) { proto_->set_comment(comment); } - - private: - void CheckNoDuplicatedInOutAttrs(); - - OpProto* proto_; - OpAttrChecker* op_checker_; - bool validated_{false}; -}; - -class NOPMaker : public OpProtoAndCheckerMaker { - public: - NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) {} -}; - class InferShapeContext { public: InferShapeContext(const OperatorBase& op, const Scope& scope) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 20bbb11896..0beab0fac5 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -264,37 +264,3 @@ TEST(Operator, Clone) { auto b = a.Clone(); ASSERT_EQ(a.Type(), b->Type()); } - -class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { - public: - TestAttrProtoMaker(paddle::framework::OpProto* proto, - paddle::framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("scale", "scale of test op"); - AddAttr("scale", "scale of test op"); - } -}; - -TEST(ProtoMaker, DuplicatedAttr) { - paddle::framework::OpProto op_proto; - paddle::framework::OpAttrChecker op_checker; - auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); -} - -class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { - public: - TestInOutProtoMaker(paddle::framework::OpProto* proto, - paddle::framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of test op"); - AddInput("input", "input of test op"); - } -}; - -TEST(ProtoMaker, DuplicatedInOut) { - paddle::framework::OpProto op_proto; - paddle::framework::OpAttrChecker op_checker; - auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); -} \ No newline at end of file diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index 3269116c11..6b78ed295c 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -96,7 +96,7 @@ class PReluGradKernel : public framework::OpKernel { trans(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor(alpha_ptr)); - // TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready + // TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready } }; From c4e04b07b1f8842d6cbac41c73add20c20746a0d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 20 Sep 2017 11:12:15 -0700 Subject: [PATCH 06/28] add virtual to OpProtoAndCheckerMaker destructor --- paddle/framework/op_proto_maker.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/op_proto_maker.h b/paddle/framework/op_proto_maker.h index fea15a374b..4d55a37db9 100644 --- a/paddle/framework/op_proto_maker.h +++ b/paddle/framework/op_proto_maker.h @@ -25,7 +25,7 @@ class OpProtoAndCheckerMaker { OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : proto_(proto), op_checker_(op_checker) {} - ~OpProtoAndCheckerMaker() { + virtual ~OpProtoAndCheckerMaker() { PADDLE_ENFORCE(validated_, "should call Validate after build"); } From c1bb602e6b837af1e511525f930893fadd7e0696 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Wed, 20 Sep 2017 12:12:27 -0700 Subject: [PATCH 07/28] Add program.md --- doc/design/program.md | 59 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 doc/design/program.md diff --git a/doc/design/program.md b/doc/design/program.md new file mode 100644 index 0000000000..cc1ffcbaf9 --- /dev/null +++ b/doc/design/program.md @@ -0,0 +1,59 @@ +# Design Doc: ProgramDesc + +The basic structure of a PaddlePaddle program is some nested blocks, as a C++ or Java program. + +As described in [graph.md](./graph.md), the first five lines of the following PaddlePaddle program + +```python +x = layer.data("images") +l = layer.data("label") +y = layer.fc(x) +cost = layer.mse(y, l) +optimize(cost) +train(cost, reader=mnist.train()) +``` + +generates, or compiles, a PaddelPaddle program, which is represented by the following protobuf message: + +```protobuf +message ProgramDesc { + repeated BlockDesc blocks = 1; +} + +message BlockDesc { + repeated VarDesc vars = 1; + repeated OpDesc ops = 2; +} + +message OpDesc { + AttrDesc attrs = 1; + ... +} + +message AttrDesc { + required AttrType type = 1; + + // index into ProgramDesc::blocks when type==BLOCK + optional int32 block = 2; + ... +} +``` + +When each of the first five lines runs, related Python function, e.g., `layer.fc`, calls C++ InferShape functions. This InferShape function needs to access the properties of VarDesc's accessed by the current OpDesc. These VarDesc's might not be defined in the current block, but in some ancestor blocks. This requires that we can trace the parent of a block. + +A nested block is often an attribute of an operator, most likely, an IfElseOp or a WhileOp. In above solution, all blocks are in `ProgramDesc::blocks`, this implicitly assigns a zero-based ID to each block -- the index of the block in `ProgramDesc::blocks`. So that `AttrDesc::block` could be an integer block ID. + +With this design, the InferShape function should take the following parameters: + +```c++ +void InferShape(const ProgramDesc* program, + int current_block, + int current_operator) { + ... +} +``` + +where + +- `current_block` indices into `ProgramDesc::blocks`, +- `current_operator` indices into `BlockDesc::ops`. From be56bf9fae11d28821d65ccb59963fdd6d937e5b Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Wed, 20 Sep 2017 12:29:25 -0700 Subject: [PATCH 08/28] Update --- doc/design/program.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/doc/design/program.md b/doc/design/program.md index cc1ffcbaf9..fb8f86ac07 100644 --- a/doc/design/program.md +++ b/doc/design/program.md @@ -21,8 +21,9 @@ message ProgramDesc { } message BlockDesc { - repeated VarDesc vars = 1; - repeated OpDesc ops = 2; + required int32 parent = 1; + repeated VarDesc vars = 2; + repeated OpDesc ops = 3; } message OpDesc { @@ -46,9 +47,10 @@ A nested block is often an attribute of an operator, most likely, an IfElseOp or With this design, the InferShape function should take the following parameters: ```c++ -void InferShape(const ProgramDesc* program, - int current_block, - int current_operator) { +void InferShape(int current_block, + int current_operator, + ProgramDesc* program // might change VarDesc values. + ) { ... } ``` From 40e49c3f8b42ad94d161c719eb447dc4b7234164 Mon Sep 17 00:00:00 2001 From: Zhuoyuan Date: Wed, 20 Sep 2017 15:59:07 -0700 Subject: [PATCH 09/28] update python test --- python/paddle/v2/framework/tests/test_lstm_unit_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/v2/framework/tests/test_lstm_unit_op.py b/python/paddle/v2/framework/tests/test_lstm_unit_op.py index f5888e908d..8ce65bfc31 100644 --- a/python/paddle/v2/framework/tests/test_lstm_unit_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_unit_op.py @@ -1,7 +1,6 @@ import unittest import numpy as np from op_test import OpTest -import glog as log def sigmoid_np(x): From 2b10d322e5fb8811e644c165da87f259f46e7c60 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 20 Sep 2017 16:18:37 -0700 Subject: [PATCH 10/28] lstm kernels --- paddle/operators/lstm_unit_op.cu | 173 +++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 paddle/operators/lstm_unit_op.cu diff --git a/paddle/operators/lstm_unit_op.cu b/paddle/operators/lstm_unit_op.cu new file mode 100644 index 0000000000..fe45360bb3 --- /dev/null +++ b/paddle/operators/lstm_unit_op.cu @@ -0,0 +1,173 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/cross_entropy_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__device__ Dtype cuda_sigmoid(const Dtype x) { + return Dtype(1) / (Dtype(1) + exp(-x)); +} + +template +__device__ Dtype cuda_tanh(const Dtype x) { + return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x)); +} + +template +__global__ void LSTMUnitKernel(const int nthreads, const int dim, const int t, + const T* C_prev, const T* X, T* C, T* H, + const T forget_bias) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + + const T* X_offset = X + 4 * dim * n; + const T i = cuda_sigmoid(X_offset[d]); + const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias); + const T o = cuda_sigmoid(X_offset[2 * dim + d]); + const T g = cuda_tanh(X_offset[3 * dim + d]); + const T c_prev = C_prev[index]; + const T c = f * c_prev + i * g; + C[index] = c; + const T tanh_c = cuda_tanh(c); + H[index] = o * tanh_c; + } +} + +template +__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, + const T* C_prev, const T* X, const T* C, + const T* H, const T* C_diff, + const T* H_diff, T* C_prev_diff, + T* X_diff, const T forget_bias) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + const T* X_offset = X + 4 * dim * n; + T* c_prev_diff = C_prev_diff + index; + T* X_diff_offset = X_diff + 4 * dim * n; + T* i_diff = X_diff_offset + d; + T* f_diff = X_diff_offset + 1 * dim + d; + T* o_diff = X_diff_offset + 2 * dim + d; + T* g_diff = X_diff_offset + 3 * dim + d; + + const T i = cuda_sigmoid(X_offset[d]); + const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias); + const T o = cuda_sigmoid(X_offset[2 * dim + d]); + const T g = cuda_tanh(X_offset[3 * dim + d]); + const T c_prev = C_prev[index]; + const T c = C[index]; + const T tanh_c = cuda_tanh(c); + const T c_term_diff = + C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[index] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } +} + +template +class LstmUnitOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto* x_tensor = ctx.Input("X"); + auto* c_prev_tensor = ctx.Input("C_prev"); + auto* c_tensor = ctx.Output("C"); + auto* h_tensor = ctx.Output("H"); + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int b_size = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + const T* X = x_tensor->data(); + const T* C_prev = c_prev_tensor->data(); + + T* C = c_tensor->mutable_data(ctx.GetPlace()); + T* H = h_tensor->mutable_data(ctx.GetPlace()); + + int block = 512; + int n = b_size * D; + int grid = (n + block - 1) / block; + + LSTMUnitKernel<<>>(n, D, C_prev, X, C, H, forget_bias); + } +}; + +template +class LstmUnitGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto x_tensor = ctx.Input("X"); + auto c_prev_tensor = ctx.Input("C_prev"); + auto c_tensor = ctx.Input("C"); + auto h_tensor = ctx.Input("H"); + + auto hdiff_tensor = ctx.Input(framework::GradVarName("H")); + auto cdiff_tensor = ctx.Input(framework::GradVarName("C")); + + auto xdiff_tensor = ctx.Output(framework::GradVarName("X")); + auto c_prev_diff_tensor = + ctx.Output(framework::GradVarName("C_prev")); + + auto* X = x_tensor->data(); + auto* C_prev = c_prev_tensor->data(); + auto* C = c_tensor->data(); + auto* H = h_tensor->data(); + + auto* H_diff = hdiff_tensor->data(); + auto* C_diff = cdiff_tensor->data(); + + auto* C_prev_diff = c_prev_diff_tensor->mutable_data(ctx.GetPlace()); + auto* X_diff = xdiff_tensor->mutable_data(ctx.GetPlace()); + + int N = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int block = 512; + int n = N * D; + int grid = (n + block - 1) / block; + + LSTMUnitGradientKernel<<>>(n, D, C_prev, X, C, H, C_diff, + H_diff, C_prev_diff, X_diff, + T forget_bias) + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel); From 757c32f20bcd0f338b237ab3c90d63813dc3d22e Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 20 Sep 2017 17:15:24 -0700 Subject: [PATCH 11/28] lstm unit gpu --- paddle/operators/lstm_unit_op.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/operators/lstm_unit_op.cu b/paddle/operators/lstm_unit_op.cu index fe45360bb3..6e5e497899 100644 --- a/paddle/operators/lstm_unit_op.cu +++ b/paddle/operators/lstm_unit_op.cu @@ -35,7 +35,7 @@ __device__ Dtype cuda_tanh(const Dtype x) { } template -__global__ void LSTMUnitKernel(const int nthreads, const int dim, const int t, +__global__ void LSTMUnitKernel(const int nthreads, const int dim, const T* C_prev, const T* X, T* C, T* H, const T forget_bias) { CUDA_1D_KERNEL_LOOP(index, nthreads) { @@ -159,9 +159,9 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel { int n = N * D; int grid = (n + block - 1) / block; - LSTMUnitGradientKernel<<>>(n, D, C_prev, X, C, H, C_diff, - H_diff, C_prev_diff, X_diff, - T forget_bias) + LSTMUnitGradientKernel<<>>(n, D, C_prev, X, C, H, C_diff, + H_diff, C_prev_diff, X_diff, + forget_bias); } }; From ed27a3be78c3195bae26acb06d825ddc6db50a1b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 20 Sep 2017 23:36:22 +0800 Subject: [PATCH 12/28] add some log info --- paddle/gserver/activations/MKLDNNActivation.h | 3 ++- paddle/gserver/layers/MKLDNNConvLayer.cpp | 4 ++-- paddle/gserver/layers/MKLDNNFcLayer.cpp | 11 +++++++++-- paddle/gserver/layers/MKLDNNLayer.h | 2 ++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/gserver/activations/MKLDNNActivation.h b/paddle/gserver/activations/MKLDNNActivation.h index 86ffe38736..40dd8c618a 100644 --- a/paddle/gserver/activations/MKLDNNActivation.h +++ b/paddle/gserver/activations/MKLDNNActivation.h @@ -100,6 +100,7 @@ public: if (cnt_ == act.value->getElementCnt()) { return; } + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; cnt_ = act.value->getElementCnt(); stream_.reset(new MKLDNNStream()); auto eng = CPUEngine::Instance().getEngine(); @@ -110,7 +111,6 @@ public: float alpha = getAlpha(); float beta = getBeta(); - /// forward pipelineFwd_.clear(); val_ = std::dynamic_pointer_cast(act.value); if (val_ == nullptr) { @@ -152,6 +152,7 @@ public: if (!needResetBwd_) { return; } + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; needResetBwd_ = false; mkldnn::algorithm algo = getAlgo(this->getName()); float alpha = getBwdAlpha(); diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 88b047c89b..463880ff4a 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -64,7 +64,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap, // create biases if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0)); } return true; } @@ -535,7 +535,7 @@ void MKLDNNConvLayer::resetWgtValBwdData( } else { wgtValBwdData_ = wgtVal_; } - VLOG(MKLDNN_FMTS) << "weight value format for backward data" + VLOG(MKLDNN_FMTS) << "weight value format for backward data: " << wgtValBwdData_->getFormat(); } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index afd092666b..8cbfbd0d2b 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -49,7 +49,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, // create biases if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0)); } return true; } @@ -161,9 +161,16 @@ void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) { void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias) { + format wgtFmt = format::oihw; + if (inVal_->getFormat() == format::nChw8c) { + wgtFmt = format::oIhw8i; + } else if (inVal_->getFormat() == format::nChw16c) { + wgtFmt = format::oIhw16i; + } wgt = MKLDNNMatrix::create( - weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_); + weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_); wgt->downSpatial(); + VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); bias = (biases_ && biases_->getW()) ? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_) diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index d8555a8331..c09fd89462 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -115,6 +115,7 @@ public: copySeqInfoToOutputs(); size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt(); if (inputElemenCnt_ != elemenCnt) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; // reset when input total sizes changed, not only the batchsize inputElemenCnt_ = elemenCnt; reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); @@ -142,6 +143,7 @@ public: void backward(const UpdateCallback& callback) override { if (needResetBwd_) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); needResetBwd_ = false; } From 0f7c40713f2430744b03baf9f6a3ef89a466cc22 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 21 Sep 2017 22:59:33 +0800 Subject: [PATCH 13/28] add vgg and script for mkldnn benchmark --- benchmark/paddle/image/provider.py | 2 +- benchmark/paddle/image/run.mkldnn.sh | 46 ++++++++++++ benchmark/paddle/image/vgg.py | 103 +++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 1 deletion(-) create mode 100755 benchmark/paddle/image/run.mkldnn.sh create mode 100644 benchmark/paddle/image/vgg.py diff --git a/benchmark/paddle/image/provider.py b/benchmark/paddle/image/provider.py index 1ac47212b5..4703944c87 100644 --- a/benchmark/paddle/image/provider.py +++ b/benchmark/paddle/image/provider.py @@ -22,5 +22,5 @@ def initHook(settings, height, width, color, num_class, **kwargs): def process(settings, file_list): for i in xrange(1024): img = np.random.rand(1, settings.data_size).reshape(-1, 1).flatten() - lab = random.randint(0, settings.num_class) + lab = random.randint(0, settings.num_class - 1) yield img.astype('float32'), int(lab) diff --git a/benchmark/paddle/image/run.mkldnn.sh b/benchmark/paddle/image/run.mkldnn.sh new file mode 100755 index 0000000000..03a87afbc3 --- /dev/null +++ b/benchmark/paddle/image/run.mkldnn.sh @@ -0,0 +1,46 @@ +set -e + +function train() { + topology=$1 + bs=$2 + thread=1 + if [ $3 ]; then + thread=$3 + fi + if [ $thread -eq 1 ]; then + use_mkldnn=1 + log="logs/${topology}-mkldnn-${bs}.log" + else + use_mkldnn=0 + log="logs/${topology}-${thread}mklml-${bs}.log" + fi + args="batch_size=${bs}" + config="${topology}.py" + paddle train --job=time \ + --config=$config \ + --use_mkldnn=$use_mkldnn \ + --use_gpu=False \ + --trainer_count=$thread \ + --log_period=10 \ + --test_period=100 \ + --config_args=$args \ + 2>&1 | tee ${log} +} + +if [ ! -d "train.list" ]; then + echo " " > train.list +fi +if [ ! -d "logs" ]; then + mkdir logs +fi + +#========= mkldnn =========# +# vgg +train vgg 64 +train vgg 128 +train vgg 256 + +#========== mklml ===========# +train vgg 64 16 +train vgg 128 16 +train vgg 256 16 diff --git a/benchmark/paddle/image/vgg.py b/benchmark/paddle/image/vgg.py new file mode 100644 index 0000000000..69e4a4cddb --- /dev/null +++ b/benchmark/paddle/image/vgg.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +from paddle.trainer_config_helpers import * + +height = 224 +width = 224 +num_class = 1000 +batch_size = get_config_arg('batch_size', int, 64) +layer_num = get_config_arg('layer_num', int, 16) + +args = {'height': height, 'width': width, 'color': True, 'num_class': num_class} +define_py_data_sources2( + "train.list", None, module="provider", obj="process", args=args) + +settings( + batch_size=batch_size, + learning_rate=0.01 / batch_size, + learning_method=MomentumOptimizer(0.9), + regularization=L2Regularization(0.0005 * batch_size)) + +img = data_layer(name='image', size=height * width * 3) + + +def vgg_network(vgg_num=3): + tmp = img_conv_group( + input=img, + num_channels=3, + conv_padding=1, + conv_num_filter=[64, 64], + conv_filter_size=3, + conv_act=ReluActivation(), + pool_size=2, + pool_stride=2, + pool_type=MaxPooling()) + + tmp = img_conv_group( + input=tmp, + conv_num_filter=[128, 128], + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + + channels = [] + for i in range(vgg_num): + channels.append(256) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + channels = [] + for i in range(vgg_num): + channels.append(512) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + + tmp = fc_layer( + input=tmp, + size=4096, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + + tmp = fc_layer( + input=tmp, + size=4096, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + + return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation()) + + +if layer_num == 16: + vgg = vgg_network(3) +elif layer_num == 19: + vgg = vgg_network(4) +else: + print("Wrong layer number.") + +lab = data_layer('label', num_class) +loss = cross_entropy(input=vgg, label=lab) +outputs(loss) From f96d31d5525a42637bdf0a04ebfdc09b5d32fa62 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 21 Sep 2017 23:28:55 +0800 Subject: [PATCH 14/28] only link iomp when with MKLDNN and MKLML --- cmake/util.cmake | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmake/util.cmake b/cmake/util.cmake index e814cad36f..3aaa055c3b 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -97,6 +97,10 @@ function(link_paddle_exe TARGET_NAME) target_link_libraries(${TARGET_NAME} log) endif(ANDROID) + if(WITH_MKLDNN AND WITH_MKLML AND MKLDNN_IOMP_DIR) + target_link_libraries(${TARGET_NAME} "-L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed") + endif() + add_dependencies(${TARGET_NAME} ${external_project_dependencies}) endfunction() From 3563f1da3531d7d110c12878074834f4fe1211d4 Mon Sep 17 00:00:00 2001 From: Mimee Date: Thu, 21 Sep 2017 11:22:41 -0700 Subject: [PATCH 15/28] First half translation of operator definition. --- doc/howto/dev/new_op_en.md | 142 +++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 doc/howto/dev/new_op_en.md diff --git a/doc/howto/dev/new_op_en.md b/doc/howto/dev/new_op_en.md new file mode 100644 index 0000000000..56b31157f5 --- /dev/null +++ b/doc/howto/dev/new_op_en.md @@ -0,0 +1,142 @@ +# How to write a new operator + + - [Background](#Background) + - [Implementing C++ Types](#Implementing_C++_Types) + - [Defining ProtoMaker](#Defining_ProtoMaker) + - [Defining Operator](#Defining_Operator) + + +## Background + +Here are the base types needed. For details, please refer to the design docs. + +- `framework::OperatorBase`: Operator (Op)base class. +- `framework::OpKernel`: Base class for Op computation. +- `framework::OperatorWithKernel`: Inherited from OperatorBase, describing an operator with computation. +- `class OpProtoAndCheckerMaker`: Describes an Operator's input, output, attributes and description, mainly used to interface with Python API. + +An operator can be differentiated by whether in has kernel methods. An operator with kernel inherits from `OperatorWithKernel` while the ones without inherit from `OperatorBase`. This tutorial focuses on implementing operators with kernels. In short, an operator includes the following information: + + + Information | Where is it defined +-------------- | :---------------------- +OpProtoMake definition | `.cc`files, Backward Op does not need an OpProtoMake interface. +Op definition | `.cc` files +Kernel implementation | The kernel methods shared between CPU and GPU are defined in `.h` files. CPU-specific kernels live in `.cc` files, while GPU-specific kernels are implemented in `.cu`files. +Registering the Op | Ops are registered in `.cc` files; For Kernel registration, `.cc` files contain the CPU implementation, while `.cu` files contain the GPU implementation. + + +New Operator implementations are added to the list [paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators), with file names in the format `*_op.h` (if applicable), `*_op.cc`, `*_op.cu` (if applicable).** The system will use the naming scheme to automatically build operators and their corresponding Python extensions. ** + + +Let's take matrix multiplication operator, [MulOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc), as an example to introduce the writing of an Operator with Kernel. + + +## Implementing C++ Types + + +### 1. Defining Class ProtoMaker + +Matrix Multiplication can be written as $Out = X * Y$, meaning that the operation consists of two inputs and pne output. + +First, define `ProtoMaker` to describe the Operator's input, output, and additional comments: + +```cpp +class MulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(Tensor), 2D tensor of size (M x K)"); + AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); + AddOutput("Out", "(Tensor), 2D tensor of size (M x N)"); + AddComment(R"DOC( +Two Element Mul Operator. +The equation is: Out = X * Y +)DOC"); + } +}; +``` + +[`MulOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L43)is inherited from`framework::OpProtoAndCheckerMaker`, consisting of 2 variables in the constructor: + + - `framework::OpProto` stores Operator input and variable attribute, used for generating Python API interfaces. + - `framework::OpAttrChecker` is used to validate variable attributes. + +The constructor utilizes `AddInput`, `AddOutput`, and `AddComment`, so that the corresponding information will be added to `OpProto`. + +The code above adds two inputs `X` and `Y` to `MulOp`, an output `Out`, and their corresponding descriptions, in accordance to Paddle's [naming convention](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). + + +An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37) is implemented as follows: + +```cpp +template +class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of scale operator.").NotInGradient(); + AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); + AddComment(R"DOC(Scale operator +The equation is: Out = scale*X +)DOC"); + AddAttr("scale", "scale of scale operator.").SetDefault(1.0); + } +}; +``` + +There are two changes in this example: + +- `AddInput("X","...").NotInGradient()` expresses that input `X` is not involved in `ScaleOp`'s corresponding computation. If an input to an operator is not participating in back-propagation, please explicitly set `.NotInGradient()`. + +- `AddAttr("scale", "...").SetDefault(1.0);` adds `scale`constant as an attribute, and sets the default value to 1.0. + + +### 2. Defining Operator + +The following code defines the interface for MulOp: + +```cpp +class MulOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_EQ(dim0.size(), 2, + "input X(%s) should be a tensor with 2 dims, a matrix", + ctx.op_.Input("X")); + PADDLE_ENFORCE_EQ(dim1.size(), 2, + "input Y(%s) should be a tensor with 2 dims, a matrix", + ctx.op_.Input("Y")); + PADDLE_ENFORCE_EQ( + dim0[1], dim1[0], + "First matrix's width must be equal with second matrix's height."); + ctx.Output("Out")->Resize({dim0[0], dim1[1]}); + } +}; +``` + +[`MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L22) is inherited from `OperatorWithKernel`. Its `public` member + +```cpp +using framework::OperatorWithKernel::OperatorWithKernel; +``` + +expresses an operator constructor using base class `OperatorWithKernel`, alternatively written as + +```cpp +MulOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} +``` + +`InferShape` interface needs to be re-written.`InferShape` is a constant method and cannot modify Op's member variables, its constant member `const framework::InferShapeContext &ctx` can be used to extract input, output, and attributes. It functions to + + - 1). validate and error out early: it checks input data dimensions and types. + - 2). configures the tensor shape in the output. + +Usually `OpProtoMaker` and `Op`'s type definitions are written in `.cc` files, which also include the registration methods introduced later. From 2fd798162b7a08a314b9ffed9cb1a57c66dba1a2 Mon Sep 17 00:00:00 2001 From: Mimee Date: Thu, 21 Sep 2017 16:11:55 -0700 Subject: [PATCH 16/28] Another 1/3. --- doc/howto/dev/new_op_en.md | 95 +++++++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/doc/howto/dev/new_op_en.md b/doc/howto/dev/new_op_en.md index 56b31157f5..afd3775ede 100644 --- a/doc/howto/dev/new_op_en.md +++ b/doc/howto/dev/new_op_en.md @@ -4,7 +4,10 @@ - [Implementing C++ Types](#Implementing_C++_Types) - [Defining ProtoMaker](#Defining_ProtoMaker) - [Defining Operator](#Defining_Operator) - + - [Registering Operator](#Registering_Operator) + - [Compilation](#Compilation) + - [Python Binding](#Python_Binding) + - [Unit Tests](#Unit_Tests) ## Background @@ -140,3 +143,93 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs, - 2). configures the tensor shape in the output. Usually `OpProtoMaker` and `Op`'s type definitions are written in `.cc` files, which also include the registration methods introduced later. + +### 3. Defining OpKernel + +`MulKernel` inherits `framework::OpKernel`, which includes the following templates: + +- `typename Place` denotes device type. When different devices, namely the CPU and the GPU, share the same kernel, this template needs to be added. If they don't share kernels, this must not be added. An example of a non-sharing kernel is [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43). + +- `typename T` denotes data type, such as `float` or `double`. + +`MulKernel` types need to rewrite the interface for `Compute`. +- `Compute` takes one input variable `const framework::ExecutionContext& context`. +- Compared with `InferShapeContext`, `ExecutionContext` includes device types, and can similarly extract input, output, and attribute variables. +- `Compute` implements the computation logics of an `OpKernel`. + +`MulKernel`'s implementation of `Compute` is as follows: + + ```cpp + template + class MulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + } + }; + ``` + +Note that **different devices (CPU, GPU)share an Op definition; whether or not they share the same `OpKernel` depends on whether `Compute` calls functions that support both devices.** + +`MulOp`'s CPU and GPU share the same `Kernel`. A non-sharing `OpKernel` example can be seen in [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43). + +To ease the writing of `OpKernel` compute, and for reusing code cross-device, `Eigen unsupported Tensor` module is used to implement `Compute` interface. To learn about how the Eigen library is used in PaddlePaddle, please see [usage document](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md). + + +This concludes the forward implementation of an operator. Next its operation and kernel need to be registered in a `.cc` file. + +The definition of its corresponding backward operator, if applicable, is similar to that of an forward operator. **Note that a backward operator does not include a `ProtoMaker`**. + +### 4. Registering Operator + +- In `.cc` files, register forward and backward operator classes and the CPU kernel. + + ```cpp + namespace ops = paddle::operators; + REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); + REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); + REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); + ``` + + In that code block, + + - `REGISTER_OP` registers the `ops::MulOp` class, type named `mul`, its type `ProtoMaker` is `ops::MulOpMaker`, registering `ops::MulOpGrad` as `mul_grad`. + - `REGISTER_OP_WITHOUT_GRADIENT` registers an operator without gradient. + - `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulKernel`. + + +- Registering GPU Kernel in `.cu` files + - Note that if GPU Kernel is implemented using the `Eigen unsupported` module, then on top of `.cu`, a macro definition `#define EIGEN_USE_GPU` is needed, such as + + ```cpp + // if use Eigen unsupported module before include head files + #define EIGEN_USE_GPU + + namespace ops = paddle::operators; + REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); + REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); + ``` + +### 5. Compilation + +Run the following commands to compile. + +``` +make mul_op +``` + +## Python Binding + +The system will automatically bind to Python and link it to a generated library. + +## Unit Tests + +Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py). \ No newline at end of file From d72f636c2edc28cf603436eadfba501319faa6ac Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 22 Sep 2017 11:17:44 +0800 Subject: [PATCH 17/28] add KMP setting and default test vgg19 --- benchmark/paddle/image/run.mkldnn.sh | 4 ++++ benchmark/paddle/image/vgg.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmark/paddle/image/run.mkldnn.sh b/benchmark/paddle/image/run.mkldnn.sh index 03a87afbc3..9f76243695 100755 --- a/benchmark/paddle/image/run.mkldnn.sh +++ b/benchmark/paddle/image/run.mkldnn.sh @@ -1,5 +1,9 @@ set -e +unset OMP_NUM_THREADS MKL_NUM_THREADS +export OMP_DYNAMIC="FALSE" +export KMP_AFFINITY="granularity=fine,compact,0,0" + function train() { topology=$1 bs=$2 diff --git a/benchmark/paddle/image/vgg.py b/benchmark/paddle/image/vgg.py index 69e4a4cddb..b8429975f5 100644 --- a/benchmark/paddle/image/vgg.py +++ b/benchmark/paddle/image/vgg.py @@ -5,7 +5,7 @@ height = 224 width = 224 num_class = 1000 batch_size = get_config_arg('batch_size', int, 64) -layer_num = get_config_arg('layer_num', int, 16) +layer_num = get_config_arg('layer_num', int, 19) args = {'height': height, 'width': width, 'color': True, 'num_class': num_class} define_py_data_sources2( From bea39f631420855ce95dcfb4a94cd7773341b19b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 22 Sep 2017 17:13:54 +0800 Subject: [PATCH 18/28] add compare simple net of mkldnn and cpu --- paddle/gserver/layers/MKLDNNConvLayer.cpp | 27 +++++--- paddle/trainer/tests/CMakeLists.txt | 13 ++++ .../sample_trainer_config_simple_net.conf | 63 +++++++++++++++++++ paddle/trainer/tests/test_CompareTwoNets.cpp | 11 ++++ 4 files changed, 105 insertions(+), 9 deletions(-) create mode 100644 paddle/trainer/tests/sample_trainer_config_simple_net.conf diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 463880ff4a..9a0abd291a 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -251,22 +251,31 @@ void MKLDNNConvLayer::resetInValue( // create buffer and reorder if input value do not match cpuInVal_ = nullptr; cvtInVal_ = nullptr; - if (inputIsOnlyMKLDNN()) { - MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); - CHECK(dnnIn) << "Input should be MKLDNNMatrix"; - if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) { - CHECK_EQ(dnnIn->getFormat(), format::nc); + + MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); + CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr); + if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { + in = dnnIn; + return; + } + if (dnnIn) { + if (dnnIn->getFormat() == format::nc) { CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format"; // create a new one with nchw format and same data memory::dims inDims = memory::dims{bs_, ic_, 1, 1}; dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); - CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()); } - in = dnnIn; + if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { + in = dnnIn; + return; + } + cpuInVal_ = dnnIn; + in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); + cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in); + CHECK(cvtInVal_) << "should not be emptry"; } else { - const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; - cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_); + cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) { // create new mkldnn matrix in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); diff --git a/paddle/trainer/tests/CMakeLists.txt b/paddle/trainer/tests/CMakeLists.txt index f01ad4142d..066837ca95 100644 --- a/paddle/trainer/tests/CMakeLists.txt +++ b/paddle/trainer/tests/CMakeLists.txt @@ -37,6 +37,19 @@ add_test(NAME test_CompareTwoNets --config_file_a=trainer/tests/sample_trainer_config_qb_rnn.conf --config_file_b=trainer/tests/sample_trainer_config_rnn.conf WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +################ test_CompareMKLDNNandCPU ###################### +if(WITH_MKLDNN) + add_unittest_without_exec(test_CompareMKLDNNandCPU + test_CompareTwoNets.cpp) + add_test(NAME test_CompareMKLDNNandCPU + COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python/ + ${CMAKE_CURRENT_BINARY_DIR}/test_CompareMKLDNNandCPU + --config_file_a=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_a=True + --config_file_b=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_b=False + --use_gpu=False + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +endif() + ############### test_CompareTwoOpts ################### add_unittest_without_exec(test_CompareTwoOpts test_CompareTwoOpts.cpp) diff --git a/paddle/trainer/tests/sample_trainer_config_simple_net.conf b/paddle/trainer/tests/sample_trainer_config_simple_net.conf new file mode 100644 index 0000000000..77f7816153 --- /dev/null +++ b/paddle/trainer/tests/sample_trainer_config_simple_net.conf @@ -0,0 +1,63 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +################################### Data Configuration ################################### +TrainData(ProtoData(files = "trainer/tests/mnist.list")) +################################### Algorithm Configuration ################################### +settings(batch_size = 1000, + learning_method = MomentumOptimizer(momentum=0.5, sparse=False)) +################################### Network Configuration ################################### +data = data_layer(name ="input", size=784) + +tmp = img_conv_layer(input=data, + num_channels=1, + filter_size=3, + num_filters=32, + padding=1, + shared_biases=True, + act=ReluActivation()) + +tmp = img_pool_layer(input=tmp, + pool_size=3, + stride=2, + padding=1, + pool_type=AvgPooling()) + +tmp = img_conv_layer(input=tmp, + filter_size=3, + num_filters=64, + padding=1, + shared_biases=True, + act=ReluActivation()) + +tmp = img_pool_layer(input=tmp, + pool_size=3, + stride=2, + padding=1, + pool_type=MaxPooling()) + +tmp = fc_layer(input=tmp, size=64, + bias_attr=True, + act=ReluActivation()) + +output = fc_layer(input=tmp, size=10, + bias_attr=True, + act=SoftmaxActivation()) + +lbl = data_layer(name ="label", size=10) + +cost = classification_cost(input=output, label=lbl) +outputs(cost) diff --git a/paddle/trainer/tests/test_CompareTwoNets.cpp b/paddle/trainer/tests/test_CompareTwoNets.cpp index 94f65e545d..00ad75b363 100644 --- a/paddle/trainer/tests/test_CompareTwoNets.cpp +++ b/paddle/trainer/tests/test_CompareTwoNets.cpp @@ -26,12 +26,15 @@ DECLARE_int32(gpu_id); DECLARE_bool(local); DECLARE_bool(use_gpu); +DECLARE_bool(use_mkldnn); DECLARE_string(config); DECLARE_string(nics); DEFINE_string(config_file_a, "", "config of one network to compare"); DEFINE_string(config_file_b, "", "config of another network to compare"); +DEFINE_bool(use_mkldnn_a, false, "whether to use mkldnn to run network"); +DEFINE_bool(use_mkldnn_b, false, "whether to use mkldnn to run network"); DEFINE_bool(need_high_accuracy, false, "whether need to run in double accuracy"); @@ -128,6 +131,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) { matA.getWidth()); } + if (FLAGS_use_mkldnn_a || FLAGS_use_mkldnn_b) { + // some format of mkldnn parameter is different with cpu + // test_MKLDNN will check the parameters + return; + } + vector& parametersA = comDataA.parameters; vector& parametersB = comDataB.parameters; @@ -167,10 +176,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) { TEST(Trainer, create) { ComData dataA; + FLAGS_use_mkldnn = FLAGS_use_mkldnn_a; calcGradient(dataA, FLAGS_config_file_a); LOG(INFO) << "\n\nforwardBackward of Network A is finished\n\n"; ComData dataB; + FLAGS_use_mkldnn = FLAGS_use_mkldnn_b; calcGradient(dataB, FLAGS_config_file_b); LOG(INFO) << "\n\nforwardBackward of the Network B is finished\n\n"; From ba1f5b5c5857631d41ff0376a815da0e2248a5e4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 22 Sep 2017 11:51:56 -0700 Subject: [PATCH 19/28] Sync computation when Python invoke `run` * Since GPU is an async device by default. We should sync computation when Python invoke `run`. So Python can get the correct computation result --- paddle/platform/device_context.h | 5 +++-- paddle/pybind/pybind.cc | 8 +++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index a106592e45..f6a39a8e26 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -34,13 +34,14 @@ class DeviceContext { template DeviceType* get_eigen_device() const; + + virtual void Wait() const {} }; class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); explicit CPUDeviceContext(CPUPlace place); - virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext { virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ - void Wait() const; + void Wait() const override; /*! \brief Return place in the device context. */ Place GetPlace() const override; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index c7009a604f..f5cba1b78f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -238,7 +238,13 @@ All parameter, weight, gradient are variables in Paddle. return Backward(forwardOp, no_grad_vars).release(); }) .def("infer_shape", &OperatorBase::InferShape) - .def("run", &OperatorBase::Run) + .def("run", + [](OperatorBase &self, + const Scope &scope, + const platform::DeviceContext &dev_ctx) { + self.Run(scope, dev_ctx); + dev_ctx.Wait(); + }) .def("type", [](const OperatorBase &op) -> std::string { return op.Type(); }) .def("outputs", From 699dd6e9f4425ce63c938a90eaeef9982436c9d5 Mon Sep 17 00:00:00 2001 From: Mimee Date: Fri, 22 Sep 2017 14:44:52 -0700 Subject: [PATCH 20/28] new line at the end --- doc/howto/dev/new_op_en.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/howto/dev/new_op_en.md b/doc/howto/dev/new_op_en.md index afd3775ede..b7aa501db9 100644 --- a/doc/howto/dev/new_op_en.md +++ b/doc/howto/dev/new_op_en.md @@ -232,4 +232,4 @@ The system will automatically bind to Python and link it to a generated library. ## Unit Tests -Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py). \ No newline at end of file +Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py). From 057e8102866e0f2356050d7cf4bb8fcab968267b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 22 Sep 2017 18:15:43 -0700 Subject: [PATCH 21/28] Simplify GetAttrType code --- paddle/framework/attribute.cc | 41 ----------------------------------- paddle/framework/attribute.h | 14 +++++++----- 2 files changed, 9 insertions(+), 46 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index fda89252e3..e705da0131 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -28,47 +28,6 @@ ProgramDesc& GetProgramDesc() { return *g_program_desc; } -template <> -AttrType AttrTypeID() { - return BOOLEAN; -} -template <> -AttrType AttrTypeID() { - return INT; -} -template <> -AttrType AttrTypeID() { - return FLOAT; -} -template <> -AttrType AttrTypeID() { - return STRING; -} -template <> -AttrType AttrTypeID>() { - return BOOLEANS; -} -template <> -AttrType AttrTypeID>() { - return INTS; -} -template <> -AttrType AttrTypeID>() { - return FLOATS; -} -template <> -AttrType AttrTypeID>() { - return STRINGS; -} -template <> -AttrType AttrTypeID>>() { - return INT_PAIRS; -} -template <> -AttrType AttrTypeID() { - return BLOCK; -} - Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case framework::AttrType::BOOLEAN: { diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 48b54b5422..13f2877226 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -27,10 +27,11 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef boost::variant, std::vector, std::vector, - std::vector, - std::vector>, BlockDesc*> +// The order should be as same as framework.proto +typedef boost::variant, + std::vector, std::vector, + std::vector>, bool, + std::vector, BlockDesc*> Attribute; typedef std::unordered_map AttributeMap; @@ -38,7 +39,10 @@ typedef std::unordered_map AttributeMap; ProgramDesc& GetProgramDesc(); template -AttrType AttrTypeID(); +inline AttrType AttrTypeID() { + Attribute tmp = T(); + return static_cast(tmp.which() - 1); +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc); From 86a9434c71e94e5a53d1bb247e6a54095ff4928e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 25 Sep 2017 15:40:54 +0800 Subject: [PATCH 22/28] rename script and modify comments --- .../image/{run.mkldnn.sh => run_mkldnn.sh} | 29 ++++++++++--------- paddle/trainer/tests/test_CompareTwoNets.cpp | 4 +-- 2 files changed, 17 insertions(+), 16 deletions(-) rename benchmark/paddle/image/{run.mkldnn.sh => run_mkldnn.sh} (70%) diff --git a/benchmark/paddle/image/run.mkldnn.sh b/benchmark/paddle/image/run_mkldnn.sh similarity index 70% rename from benchmark/paddle/image/run.mkldnn.sh rename to benchmark/paddle/image/run_mkldnn.sh index 9f76243695..5b0a037344 100755 --- a/benchmark/paddle/image/run.mkldnn.sh +++ b/benchmark/paddle/image/run_mkldnn.sh @@ -7,16 +7,17 @@ export KMP_AFFINITY="granularity=fine,compact,0,0" function train() { topology=$1 bs=$2 - thread=1 - if [ $3 ]; then - thread=$3 - fi - if [ $thread -eq 1 ]; then - use_mkldnn=1 + use_mkldnn=$3 + if [ $3 == "True" ]; then + use_mkldnn=$3 + thread=1 log="logs/${topology}-mkldnn-${bs}.log" - else - use_mkldnn=0 + elif [ $3 == "False" ]; then + use_mkldnn=$3 + thread=`nproc` log="logs/${topology}-${thread}mklml-${bs}.log" + else + echo "Wrong input $3, use True or False." fi args="batch_size=${bs}" config="${topology}.py" @@ -40,11 +41,11 @@ fi #========= mkldnn =========# # vgg -train vgg 64 -train vgg 128 -train vgg 256 +train vgg 64 True +train vgg 128 True +train vgg 256 True #========== mklml ===========# -train vgg 64 16 -train vgg 128 16 -train vgg 256 16 +train vgg 64 False +train vgg 128 False +train vgg 256 False diff --git a/paddle/trainer/tests/test_CompareTwoNets.cpp b/paddle/trainer/tests/test_CompareTwoNets.cpp index 00ad75b363..307645d2c3 100644 --- a/paddle/trainer/tests/test_CompareTwoNets.cpp +++ b/paddle/trainer/tests/test_CompareTwoNets.cpp @@ -33,8 +33,8 @@ DECLARE_string(nics); DEFINE_string(config_file_a, "", "config of one network to compare"); DEFINE_string(config_file_b, "", "config of another network to compare"); -DEFINE_bool(use_mkldnn_a, false, "whether to use mkldnn to run network"); -DEFINE_bool(use_mkldnn_b, false, "whether to use mkldnn to run network"); +DEFINE_bool(use_mkldnn_a, false, "whether to use mkldnn to run config_file_a"); +DEFINE_bool(use_mkldnn_b, false, "whether to use mkldnn to run config_file_b"); DEFINE_bool(need_high_accuracy, false, "whether need to run in double accuracy"); From d6a27ade5469dbaf832983fabaa32ec70ab4c2f5 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 25 Sep 2017 17:13:05 +0800 Subject: [PATCH 23/28] add OMP SGD to speedup with CPUs --- paddle/math/Vector.h | 22 +++++++++++++++++++ paddle/parameter/FirstOrderOptimizer.h | 10 +++++++++ paddle/parameter/ParameterUpdateFunctions.cpp | 3 +++ 3 files changed, 35 insertions(+) diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index 80b9775fcc..7dbf3cfb0d 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -92,6 +92,28 @@ public: const T* getData() const { return this->data_; } T* getData() { return this->data_; } +#ifdef PADDLE_USE_MKLDNN + /** + * sgd update with openmp to speedup + */ + void sgdUpdateWithOMP(VectorT& gradVec, + VectorT& momVec, + T learningRate, + T momentum, + T decayRate) { + size_t size = this->getSize(); + T* val = this->getData(); + T* grd = gradVec.getData(); + T* mom = momVec.getData(); + decayRate *= learningRate; +#pragma omp parallel for + for (size_t i = 0; i < size; ++i) { + mom[i] = momentum * mom[i] - learningRate * grd[i] - decayRate * val[i]; + val[i] += mom[i]; + } + } +#endif + virtual void zeroMem() = 0; // set all elements to value virtual void reset(const T& value) = 0; diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index caa78acd98..73e09aee23 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -37,6 +37,15 @@ public: real torch_learningRate = optConfig_.learning_method() == "torch_momentum" ? 1.0 - paraConfig.momentum() : 1.0; +#ifdef PADDLE_USE_MKLDNN + vecs[PARAMETER_VALUE]->sgdUpdateWithOMP( + *vecs[PARAMETER_GRADIENT], + *vecs[PARAMETER_MOMENTUM], + learningRate_ * paraConfig.learning_rate() * + (firstTime_ ? 1.0 : torch_learningRate), + paraConfig.momentum(), + applyDecay_ ? paraConfig.decay_rate() : 0); +#else vecs[PARAMETER_VALUE]->sgdUpdate( *vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM], @@ -44,6 +53,7 @@ public: (firstTime_ ? 1.0 : torch_learningRate), paraConfig.momentum(), applyDecay_ ? paraConfig.decay_rate() : 0); +#endif } virtual void finishBatch() { firstTime_ = false; } }; diff --git a/paddle/parameter/ParameterUpdateFunctions.cpp b/paddle/parameter/ParameterUpdateFunctions.cpp index c8af7105c7..8b3be062b6 100644 --- a/paddle/parameter/ParameterUpdateFunctions.cpp +++ b/paddle/parameter/ParameterUpdateFunctions.cpp @@ -30,6 +30,9 @@ void sgdUpdateCpu(real learningRate, const real* grad, real* momentumVec) { decayRate *= learningRate; +#ifdef PADDLE_USE_MKLDNN +#pragma omp parallel for +#endif for (size_t i = 0; i < size; ++i) { momentumVec[i] = momentum * momentumVec[i] - learningRate * grad[i] - decayRate * value[i]; From d54e8420be00353af3f3524be8ea6d039e49c6be Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 25 Sep 2017 10:31:32 -0700 Subject: [PATCH 24/28] Stabilize prelu gradient check --- python/paddle/v2/framework/tests/test_prelu_op.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index 2b6b7db368..676fd9f7c5 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -7,6 +7,14 @@ class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" x_np = np.random.normal(size=(10, 10)).astype("float32") + + for pos, val in np.ndenumerate(x_np): + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + while abs(val) < 1e-3: + x_np[pos] = np.random.normal() + val = x_np[pos] + x_np_sign = np.sign(x_np) x_np = x_np_sign * np.maximum(x_np, .005) alpha_np = np.array([.1]) From f4832fe0b477606be8434b4070abe8b6e38b2aa4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 25 Sep 2017 11:46:43 -0700 Subject: [PATCH 25/28] Remove IntPair type in attribute --- paddle/framework/attribute.cc | 8 -------- paddle/framework/attribute.h | 3 +-- paddle/framework/framework.proto | 13 +++---------- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index e705da0131..510dc28c57 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -70,14 +70,6 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } - case framework::AttrType::INT_PAIRS: { - std::vector> val(attr_desc.int_pairs_size()); - for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { - val[i].first = attr_desc.int_pairs(i).first(); - val[i].second = attr_desc.int_pairs(i).second(); - } - return val; - } case framework::AttrType::BLOCK: { return GetProgramDesc().mutable_blocks(attr_desc.block_idx()); } diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 13f2877226..488fa38faf 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -29,8 +29,7 @@ namespace framework { // The order should be as same as framework.proto typedef boost::variant, - std::vector, std::vector, - std::vector>, bool, + std::vector, std::vector, bool, std::vector, BlockDesc*> Attribute; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 6fcfe6de25..951c7afbc1 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -22,17 +22,11 @@ enum AttrType { INTS = 3; FLOATS = 4; STRINGS = 5; - INT_PAIRS = 6; - BOOLEAN = 7; - BOOLEANS = 8; - BLOCK = 9; + BOOLEAN = 6; + BOOLEANS = 7; + BLOCK = 8; } -message IntPair { - required int32 first = 1; - required int32 second = 2; -}; - // OpDesc describes an instance of a C++ framework::OperatorBase // derived class type. message OpDesc { @@ -46,7 +40,6 @@ message OpDesc { repeated int32 ints = 6; repeated float floats = 7; repeated string strings = 8; - repeated IntPair int_pairs = 9; optional bool b = 10; repeated bool bools = 11; optional int32 block_idx = 12; From d0ad82cff1eea9b91bfb9c4abc6c0300cffb99e4 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 25 Sep 2017 12:01:59 -0700 Subject: [PATCH 26/28] fix nv_library (#4370) * fix nv_library * fix symbol in gpu_info.h --- cmake/generic.cmake | 2 +- paddle/platform/gpu_info.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 0bbf922931..ff9868fc4e 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -253,7 +253,7 @@ function(nv_library TARGET_NAME) foreach(source_file ${nv_library_SRCS}) string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) - list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND nv_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) endif() endforeach() add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS}) diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index ed2420b874..f0c825bd9b 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -36,7 +36,7 @@ int GetCurrentDeviceId(); //! Set the GPU device id for next execution. void SetDeviceId(int device_id); -//!Get the memory usage of current GPU device. +//! Get the memory usage of current GPU device. void GpuMemoryUsage(size_t &available, size_t &total); //! Get the maximum allocation size of current GPU device. From 88952fbab9a11fb4f939c9046699ba9ecfa64314 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 25 Sep 2017 20:03:01 +0800 Subject: [PATCH 27/28] use existed sgd updater function --- paddle/math/Vector.h | 22 ---------------------- paddle/parameter/FirstOrderOptimizer.h | 15 ++++++++------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index 7dbf3cfb0d..80b9775fcc 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -92,28 +92,6 @@ public: const T* getData() const { return this->data_; } T* getData() { return this->data_; } -#ifdef PADDLE_USE_MKLDNN - /** - * sgd update with openmp to speedup - */ - void sgdUpdateWithOMP(VectorT& gradVec, - VectorT& momVec, - T learningRate, - T momentum, - T decayRate) { - size_t size = this->getSize(); - T* val = this->getData(); - T* grd = gradVec.getData(); - T* mom = momVec.getData(); - decayRate *= learningRate; -#pragma omp parallel for - for (size_t i = 0; i < size; ++i) { - mom[i] = momentum * mom[i] - learningRate * grd[i] - decayRate * val[i]; - val[i] += mom[i]; - } - } -#endif - virtual void zeroMem() = 0; // set all elements to value virtual void reset(const T& value) = 0; diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index 73e09aee23..895e8d6a63 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "ParameterOptimizer.h" +#include "ParameterUpdateFunctions.h" #include "Regularizer.h" namespace paddle { @@ -38,13 +39,13 @@ public: ? 1.0 - paraConfig.momentum() : 1.0; #ifdef PADDLE_USE_MKLDNN - vecs[PARAMETER_VALUE]->sgdUpdateWithOMP( - *vecs[PARAMETER_GRADIENT], - *vecs[PARAMETER_MOMENTUM], - learningRate_ * paraConfig.learning_rate() * - (firstTime_ ? 1.0 : torch_learningRate), - paraConfig.momentum(), - applyDecay_ ? paraConfig.decay_rate() : 0); + sgdUpdate(learningRate_ * paraConfig.learning_rate() * + (firstTime_ ? 1.0 : torch_learningRate), + paraConfig.momentum(), + applyDecay_ ? paraConfig.decay_rate() : 0, + vecs[PARAMETER_VALUE].get(), + vecs[PARAMETER_GRADIENT].get(), + vecs[PARAMETER_MOMENTUM].get()); #else vecs[PARAMETER_VALUE]->sgdUpdate( *vecs[PARAMETER_GRADIENT], From 8419b60ec710e8939cf23892bface93d4a37f39e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 26 Sep 2017 10:49:52 +0800 Subject: [PATCH 28/28] fix typo --- doc/getstarted/build_and_install/docker_install_cn.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index 84e3317774..30b144d849 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -20,7 +20,7 @@ Docker使用入门 docker pull paddlepaddle/paddle:0.10.0 - 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。 + 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用docker.paddlepaddle.org/paddle下载。 - *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。 实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。