From 46d30ec680f494e4cc30a73330074497da064fbd Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 17 Aug 2017 20:34:02 -0700 Subject: [PATCH 01/27] init minst.py --- python/paddle/v2/framework/tests/mnist.py | 140 ++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 python/paddle/v2/framework/tests/mnist.py diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py new file mode 100644 index 0000000000..32a088ac28 --- /dev/null +++ b/python/paddle/v2/framework/tests/mnist.py @@ -0,0 +1,140 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator +import numpy + +BATCH_SIZE = 100 + +scope = core.Scope() +place = core.CPUPlace() +dev_ctx = core.DeviceContext.create(place) + +# init_net = core.Net.create() +forward_network = core.Net.create() + +# should be init after forward_op is constructed +# backward_net = core.Operator.backward(forward_net, set()) +backward_net = None +optimize_net = core.Net.create() + + +def atom_id(): + id = 0 + while True: + yield id + id += 1 + + +uniq_id = atom_id().next + + +def data_layer(name, dims): + var = scope.new_var(name) + tensor = var.get_tensor() + tensor.set_dims(dims) # 1 is batch size holder. + return name + + +def feed_data(name, data): + assert isinstance(data, numpy.array) + tensor = scope.find_var(name).get_tensor() + tensor.set_dims(data.shape) + tensor.alloc_float(place) + tensor.set(data, place) + + +def grad_var_name(var_name): + return var_name + "@GRAD" + + +def sgd_optimizer(net, param_name, learning_rate=0.01): + grad_name = grad_var_name(param_name) + optimize_op = Operator( + "sgd", param=param_name, grad=grad_name, learning_rate=learning_rate) + net.add_op(optimize_op) + + +# should use operator and add these to the init_network +def init_param(param_name, dims): + print param_name + var = scope.new_var(param_name) + tensor = var.get_tensor() + tensor.set_dims(dims) + data = numpy.random.uniform( + low=0.0, high=1.0, size=tensor.shape()).astype("float32") + tensor.set(data, place) + + +# fc_layer +def fc_layer(net, input, size, act="sigmoid", bias=True, param=None, name=None): + """ + Add a fc layer to net + + :param input: input variable name. + :type input: str + :param size: fully connected layer size. + :param act: activation name + :param param: parameter attribute, used for initialize parameters. + :param bias: bias attribute. False will not have a bias. + :param name: the name of fc layer. If not set, model will generate a + readable name + :return: output variable name. + """ + if name is None: + name = 'fc_%d' % uniq_id() + if not isinstance(name, str): + raise ValueError("name should be string") + + input_dims = scope.find_var(input).get_tensor().get_dims() + + w_name = param or name + ".w" + init_param(param_name=w_name, dims=[input_dims[1], size]) + sgd_optimizer(net=optimize_net, param_name=w_name, learning_rate=0.01) + + pre_activation = name + ".mul.out" + scope.new_var(pre_activation) + mul_op = Operator("mul", X=input, Y=w_name, Out=pre_activation) + net.add_op(mul_op) + + # create bias variable if needed + if bias: + bias_name = name + ".b" + init_param(param_name=bias_name, dims=[size]) + sgd_optimizer( + net=optimize_net, param_name=bias_name, learning_rate=0.01) + bias_out = name + ".rowwise_add.out" + scope.new_var(bias_out) + rowwise_add_op = Operator( + "rowwise_add", X=pre_activation, b=bias_name, Out=bias_out) + net.add_op(rowwise_add_op) + pre_activation = bias_out + + activation_op = Operator(act, X=pre_activation, Y=name) + net.add_op(activation_op) + scope.new_var(name) + net.infer_shape(scope) + return name + + +def cross_entropy_layer(net, input, label): + cost_name = 'cross_entropy_%d' % uniq_id() + cross_entropy_op = Operator( + "onehot_cross_entropy", X=input, label=label, Y=cost_name) + net.add_op(cross_entropy_op) + scope.new_var(cost_name) + net.infer_shape(scope) + return cost_name + + +images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) +label = data_layer(name='label', dims=[BATCH_SIZE]) +fc = fc_layer(net=forward_network, input=images, size=10, act="softmax") +cost = cross_entropy_layer(net=forward_network, input=fc, label=label) +forward_network.complete_add_op(True) +print(forward_network) +backward_net = core.Operator.backward(forward_network, set()) + +print(backward_net) + +PASS_NUM = 10 +for pass_id in range(PASS_NUM): + print pass_id From 118dd1494fbe3654da8f71c2245523e27616d475 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 21 Aug 2017 18:22:59 -0700 Subject: [PATCH 02/27] can run, for debug --- .../paddle/v2/framework/tests/CMakeLists.txt | 1 + python/paddle/v2/framework/tests/mnist.py | 73 +++++++++++++++++-- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a07130..41682c8350 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -27,3 +27,4 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) +py_test(mnist SRCS mnist.py) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index 32a088ac28..d0c56c457d 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -2,7 +2,7 @@ import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator import numpy -BATCH_SIZE = 100 +BATCH_SIZE = 2 scope = core.Scope() place = core.CPUPlace() @@ -35,10 +35,15 @@ def data_layer(name, dims): def feed_data(name, data): - assert isinstance(data, numpy.array) + assert isinstance(data, numpy.ndarray) tensor = scope.find_var(name).get_tensor() tensor.set_dims(data.shape) - tensor.alloc_float(place) + if data.dtype == numpy.dtype('int32'): + tensor.alloc_float(place) + elif data.dtype == numpy.dtype('float32'): + tensor.alloc_int(place) + else: + raise ValueError("data type not supported") tensor.set(data, place) @@ -49,7 +54,11 @@ def grad_var_name(var_name): def sgd_optimizer(net, param_name, learning_rate=0.01): grad_name = grad_var_name(param_name) optimize_op = Operator( - "sgd", param=param_name, grad=grad_name, learning_rate=learning_rate) + "sgd", + param=param_name, + grad=grad_name, + param_out=param_name, + learning_rate=learning_rate) net.add_op(optimize_op) @@ -65,7 +74,7 @@ def init_param(param_name, dims): # fc_layer -def fc_layer(net, input, size, act="sigmoid", bias=True, param=None, name=None): +def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): """ Add a fc layer to net @@ -125,16 +134,64 @@ def cross_entropy_layer(net, input, label): return cost_name +def get_backward_net(forward_net): + net = core.Operator.backward(forward_net, set()) + for input in net.inputs()["all"]: + var = scope.new_var(input) + var.get_tensor() + for output in net.outputs()["all"]: + var = scope.new_var(output) + var.get_tensor() + return net + + +def print_inputs_outputs(op): + print("===============" + op.type() + "==============") + print("***inputs:***") + for input in op.inputs()["all"]: + print input, scope.find_var(input).get_tensor().get_dims() + print("***outputs:***") + for output in op.outputs()["all"]: + print output, scope.find_var(output).get_tensor().get_dims() + print("") + print("") + + images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) label = data_layer(name='label', dims=[BATCH_SIZE]) fc = fc_layer(net=forward_network, input=images, size=10, act="softmax") cost = cross_entropy_layer(net=forward_network, input=fc, label=label) forward_network.complete_add_op(True) print(forward_network) -backward_net = core.Operator.backward(forward_network, set()) - +backward_net = get_backward_net(forward_network) print(backward_net) +optimize_net.complete_add_op(True) +print(optimize_net) PASS_NUM = 10 for pass_id in range(PASS_NUM): - print pass_id + print("===========forward==========") + feed_data("pixel", numpy.random.random((BATCH_SIZE, 784)).astype('float32')) + feed_data("label", numpy.ones(BATCH_SIZE).astype("int32")) + forward_network.infer_shape(scope) + print_inputs_outputs(forward_network) + + print(numpy.array(scope.find_var("label").get_tensor())) + forward_network.run(scope, dev_ctx) + # print(numpy.array(scope.find_var("fc_0").get_tensor())) + + print("===========backward==========") + cost_data = numpy.array(scope.find_var("cross_entropy_1").get_tensor()) + cost_grad = scope.find_var(grad_var_name("cross_entropy_1")).get_tensor() + cost_grad.set_dims(cost_data.shape) + cost_grad.alloc_float(place) + cost_grad.set(cost_data, place) + + backward_net.infer_shape(scope) + print_inputs_outputs(backward_net) + + backward_net.run(scope, dev_ctx) + + print("===========optimize_net==========") + print_inputs_outputs(optimize_net) + optimize_net.run(scope, dev_ctx) From 5a8fbb7d19e95f3be16bbee029e82e14f0a240df Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 22 Aug 2017 00:56:34 -0700 Subject: [PATCH 03/27] add data --- python/paddle/v2/framework/tests/mnist.py | 26 +++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index d0c56c457d..f75f196168 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -1,8 +1,9 @@ import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator import numpy +import paddle.v2 as paddle -BATCH_SIZE = 2 +BATCH_SIZE = 100 scope = core.Scope() place = core.CPUPlace() @@ -39,9 +40,9 @@ def feed_data(name, data): tensor = scope.find_var(name).get_tensor() tensor.set_dims(data.shape) if data.dtype == numpy.dtype('int32'): - tensor.alloc_float(place) - elif data.dtype == numpy.dtype('float32'): tensor.alloc_int(place) + elif data.dtype == numpy.dtype('float32'): + tensor.alloc_float(place) else: raise ValueError("data type not supported") tensor.set(data, place) @@ -168,20 +169,31 @@ print(backward_net) optimize_net.complete_add_op(True) print(optimize_net) -PASS_NUM = 10 +reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=8192), + batch_size=BATCH_SIZE) + +PASS_NUM = 1000 for pass_id in range(PASS_NUM): print("===========forward==========") - feed_data("pixel", numpy.random.random((BATCH_SIZE, 784)).astype('float32')) - feed_data("label", numpy.ones(BATCH_SIZE).astype("int32")) + # feed_data("pixel", numpy.random.random((BATCH_SIZE, 784)).astype('float32')) + # feed_data("label", numpy.ones(BATCH_SIZE).astype("int32")) + data = reader().next() + image = numpy.array(map(lambda x: x[0], data)).astype("float32") + label = numpy.array(map(lambda x: x[1], data)).astype("int32") + feed_data("pixel", image) + feed_data("label", label) forward_network.infer_shape(scope) print_inputs_outputs(forward_network) - print(numpy.array(scope.find_var("label").get_tensor())) + # print(numpy.array(scope.find_var("label").get_tensor())) forward_network.run(scope, dev_ctx) # print(numpy.array(scope.find_var("fc_0").get_tensor())) print("===========backward==========") cost_data = numpy.array(scope.find_var("cross_entropy_1").get_tensor()) + print(cost_data.sum() / len(cost_data)) cost_grad = scope.find_var(grad_var_name("cross_entropy_1")).get_tensor() cost_grad.set_dims(cost_data.shape) cost_grad.alloc_float(place) From 0f3b9e4112cbedd1b026f6cd09955d15f6207864 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 22 Aug 2017 15:36:43 +0800 Subject: [PATCH 04/27] lookup table op, cuda helper and set functor 1. finish lookup table CPU and GPU kernel 2. Add some cuda helper 3. Add some math funtor --- paddle/framework/pybind.cc | 1 + paddle/operators/CMakeLists.txt | 6 +- paddle/operators/functor/CMakeLists.txt | 5 + paddle/operators/functor/math_functor.cc | 42 +++++++ paddle/operators/functor/math_functor.cu | 42 +++++++ paddle/operators/functor/math_functor.h | 32 +++++ paddle/operators/lookup_table_op.cc | 71 +++++++++++ paddle/operators/lookup_table_op.cu | 116 ++++++++++++++++++ paddle/operators/lookup_table_op.h | 75 +++++++++++ paddle/platform/cuda_helper.h | 57 +++++++++ .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../v2/framework/tests/test_lookup_table.py | 31 +++++ 12 files changed, 477 insertions(+), 2 deletions(-) create mode 100644 paddle/operators/functor/CMakeLists.txt create mode 100644 paddle/operators/functor/math_functor.cc create mode 100644 paddle/operators/functor/math_functor.cu create mode 100644 paddle/operators/functor/math_functor.h create mode 100644 paddle/operators/lookup_table_op.cc create mode 100644 paddle/operators/lookup_table_op.cu create mode 100644 paddle/operators/lookup_table_op.h create mode 100644 paddle/platform/cuda_helper.h create mode 100644 python/paddle/v2/framework/tests/test_lookup_table.py diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index f0114b9e49..68c5526bbb 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -42,6 +42,7 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); +USE_OP(lookup_table); namespace paddle { namespace framework { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e4..1ca5010eae 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -42,6 +42,8 @@ function(op_library TARGET) endfunction() add_subdirectory(math) +add_subdirectory(functor) + cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) @@ -66,5 +68,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) -op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu) +op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) +op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu DEPS math_functor) diff --git a/paddle/operators/functor/CMakeLists.txt b/paddle/operators/functor/CMakeLists.txt new file mode 100644 index 0000000000..d3b39e5fc2 --- /dev/null +++ b/paddle/operators/functor/CMakeLists.txt @@ -0,0 +1,5 @@ +if(WITH_GPU) + nv_library(math_functor SRCS math_functor.cc math_functor.cu DEPS device_context) +else() + cc_library(math_functor SRCS math_functor.cc DEPS device_context) +endif() diff --git a/paddle/operators/functor/math_functor.cc b/paddle/operators/functor/math_functor.cc new file mode 100644 index 0000000000..1f2767f171 --- /dev/null +++ b/paddle/operators/functor/math_functor.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/functor/math_functor.h" +#include "paddle/framework/eigen.h" + +namespace paddle { +namespace operators { +namespace functor { + +template +struct Set { + void operator()(const T alpha, framework::Tensor* Y, + platform::DeviceContext* context) { + int N = product(Y->dims()); + T* YData = Y->mutable_data(context->GetPlace()); + if (alpha == static_cast(0)) { + memset(YData, 0, N * sizeof(T)); + } else { + framework::EigenVector::Flatten(*Y) + .setConstant(alpha); + } + } +}; + +template struct Set; +template struct Set; + +} // namespace functor +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/functor/math_functor.cu b/paddle/operators/functor/math_functor.cu new file mode 100644 index 0000000000..6dc828c60a --- /dev/null +++ b/paddle/operators/functor/math_functor.cu @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/functor/math_functor.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace functor { + +template +__global__ void SetKernel(const int N, const T alpha, T* Y) { + CUDA_1D_KERNEL_LOOP(i, N) { Y[i] = alpha; } +} + +template +struct Set { + void operator()(const T alpha, framework::Tensor* Y, + platform::DeviceContext* context) { + int N = product(Y->dims()); + T* YData = Y->mutable_data(context->GetPlace()); + SetKernel<<<(N + 512 - 1) / 512, 512>>>(N, alpha, YData); + } +}; + +template struct Set; +template struct Set; + +} // namespace functor +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/functor/math_functor.h b/paddle/operators/functor/math_functor.h new file mode 100644 index 0000000000..d5c7bd368f --- /dev/null +++ b/paddle/operators/functor/math_functor.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace functor { + +template +struct Set { + void operator()(const T alpha, paddle::framework::Tensor* Y, + paddle::platform::DeviceContext* context); +}; + +} // namespace functor +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc new file mode 100644 index 0000000000..5f70458a87 --- /dev/null +++ b/paddle/operators/lookup_table_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lookup_table_op.h" + +namespace paddle { +namespace operators { + +class LookupTableOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &context) const override { + auto table_t = context.Input("W"); + auto ids_t = context.Input("Ids"); + auto output_t = context.Output("Out"); + + output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); + } +}; + +class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LookupTableOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("W", + "An input represents embedding tensors," + " which is a learnable parameter."); + AddInput("Ids", + "An input with type int32 or int64" + "contains the ids to be looked up in W.") + .NotInGradient(); + AddOutput("Out", "The lookup results, which have the same type with W."); + AddComment( + "This operator is used to perform lookups on the parameter W," + "then concatenated into a dense tensor."); + } +}; + +class LookupTableOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &context) const override { + context.Output(0)->Resize(context.Input(0)->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, + lookup_table_grad, ops::LookupTableOpGrad); + +REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel); +REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel); diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu new file mode 100644 index 0000000000..94b440e00e --- /dev/null +++ b/paddle/operators/lookup_table_op.cu @@ -0,0 +1,116 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/functor/math_functor.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void LookupTable(T* output, const T* table, const uint32_t* ids, + const int N, const int K, const int D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDimX; + + while (idy < K) { + int id = ids[idy]; + PADDLE_ASSERT(id >= 0); + PADDLE_ASSERT(id < N); + T* out = output + idy; + const T* tab = table + id; + for (int i = idx; i < D; i += blockDimX) { + out[i] = tab[i]; + } + idy += blockDimY * gridDimX; + } +} + +template +__global__ void LookupTableGradKernel(T* table, const T* output, + const uint32_t* ids, const int N, + const int K, const int D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDimX; + + while (idy < K) { + int id = ids[idy]; + PADDLE_ASSERT(id >= 0); + PADDLE_ASSERT(id < N); + const T* out = output + idy; + T* tab = table + id; + for (int i = idx; i < D; i += blockDimX) { + paddle::platform::CudaAtomicAdd(tab + i, out[i]); + } + idy += blockDimY * gridDimX; + } +} + +template +class LookupTableCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto table_t = context.Input("W"); + auto ids_t = context.Input("Ids"); + auto output_t = context.Output("Out"); + + size_t N = table_t->dims()[0]; + size_t D = table_t->dims()[1]; + size_t K = product(ids_t->dims()); + auto ids = ids_t->data(); + auto table = table_t->data(); + auto output = output_t->mutable_data(context.GetPlace()); + + dim3 threads(128, 8); + dim3 grids(8, 1); + LookupTable<<>>(output, table, ids, N, K, D); + } +}; + +template +class LookupTableGrad : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto ids_t = context.Input("Ids"); + auto d_output_t = context.Input(framework::GradVarName("Out")); + auto d_table_t = context.Output(framework::GradVarName("W")); + + int N = d_table_t->dims()[0]; + int D = d_table_t->dims()[1]; + int K = product(ids_t->dims()); + const uint32_t* ids = ids_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); + const T* d_output = d_output_t->data(); + + auto* device_context = + const_cast(context.device_context_); + functor::Set()(static_cast(0), d_table_t, + device_context); + dim3 threads(128, 8); + dim3 grids(8, 1); + LookupTableGradKernel<<>>(d_table, d_output, + ids, N, K, D); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel); +REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad); diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h new file mode 100644 index 0000000000..790ecab3c6 --- /dev/null +++ b/paddle/operators/lookup_table_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/functor/math_functor.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class LookupTableKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto table_t = context.Input("W"); // float tensor + auto ids_t = context.Input("Ids"); // int tensor + auto output_t = context.Output("Out"); // float tensor + + size_t N = table_t->dims()[0]; + size_t D = table_t->dims()[1]; + auto ids = ids_t->data(); + auto table = table_t->data(); + auto output = output_t->mutable_data(context.GetPlace()); + for (size_t i = 0; i < product(ids_t->dims()); ++i) { + PADDLE_ENFORCE_LT(ids[i], N); + PADDLE_ENFORCE_GE(ids[i], 0); + memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); + } + } +}; + +template +class LookupTableGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto ids_t = context.Input("Ids"); + auto d_output_t = context.Input(framework::GradVarName("Out")); + auto d_table_t = context.Output(framework::GradVarName("W")); + + size_t N = d_table_t->dims()[0]; + size_t D = d_table_t->dims()[1]; + auto ids = ids_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); + const T* d_output = d_output_t->data(); + + auto* device_context = + const_cast(context.device_context_); + functor::Set()(static_cast(0), d_table_t, + device_context); + for (size_t i = 0; i < product(ids_t->dims()); ++i) { + PADDLE_ENFORCE_LT(ids[i], N); + PADDLE_ENFORCE_GE(ids[i], 0); + for (size_t j = 0; j < D; ++j) { + d_table[ids[i] * D + j] += d_output[i * D + j]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h new file mode 100644 index 0000000000..4346291117 --- /dev/null +++ b/paddle/platform/cuda_helper.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace platform { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +#define CUDA_ATOMIC_WRAPPER(op, T) \ + __device__ __forceinline__ T CudaAtomic##op(T* address, const T val) + +#define USE_CUDA_ATOMIC(op, T) \ + CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } + +// For atomicAdd. +USE_CUDA_ATOMIC(Add, float); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +USE_CUDA_ATOMIC(Add, double); +#else +// Custom implementation of atomicAdd for double. +// This implementation is copied from CUDA manual. +CUDA_ATOMIC_WRAPPER(Add, double) { + unsigned long long int* address_as_ull = + reinterpret_cast(address); + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN + } while (assumed != old); + + return __longlong_as_double(old); +#endif +} + +} // namespace platform +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a07130..65c02f2cfb 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -27,3 +27,4 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) +py_test(test_lookup_table SRCS test_lookup_table.py) diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py new file mode 100644 index 0000000000..071069768b --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lookup_table.py @@ -0,0 +1,31 @@ +import unittest +import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op + + +class TestSigmoidOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'lookup_table' + table = np.random.random((17, 31)).astype('float32') + ids = np.random.randint(0, 17, 4) + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids]} + + +class TestSigmoidGradOp(GradientChecker): + def test_grad(self): + op = create_op('lookup_table') + table = np.random.random((17, 31)).astype('float32') + ids = np.random.randint(0, 17, 4) + inputs = {'W': table, 'Ids': ids} + # compare gradients between cpu and gpu + self.compare_grad(op, inputs) + # check gradients + self.check_grad(op, inputs, set('W'), 'Out') + + +if __name__ == '__main__': + unittest.main() From c91e542ad3a5a1ecd1c6b825d7c9e89d3e7384b5 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 22 Aug 2017 17:18:14 +0800 Subject: [PATCH 05/27] fix compile for paddle_pybind. --- paddle/framework/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 68304c9fc8..325a6f7532 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -55,5 +55,6 @@ cc_library(paddle_pybind SHARED recurrent_op uniform_random_op gaussian_random_op + lookup_table_op fill_zeros_like_op) endif(WITH_PYTHON) From 9bc1a1a126dc60f06cd353ff72869416d50eb3af Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 22 Aug 2017 18:19:04 +0800 Subject: [PATCH 06/27] fix cuda_helper.h --- paddle/platform/cuda_helper.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h index 4346291117..939c3713ad 100644 --- a/paddle/platform/cuda_helper.h +++ b/paddle/platform/cuda_helper.h @@ -34,8 +34,6 @@ USE_CUDA_ATOMIC(Add, float); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 USE_CUDA_ATOMIC(Add, double); #else -// Custom implementation of atomicAdd for double. -// This implementation is copied from CUDA manual. CUDA_ATOMIC_WRAPPER(Add, double) { unsigned long long int* address_as_ull = reinterpret_cast(address); @@ -50,8 +48,8 @@ CUDA_ATOMIC_WRAPPER(Add, double) { } while (assumed != old); return __longlong_as_double(old); -#endif } +#endif } // namespace platform } // namespace paddle From a8d072c769b940d087006fa68ffcf462aa8579b8 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Wed, 23 Aug 2017 00:12:58 +0800 Subject: [PATCH 07/27] fix bug. --- paddle/operators/lookup_table_op.cc | 7 ++-- paddle/operators/lookup_table_op.cu | 32 +++++++++---------- paddle/operators/lookup_table_op.h | 6 ++-- .../v2/framework/tests/test_lookup_table.py | 6 ++-- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 5f70458a87..94d40890a7 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -41,8 +41,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { " which is a learnable parameter."); AddInput("Ids", "An input with type int32 or int64" - "contains the ids to be looked up in W.") - .NotInGradient(); + "contains the ids to be looked up in W."); AddOutput("Out", "The lookup results, which have the same type with W."); AddComment( "This operator is used to perform lookups on the parameter W," @@ -56,7 +55,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &context) const override { - context.Output(0)->Resize(context.Input(0)->dims()); + auto table = context.Input("W"); + auto d_table = context.Output(framework::GradVarName("W")); + d_table->Resize(table->dims()); } }; diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 94b440e00e..99678ef681 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -23,7 +23,7 @@ namespace operators { using Tensor = framework::Tensor; template -__global__ void LookupTable(T* output, const T* table, const uint32_t* ids, +__global__ void LookupTable(T* output, const T* table, const int32_t* ids, const int N, const int K, const int D) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * gridDimX; @@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids, int id = ids[idy]; PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id < N); - T* out = output + idy; - const T* tab = table + id; + T* out = output + idy * D; + const T* tab = table + id * D; for (int i = idx; i < D; i += blockDimX) { out[i] = tab[i]; } @@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids, } template -__global__ void LookupTableGradKernel(T* table, const T* output, - const uint32_t* ids, const int N, - const int K, const int D) { +__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, + const int N, const int K, const int D) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * gridDimX; @@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output, int id = ids[idy]; PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id < N); - const T* out = output + idy; - T* tab = table + id; + const T* out = output + idy * D; + T* tab = table + id * D; for (int i = idx; i < D; i += blockDimX) { - paddle::platform::CudaAtomicAdd(tab + i, out[i]); + paddle::platform::CudaAtomicAdd(&tab[i], out[i]); } idy += blockDimY * gridDimX; } @@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; size_t K = product(ids_t->dims()); - auto ids = ids_t->data(); + auto ids = ids_t->data(); auto table = table_t->data(); auto output = output_t->mutable_data(context.GetPlace()); @@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { }; template -class LookupTableGrad : public framework::OpKernel { +class LookupTableGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto ids_t = context.Input("Ids"); @@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel { int N = d_table_t->dims()[0]; int D = d_table_t->dims()[1]; int K = product(ids_t->dims()); - const uint32_t* ids = ids_t->data(); - T* d_table = d_table_t->mutable_data(context.GetPlace()); + const int32_t* ids = ids_t->data(); const T* d_output = d_output_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); @@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel { device_context); dim3 threads(128, 8); dim3 grids(8, 1); - LookupTableGradKernel<<>>(d_table, d_output, - ids, N, K, D); + LookupTableGrad<<>>(d_table, d_output, ids, N, + K, D); } }; @@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel); -REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad); +REGISTER_OP_GPU_KERNEL(lookup_table_grad, + ops::LookupTableGradCUDAKernel); diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 790ecab3c6..9254e03a1b 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel { size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; - auto ids = ids_t->data(); + auto ids = ids_t->data(); auto table = table_t->data(); auto output = output_t->mutable_data(context.GetPlace()); for (size_t i = 0; i < product(ids_t->dims()); ++i) { @@ -53,9 +53,9 @@ class LookupTableGradKernel : public framework::OpKernel { size_t N = d_table_t->dims()[0]; size_t D = d_table_t->dims()[1]; - auto ids = ids_t->data(); - T* d_table = d_table_t->mutable_data(context.GetPlace()); + auto ids = ids_t->data(); const T* d_output = d_output_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py index 071069768b..3056bf53e3 100644 --- a/python/paddle/v2/framework/tests/test_lookup_table.py +++ b/python/paddle/v2/framework/tests/test_lookup_table.py @@ -10,7 +10,7 @@ class TestSigmoidOp(unittest.TestCase): def setUp(self): self.type = 'lookup_table' table = np.random.random((17, 31)).astype('float32') - ids = np.random.randint(0, 17, 4) + ids = np.random.randint(0, 17, 4).astype('int32') self.inputs = {'W': table, 'Ids': ids} self.outputs = {'Out': table[ids]} @@ -19,10 +19,8 @@ class TestSigmoidGradOp(GradientChecker): def test_grad(self): op = create_op('lookup_table') table = np.random.random((17, 31)).astype('float32') - ids = np.random.randint(0, 17, 4) + ids = np.random.randint(0, 17, 4).astype('int32') inputs = {'W': table, 'Ids': ids} - # compare gradients between cpu and gpu - self.compare_grad(op, inputs) # check gradients self.check_grad(op, inputs, set('W'), 'Out') From 51792022c9f7963321d77d7dac4143e566af9fdc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 22 Aug 2017 12:54:44 -0700 Subject: [PATCH 08/27] refine code and add debug info --- python/paddle/v2/framework/tests/mnist.py | 47 +++++++++++------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index f75f196168..6a3ed0dce0 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -52,7 +52,7 @@ def grad_var_name(var_name): return var_name + "@GRAD" -def sgd_optimizer(net, param_name, learning_rate=0.01): +def sgd_optimizer(net, param_name, learning_rate=0.001): grad_name = grad_var_name(param_name) optimize_op = Operator( "sgd", @@ -65,7 +65,6 @@ def sgd_optimizer(net, param_name, learning_rate=0.01): # should use operator and add these to the init_network def init_param(param_name, dims): - print param_name var = scope.new_var(param_name) tensor = var.get_tensor() tensor.set_dims(dims) @@ -158,17 +157,34 @@ def print_inputs_outputs(op): print("") +def set_cost(): + cost_data = numpy.array(scope.find_var("cross_entropy_1").get_tensor()) + # print(cost_data) + print(cost_data.sum() / len(cost_data)) + + cost_grad = scope.find_var(grad_var_name("cross_entropy_1")).get_tensor() + cost_grad.set_dims(cost_data.shape) + cost_grad.alloc_float(place) + cost_grad.set(cost_data, place) + + images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) label = data_layer(name='label', dims=[BATCH_SIZE]) fc = fc_layer(net=forward_network, input=images, size=10, act="softmax") cost = cross_entropy_layer(net=forward_network, input=fc, label=label) + forward_network.complete_add_op(True) -print(forward_network) backward_net = get_backward_net(forward_network) -print(backward_net) optimize_net.complete_add_op(True) + +print(forward_network) +print(backward_net) print(optimize_net) +print_inputs_outputs(forward_network) +print_inputs_outputs(backward_net) +print_inputs_outputs(optimize_net) + reader = paddle.batch( paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=8192), @@ -176,34 +192,17 @@ reader = paddle.batch( PASS_NUM = 1000 for pass_id in range(PASS_NUM): - print("===========forward==========") - # feed_data("pixel", numpy.random.random((BATCH_SIZE, 784)).astype('float32')) - # feed_data("label", numpy.ones(BATCH_SIZE).astype("int32")) data = reader().next() + image = numpy.array(map(lambda x: x[0], data)).astype("float32") label = numpy.array(map(lambda x: x[1], data)).astype("int32") feed_data("pixel", image) feed_data("label", label) - forward_network.infer_shape(scope) - print_inputs_outputs(forward_network) - # print(numpy.array(scope.find_var("label").get_tensor())) + forward_network.infer_shape(scope) forward_network.run(scope, dev_ctx) - # print(numpy.array(scope.find_var("fc_0").get_tensor())) - - print("===========backward==========") - cost_data = numpy.array(scope.find_var("cross_entropy_1").get_tensor()) - print(cost_data.sum() / len(cost_data)) - cost_grad = scope.find_var(grad_var_name("cross_entropy_1")).get_tensor() - cost_grad.set_dims(cost_data.shape) - cost_grad.alloc_float(place) - cost_grad.set(cost_data, place) - + set_cost() backward_net.infer_shape(scope) - print_inputs_outputs(backward_net) - backward_net.run(scope, dev_ctx) - print("===========optimize_net==========") - print_inputs_outputs(optimize_net) optimize_net.run(scope, dev_ctx) From d3c65a64dc4ab98af10498cb2eb9327ef1697e5a Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 22 Aug 2017 20:21:23 -0700 Subject: [PATCH 09/27] fix data reader --- python/paddle/v2/framework/tests/mnist.py | 29 ++++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index 6a3ed0dce0..1d40fd9a97 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -52,7 +52,7 @@ def grad_var_name(var_name): return var_name + "@GRAD" -def sgd_optimizer(net, param_name, learning_rate=0.001): +def sgd_optimizer(net, param_name, learning_rate=0.01): grad_name = grad_var_name(param_name) optimize_op = Operator( "sgd", @@ -159,13 +159,13 @@ def print_inputs_outputs(op): def set_cost(): cost_data = numpy.array(scope.find_var("cross_entropy_1").get_tensor()) - # print(cost_data) print(cost_data.sum() / len(cost_data)) cost_grad = scope.find_var(grad_var_name("cross_entropy_1")).get_tensor() + cost_grad.set_dims(cost_data.shape) cost_grad.alloc_float(place) - cost_grad.set(cost_data, place) + cost_grad.set(numpy.ones(cost_data.shape).astype("float32"), place) images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) @@ -192,17 +192,18 @@ reader = paddle.batch( PASS_NUM = 1000 for pass_id in range(PASS_NUM): - data = reader().next() - image = numpy.array(map(lambda x: x[0], data)).astype("float32") - label = numpy.array(map(lambda x: x[1], data)).astype("int32") - feed_data("pixel", image) - feed_data("label", label) + print("pass[" + str(pass_id) + "]") + for data in reader(): + image = numpy.array(map(lambda x: x[0], data)).astype("float32") + label = numpy.array(map(lambda x: x[1], data)).astype("int32") + feed_data("pixel", image) + feed_data("label", label) - forward_network.infer_shape(scope) - forward_network.run(scope, dev_ctx) - set_cost() - backward_net.infer_shape(scope) - backward_net.run(scope, dev_ctx) + forward_network.infer_shape(scope) + forward_network.run(scope, dev_ctx) + set_cost() + backward_net.infer_shape(scope) + backward_net.run(scope, dev_ctx) - optimize_net.run(scope, dev_ctx) + optimize_net.run(scope, dev_ctx) From a13798e8f7764239c151864894afc6a543e6c190 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 22 Aug 2017 20:41:31 -0700 Subject: [PATCH 10/27] rename add_op to append_op --- python/paddle/v2/framework/tests/mnist.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index 1d40fd9a97..32349b8d4d 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -60,7 +60,7 @@ def sgd_optimizer(net, param_name, learning_rate=0.01): grad=grad_name, param_out=param_name, learning_rate=learning_rate) - net.add_op(optimize_op) + net.append_op(optimize_op) # should use operator and add these to the init_network @@ -102,7 +102,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): pre_activation = name + ".mul.out" scope.new_var(pre_activation) mul_op = Operator("mul", X=input, Y=w_name, Out=pre_activation) - net.add_op(mul_op) + net.append_op(mul_op) # create bias variable if needed if bias: @@ -112,13 +112,13 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): net=optimize_net, param_name=bias_name, learning_rate=0.01) bias_out = name + ".rowwise_add.out" scope.new_var(bias_out) - rowwise_add_op = Operator( + rowwise_append_op = Operator( "rowwise_add", X=pre_activation, b=bias_name, Out=bias_out) - net.add_op(rowwise_add_op) + net.append_op(rowwise_append_op) pre_activation = bias_out activation_op = Operator(act, X=pre_activation, Y=name) - net.add_op(activation_op) + net.append_op(activation_op) scope.new_var(name) net.infer_shape(scope) return name @@ -128,7 +128,7 @@ def cross_entropy_layer(net, input, label): cost_name = 'cross_entropy_%d' % uniq_id() cross_entropy_op = Operator( "onehot_cross_entropy", X=input, label=label, Y=cost_name) - net.add_op(cross_entropy_op) + net.append_op(cross_entropy_op) scope.new_var(cost_name) net.infer_shape(scope) return cost_name From d8cd67dd1e229a27180d3628dc9485734546aba4 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 23 Aug 2017 12:26:46 +0800 Subject: [PATCH 11/27] Make cudnn convolution layer and projection support for dilation. --- paddle/cuda/include/hl_cuda_cudnn.h | 11 +- paddle/cuda/src/hl_cuda_cudnn.cc | 123 ++++++++++-------- paddle/gserver/layers/ConvBaseLayer.cpp | 16 ++- paddle/gserver/layers/ConvBaseLayer.h | 4 + paddle/gserver/layers/ConvBaseOperator.cpp | 3 +- paddle/gserver/layers/ConvBaseProjection.cpp | 20 ++- paddle/gserver/layers/ConvBaseProjection.h | 1 + paddle/gserver/layers/ConvProjection.cpp | 4 +- paddle/gserver/tests/test_LayerGrad.cpp | 40 ++++-- proto/ModelConfig.proto | 3 + python/paddle/trainer/config_parser.py | 4 + .../paddle/trainer_config_helpers/layers.py | 19 +++ .../tests/configs/img_layers.py | 1 + 13 files changed, 171 insertions(+), 78 deletions(-) diff --git a/paddle/cuda/include/hl_cuda_cudnn.h b/paddle/cuda/include/hl_cuda_cudnn.h index db18e4912b..3f68c62de6 100644 --- a/paddle/cuda/include/hl_cuda_cudnn.h +++ b/paddle/cuda/include/hl_cuda_cudnn.h @@ -214,7 +214,8 @@ extern void hl_conv_workspace(hl_tensor_descriptor input, int* convBwdDataAlgo, size_t* bwdDataLimitBytes, int* convBwdFilterAlgo, - size_t* bwdFilterLimitBytes); + size_t* bwdFilterLimitBytes, + bool useDilation); /** * @brief destroy filter descriptor. @@ -242,7 +243,9 @@ extern void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, int padding_height, int padding_width, int stride_height, - int stride_width); + int stride_width, + int dilation_h = 1, + int dilation_w = 1); /** * @brief reset convolution descriptor. @@ -262,7 +265,9 @@ extern void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, int padding_height, int padding_width, int stride_height, - int stride_width); + int stride_width, + int dilation_h = 1, + int dilation_w = 1); /** * @brief destroy convolution descriptor. diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index 78642a1744..f55fa523e1 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -201,7 +201,8 @@ void hl_conv_workspace(hl_tensor_descriptor input, int* convBwdDataAlgo, size_t* bwdDataLimitBytes, int* convBwdFilterAlgo, - size_t* bwdFilterLimitBytes) { + size_t* bwdFilterLimitBytes, + bool useDilation) { #if CUDNN_VERSION >= 4000 CHECK_NOTNULL(input); @@ -213,21 +214,60 @@ void hl_conv_workspace(hl_tensor_descriptor input, size_t memoryLimitBytes = (1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb; + // For dilation + int algo = 0; + // cudnn convolution forward configuration cudnnTensorDescriptor_t fwd_src_desc = GET_TENSOR_DESCRIPTOR(input); cudnnTensorDescriptor_t fwd_dest_desc = GET_TENSOR_DESCRIPTOR(output); cudnnFilterDescriptor_t fwd_filter_desc = GET_FILTER_DESCRIPTOR(filter); cudnnConvolutionDescriptor_t fwd_conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv); + // cudnn convolution backward data configuration + cudnnFilterDescriptor_t bwd_data_filter_desc = GET_FILTER_DESCRIPTOR(filter); + cudnnTensorDescriptor_t bwd_data_diff_desc = GET_TENSOR_DESCRIPTOR(output); + cudnnTensorDescriptor_t bwd_data_grad_desc = GET_TENSOR_DESCRIPTOR(input); + cudnnConvolutionDescriptor_t bwd_data_conv_desc = + GET_CONVOLUTION_DESCRIPTOR(conv); + // cudnn convolution backward filter configuration + cudnnTensorDescriptor_t bwd_filter_src_desc = GET_TENSOR_DESCRIPTOR(input); + cudnnTensorDescriptor_t bwd_filter_diff_desc = GET_TENSOR_DESCRIPTOR(output); + cudnnConvolutionDescriptor_t bwd_filter_conv_desc = + GET_CONVOLUTION_DESCRIPTOR(conv); + cudnnFilterDescriptor_t bwd_filter_grad_desc = GET_FILTER_DESCRIPTOR(filter); - CHECK_CUDNN(dynload::cudnnGetConvolutionForwardAlgorithm( - t_resource.cudnn_handle, - fwd_src_desc, - fwd_filter_desc, - fwd_conv_desc, - fwd_dest_desc, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - memoryLimitBytes, - reinterpret_cast(convFwdAlgo))); + if (useDilation) { + convFwdAlgo = &algo; + convBwdDataAlgo = &algo; + convBwdFilterAlgo = &algo; + } else { + CHECK_CUDNN(dynload::cudnnGetConvolutionForwardAlgorithm( + t_resource.cudnn_handle, + fwd_src_desc, + fwd_filter_desc, + fwd_conv_desc, + fwd_dest_desc, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + memoryLimitBytes, + reinterpret_cast(convFwdAlgo))); + CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataAlgorithm( + t_resource.cudnn_handle, + bwd_data_filter_desc, + bwd_data_diff_desc, + bwd_data_conv_desc, + bwd_data_grad_desc, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + memoryLimitBytes, + reinterpret_cast(convBwdDataAlgo))); + CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + t_resource.cudnn_handle, + bwd_filter_src_desc, + bwd_filter_diff_desc, + bwd_filter_conv_desc, + bwd_filter_grad_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + memoryLimitBytes, + reinterpret_cast(convBwdFilterAlgo))); + } CHECK_CUDNN(dynload::cudnnGetConvolutionForwardWorkspaceSize( t_resource.cudnn_handle, @@ -238,23 +278,6 @@ void hl_conv_workspace(hl_tensor_descriptor input, static_cast(*convFwdAlgo), fwdLimitBytes)); - // cudnn convolution backward data configuration - cudnnFilterDescriptor_t bwd_data_filter_desc = GET_FILTER_DESCRIPTOR(filter); - cudnnTensorDescriptor_t bwd_data_diff_desc = GET_TENSOR_DESCRIPTOR(output); - cudnnTensorDescriptor_t bwd_data_grad_desc = GET_TENSOR_DESCRIPTOR(input); - cudnnConvolutionDescriptor_t bwd_data_conv_desc = - GET_CONVOLUTION_DESCRIPTOR(conv); - - CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataAlgorithm( - t_resource.cudnn_handle, - bwd_data_filter_desc, - bwd_data_diff_desc, - bwd_data_conv_desc, - bwd_data_grad_desc, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - memoryLimitBytes, - reinterpret_cast(convBwdDataAlgo))); - CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( t_resource.cudnn_handle, bwd_data_filter_desc, @@ -264,23 +287,6 @@ void hl_conv_workspace(hl_tensor_descriptor input, static_cast(*convBwdDataAlgo), bwdDataLimitBytes)); - // cudnn convolution backward filter configuration - cudnnTensorDescriptor_t bwd_filter_src_desc = GET_TENSOR_DESCRIPTOR(input); - cudnnTensorDescriptor_t bwd_filter_diff_desc = GET_TENSOR_DESCRIPTOR(output); - cudnnConvolutionDescriptor_t bwd_filter_conv_desc = - GET_CONVOLUTION_DESCRIPTOR(conv); - cudnnFilterDescriptor_t bwd_filter_grad_desc = GET_FILTER_DESCRIPTOR(filter); - - CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterAlgorithm( - t_resource.cudnn_handle, - bwd_filter_src_desc, - bwd_filter_diff_desc, - bwd_filter_conv_desc, - bwd_filter_grad_desc, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - memoryLimitBytes, - reinterpret_cast(convBwdFilterAlgo))); - CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( t_resource.cudnn_handle, bwd_filter_src_desc, @@ -603,7 +609,9 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, int padding_height, int padding_width, int stride_height, - int stride_width) { + int stride_width, + int dilation_h, + int dilation_w) { CHECK_NOTNULL(conv); cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)malloc( @@ -625,18 +633,23 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, padding_width, stride_height, stride_width, - 1, - 1, + dilation_h, + dilation_w, mode, data_type)); #else + if (dilation_h > 1 || dilation_w > 1) { + LOG(FATAL) + << "Current cudnn version does't support for dilation convolution."; + } + CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc, padding_height, padding_width, stride_height, stride_width, - 1, - 1, + dilation_h, + dilation_w, mode)); #endif @@ -659,7 +672,9 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, int padding_height, int padding_width, int stride_height, - int stride_width) { + int stride_width, + int dilation_h, + int dilation_w) { CHECK_NOTNULL(conv); CHECK_NOTNULL(image); CHECK_NOTNULL(filter); @@ -678,8 +693,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, padding_width, stride_height, stride_width, - 1, - 1, + dilation_h, + dilation_w, mode, data_type)); #else @@ -688,8 +703,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, padding_width, stride_height, stride_width, - 1, - 1, + dilation_h, + dilation_w, mode)); #endif diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index e161d89c38..a5328ef834 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -32,9 +32,11 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, const ConvConfig& conf = inputConfig.conv_conf(); padding_.push_back(conf.padding()); stride_.push_back(conf.stride()); + dilation_.push_back(conf.dilation()); filterSize_.push_back(conf.filter_size()); paddingY_.push_back(conf.padding_y()); strideY_.push_back(conf.stride_y()); + dilationY_.push_back(conf.dilation_y()); filterSizeY_.push_back(conf.filter_size_y()); filterPixels_.push_back(filterSize_.back() * filterSizeY_.back()); channels_.push_back(conf.channels()); @@ -89,7 +91,11 @@ size_t ConvBaseLayer::calOutputSize() { size_t layerSize = 0; auto setLayerSize = [&](IntV& inH, IntV& inW, IntV& outH, IntV& outW) { + size_t filterSizeY; + size_t filterSize; for (size_t i = 0; i < inputLayers_.size(); i++) { + filterSizeY = (filterSizeY_[i] - 1) * dilationY_[i] + 1; + filterSize = (filterSize_[i] - 1) * dilation_[i] + 1; inH.push_back(inputLayers_[i]->getOutput().getFrameHeight()); inW.push_back(inputLayers_[i]->getOutput().getFrameWidth()); const ConvConfig& conf = config_.inputs(i).conv_conf(); @@ -98,17 +104,17 @@ size_t ConvBaseLayer::calOutputSize() { inH[i] = conf.has_output_y() ? conf.output_y() : conf.output_x(); if (inW[i] == 0) inW[i] = conf.output_x(); outH.push_back(imageSize( - inH[i], filterSizeY_[i], paddingY_[i], strideY_[i], caffeMode_)); - outW.push_back(imageSize( - inW[i], filterSize_[i], padding_[i], stride_[i], caffeMode_)); + inH[i], filterSizeY, paddingY_[i], strideY_[i], caffeMode_)); + outW.push_back( + imageSize(inW[i], filterSize, padding_[i], stride_[i], caffeMode_)); } else { if (inH[i] == 0) inH[i] = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size(); if (inW[i] == 0) inW[i] = conf.img_size(); outH.push_back(outputSize( - inH[i], filterSizeY_[i], paddingY_[i], strideY_[i], caffeMode_)); + inH[i], filterSizeY, paddingY_[i], strideY_[i], caffeMode_)); outW.push_back(outputSize( - inW[i], filterSize_[i], padding_[i], stride_[i], caffeMode_)); + inW[i], filterSize, padding_[i], stride_[i], caffeMode_)); } CHECK_EQ(outH[i], outH[0]); CHECK_EQ(outW[i], outW[0]); diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h index e9d15d94f8..223bce8e29 100644 --- a/paddle/gserver/layers/ConvBaseLayer.h +++ b/paddle/gserver/layers/ConvBaseLayer.h @@ -40,6 +40,10 @@ protected: IntV stride_; /// The y dimension of the stride. IntV strideY_; + /// The x dimension of the dilation. + IntV dilation_; + /// The y dimension of the dilation. + IntV dilationY_; /// The x dimension of a filter kernel. IntV filterSize_; /// The y dimension of a filter kernel. diff --git a/paddle/gserver/layers/ConvBaseOperator.cpp b/paddle/gserver/layers/ConvBaseOperator.cpp index 5c23198629..5469c41c87 100644 --- a/paddle/gserver/layers/ConvBaseOperator.cpp +++ b/paddle/gserver/layers/ConvBaseOperator.cpp @@ -59,7 +59,8 @@ void ConvBaseOperator::allocConvWorkSpace() { &bwdDataAlgo_, &bwdDataLimitBytes_, &bwdFilterAlgo_, - &bwdFilterLimitBytes_); + &bwdFilterLimitBytes_, + /*useDilation*/ false); size_t maxWorkSpace = 0; maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); diff --git a/paddle/gserver/layers/ConvBaseProjection.cpp b/paddle/gserver/layers/ConvBaseProjection.cpp index eb6b0445c9..08f36c516c 100644 --- a/paddle/gserver/layers/ConvBaseProjection.cpp +++ b/paddle/gserver/layers/ConvBaseProjection.cpp @@ -41,6 +41,11 @@ void ConvBaseProjection::getConvParams() { strideH_ = conf.stride_y(); strideW_ = conf.stride(); + dilationH_ = conf.dilation_y(); + dilationW_ = conf.dilation(); + CHECK_GT(dilationH_, 0); + CHECK_GT(dilationW_, 0); + filterH_ = conf.filter_size_y(); filterW_ = conf.filter_size(); @@ -77,7 +82,9 @@ void ConvBaseProjection::initCudnn() { paddingH_, paddingW_, strideH_, - strideW_); + strideW_, + dilationH_, + dilationW_); // initialize all to default algorithms fwdAlgo_ = 0; @@ -131,7 +138,9 @@ void ConvBaseProjection::reshapeTensorDesc(int batchSize) { paddingH_, paddingW_, strideH_, - strideW_); + strideW_, + dilationH_, + dilationW_); } void ConvBaseProjection::reshape(int batchSize) { @@ -140,6 +149,10 @@ void ConvBaseProjection::reshape(int batchSize) { CHECK_EQ(calInputSize(), in_->value->getWidth()); reshapeTensorDesc(batchSize); + bool useDilation = false; + if (dilationH_ > 1 || dilationW_ > 1) { + useDilation = true; + } hl_conv_workspace(imageDesc_, outputDesc_, filterDesc_, @@ -149,7 +162,8 @@ void ConvBaseProjection::reshape(int batchSize) { &bwdDataAlgo_, &bwdDataLimitBytes_, &bwdFilterAlgo_, - &bwdFilterLimitBytes_); + &bwdFilterLimitBytes_, + useDilation); size_t maxWorkSpace = 0; maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); diff --git a/paddle/gserver/layers/ConvBaseProjection.h b/paddle/gserver/layers/ConvBaseProjection.h index e9d9f8f1b2..ebdb57845b 100644 --- a/paddle/gserver/layers/ConvBaseProjection.h +++ b/paddle/gserver/layers/ConvBaseProjection.h @@ -63,6 +63,7 @@ protected: int configChannels_, configNumFilters_; int paddingH_, paddingW_; int strideH_, strideW_; + int dilationH_, dilationW_; int filterH_, filterW_; /// One group offset of input data. int inputOffset_; diff --git a/paddle/gserver/layers/ConvProjection.cpp b/paddle/gserver/layers/ConvProjection.cpp index 5b7ecc5560..6f0106b713 100644 --- a/paddle/gserver/layers/ConvProjection.cpp +++ b/paddle/gserver/layers/ConvProjection.cpp @@ -25,12 +25,12 @@ size_t ConvProjection::calOutputSize() { if (imageH_ == 0) imageH_ = configImgH_; if (imageW_ == 0) imageW_ = configImgW_; outputH_ = outputSize(imageH_, - filterH_, + (filterH_ - 1) * dilationH_ + 1, paddingH_, strideH_, /* caffeMode */ true); outputW_ = outputSize(imageW_, - filterW_, + (filterW_ - 1) * dilationW_ + 1, paddingW_, strideW_, /* caffeMode */ true); diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0f312b6ca5..b3913d3a28 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include @@ -189,10 +190,16 @@ TEST(Projection, scaling) { void testProjectionConv(size_t groups, bool isDeconv) { const int NUM_FILTERS = 18; const int FILTER_SIZE = 2; - const int FILTER_SIZE_Y = 4; + const int FILTER_SIZE_Y = 2; const int CHANNELS = 3; const int IMAGE_SIZE = 16; +#if CUDNN_VERSION >= 6000 + const int DILATION = 2; +#else + const int DILATION = 1; +#endif + ProjectionConfig conf; if (isDeconv) { conf.set_type("convt"); @@ -209,6 +216,8 @@ void testProjectionConv(size_t groups, bool isDeconv) { conv->set_padding_y(1); conv->set_stride(2); conv->set_stride_y(2); + conv->set_dilation(DILATION); + conv->set_dilation_y(DILATION); conv->set_groups(groups); if (isDeconv) { conv->set_filter_channels(NUM_FILTERS / conv->groups()); @@ -217,12 +226,12 @@ void testProjectionConv(size_t groups, bool isDeconv) { } conv->set_img_size(IMAGE_SIZE); int output_x = outputSize(conv->img_size(), - conv->filter_size(), + (conv->filter_size() - 1) * DILATION + 1, conv->padding(), conv->stride(), /* caffeMode */ true); int output_y = outputSize(conv->img_size(), - conv->filter_size_y(), + (conv->filter_size_y() - 1) * DILATION + 1, conv->padding_y(), conv->stride_y(), /* caffeMode */ true); @@ -253,8 +262,8 @@ TEST(Projection, conv) { testProjectionConv(1, false); testProjectionConv(3, false); /// test ConvTransProjection - testProjectionConv(1, true); - testProjectionConv(3, true); + /// testProjectionConv(1, true); + /// testProjectionConv(3, true); } #endif @@ -424,27 +433,38 @@ void testConvLayer(const string& type, bool trans, bool useGpu) { config.layerConfig.set_partial_sum(1); config.layerConfig.set_shared_biases(true); - config.inputDefs.push_back({INPUT_DATA, "layer_0", 384, 288}); + int dilation = 1; + if (type == "cudnn_conv") { +#if CUDNN_VERSION >= 6000 + dilation = 2; +#else + dilation = 1; +#endif + } + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 768, 192}); LayerInputConfig* input = config.layerConfig.add_inputs(); ConvConfig* conv = input->mutable_conv_conf(); conv->set_filter_size(2); - conv->set_filter_size_y(3); + conv->set_filter_size_y(2); conv->set_channels(3); conv->set_padding(0); conv->set_padding_y(1); conv->set_stride(2); conv->set_stride_y(2); + conv->set_dilation(dilation); + conv->set_dilation_y(dilation); conv->set_groups(1); conv->set_filter_channels(conv->channels() / conv->groups()); conv->set_img_size(16); - conv->set_img_size_y(8); + conv->set_img_size_y(16); conv->set_output_x(outputSize(conv->img_size(), - conv->filter_size(), + (conv->filter_size() - 1) * dilation + 1, conv->padding(), conv->stride(), /* caffeMode */ true)); conv->set_output_y(outputSize(conv->img_size_y(), - conv->filter_size_y(), + (conv->filter_size_y() - 1) * dilation + 1, conv->padding_y(), conv->stride_y(), /* caffeMode */ true)); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 4f3d5bf3f6..14c745b532 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -82,6 +82,9 @@ message ConvConfig { // if not set, use img_size optional uint32 img_size_y = 14; + + required uint32 dilation = 15 [ default = 1 ]; + required uint32 dilation_y = 16 [ default = 1 ]; } message PoolConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index da99e5bd53..2d96901ed4 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -861,6 +861,7 @@ class Conv(Cfg): filter_size, channels, padding=None, + dilation=None, stride=None, groups=None, filter_channels=None, @@ -869,12 +870,15 @@ class Conv(Cfg): caffe_mode=True, filter_size_y=None, padding_y=None, + dilation_y=None, stride_y=None): self.add_keys(locals()) if filter_size_y is None: self.filter_size_y = filter_size if padding_y is None: self.padding_y = padding + if dilation_y is None: + self.dilation_y = dilation if stride_y is None: self.stride_y = stride if output_x is not None: diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 1bc55c8696..de7f31a20a 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2322,6 +2322,7 @@ def img_conv_layer(input, groups=1, stride=1, padding=0, + dilation=0, bias_attr=None, param_attr=None, shared_biases=True, @@ -2329,6 +2330,7 @@ def img_conv_layer(input, filter_size_y=None, stride_y=None, padding_y=None, + dilation_y=None, trans=False, layer_type=None): """ @@ -2393,6 +2395,11 @@ def img_conv_layer(input, :type padding: int|tuple|list :param padding_y: The y dimension of the padding. :type padding_y: int + :param dilation: The x dimension of the dilation. Or input a tuple for two + image dimension + :type dilation: int|tuple|list + :param padding_y: The y dimension of the dilation. + :type padding_y: int :param bias_attr: Convolution bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|False @@ -2440,6 +2447,16 @@ def img_conv_layer(input, else: padding_y = padding + if dilation_y is None: + if isinstance(dilation, collections.Sequence): + assert len(dilation) == 2 + dilation, dilation_y = dilation + else: + dilation_y = dilation + + if dilation > 1 or dilation_y > 1: + assert layer_type in ["cudnn_conv", "cudnn_convt"] + if param_attr.attr.get('initial_smart'): # special initial for conv layers. init_w = (2.0 / (filter_size**2 * num_channels))**0.5 @@ -2464,11 +2481,13 @@ def img_conv_layer(input, conv=Conv( filter_size=filter_size, padding=padding, + dilation=dilation, stride=stride, channels=num_channels, groups=groups, filter_size_y=filter_size_y, padding_y=padding_y, + dilation_y=dilation_y, stride_y=stride_y), **param_attr.attr), active_type=act.name, diff --git a/python/paddle/trainer_config_helpers/tests/configs/img_layers.py b/python/paddle/trainer_config_helpers/tests/configs/img_layers.py index 9fda16a540..01d31ef3fa 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/img_layers.py +++ b/python/paddle/trainer_config_helpers/tests/configs/img_layers.py @@ -12,6 +12,7 @@ img_conv = img_conv_layer( num_filters=64, filter_size=(32, 32), padding=(1, 1), + dilation=(1, 1), stride=(1, 1), act=LinearActivation()) img_bn = batch_norm_layer(input=img_conv, act=ReluActivation()) From 1dc850e4d116f3e51c63bf5c390f9529f6884904 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 23 Aug 2017 13:13:16 +0800 Subject: [PATCH 12/27] Fix proto file --- proto/ModelConfig.proto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 14c745b532..1ea1e05259 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -83,8 +83,8 @@ message ConvConfig { // if not set, use img_size optional uint32 img_size_y = 14; - required uint32 dilation = 15 [ default = 1 ]; - required uint32 dilation_y = 16 [ default = 1 ]; + optional uint32 dilation = 15 [ default = 1 ]; + optional uint32 dilation_y = 16 [ default = 1 ]; } message PoolConfig { From f188e22b33c1a152a1835a5d0cb4b23e6e6d25bf Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Wed, 23 Aug 2017 14:39:16 +0800 Subject: [PATCH 13/27] Remove set functor and add comapre_grad test --- paddle/operators/CMakeLists.txt | 3 +- paddle/operators/fill_zeros_like_op.h | 2 +- paddle/operators/functor/CMakeLists.txt | 5 --- paddle/operators/functor/math_functor.cc | 42 ------------------- paddle/operators/functor/math_functor.cu | 42 ------------------- paddle/operators/functor/math_functor.h | 32 -------------- paddle/operators/lookup_table_op.cu | 26 ++++++------ paddle/operators/lookup_table_op.h | 10 ++--- paddle/platform/cuda_helper.h | 4 -- .../v2/framework/tests/gradient_checker.py | 13 +++++- .../v2/framework/tests/test_lookup_table.py | 2 + 11 files changed, 33 insertions(+), 148 deletions(-) delete mode 100644 paddle/operators/functor/CMakeLists.txt delete mode 100644 paddle/operators/functor/math_functor.cc delete mode 100644 paddle/operators/functor/math_functor.cu delete mode 100644 paddle/operators/functor/math_functor.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 1ca5010eae..8d2d8a1141 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -42,7 +42,6 @@ function(op_library TARGET) endfunction() add_subdirectory(math) -add_subdirectory(functor) cc_test(gather_test SRCS gather_test.cc DEPS tensor) @@ -69,4 +68,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) -op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu DEPS math_functor) +op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu) diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h index fd380ca851..969998ce2e 100644 --- a/paddle/operators/fill_zeros_like_op.h +++ b/paddle/operators/fill_zeros_like_op.h @@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel { auto* output = context.Output("Dst"); output->mutable_data(context.GetPlace()); auto t = framework::EigenVector::Flatten(*output); - t.device(context.GetEigenDevice()) = t.constant(T(0)); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); } }; diff --git a/paddle/operators/functor/CMakeLists.txt b/paddle/operators/functor/CMakeLists.txt deleted file mode 100644 index d3b39e5fc2..0000000000 --- a/paddle/operators/functor/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -if(WITH_GPU) - nv_library(math_functor SRCS math_functor.cc math_functor.cu DEPS device_context) -else() - cc_library(math_functor SRCS math_functor.cc DEPS device_context) -endif() diff --git a/paddle/operators/functor/math_functor.cc b/paddle/operators/functor/math_functor.cc deleted file mode 100644 index 1f2767f171..0000000000 --- a/paddle/operators/functor/math_functor.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/functor/math_functor.h" -#include "paddle/framework/eigen.h" - -namespace paddle { -namespace operators { -namespace functor { - -template -struct Set { - void operator()(const T alpha, framework::Tensor* Y, - platform::DeviceContext* context) { - int N = product(Y->dims()); - T* YData = Y->mutable_data(context->GetPlace()); - if (alpha == static_cast(0)) { - memset(YData, 0, N * sizeof(T)); - } else { - framework::EigenVector::Flatten(*Y) - .setConstant(alpha); - } - } -}; - -template struct Set; -template struct Set; - -} // namespace functor -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/functor/math_functor.cu b/paddle/operators/functor/math_functor.cu deleted file mode 100644 index 6dc828c60a..0000000000 --- a/paddle/operators/functor/math_functor.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/functor/math_functor.h" -#include "paddle/platform/cuda_helper.h" - -namespace paddle { -namespace operators { -namespace functor { - -template -__global__ void SetKernel(const int N, const T alpha, T* Y) { - CUDA_1D_KERNEL_LOOP(i, N) { Y[i] = alpha; } -} - -template -struct Set { - void operator()(const T alpha, framework::Tensor* Y, - platform::DeviceContext* context) { - int N = product(Y->dims()); - T* YData = Y->mutable_data(context->GetPlace()); - SetKernel<<<(N + 512 - 1) / 512, 512>>>(N, alpha, YData); - } -}; - -template struct Set; -template struct Set; - -} // namespace functor -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/functor/math_functor.h b/paddle/operators/functor/math_functor.h deleted file mode 100644 index d5c7bd368f..0000000000 --- a/paddle/operators/functor/math_functor.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/tensor.h" -#include "paddle/platform/device_context.h" - -namespace paddle { -namespace operators { -namespace functor { - -template -struct Set { - void operator()(const T alpha, paddle::framework::Tensor* Y, - paddle::platform::DeviceContext* context); -}; - -} // namespace functor -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 99678ef681..27eee3436a 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/functor/math_functor.h" #include "paddle/platform/assert.h" #include "paddle/platform/cuda_helper.h" @@ -22,11 +22,11 @@ namespace operators { using Tensor = framework::Tensor; -template +template __global__ void LookupTable(T* output, const T* table, const int32_t* ids, const int N, const int K, const int D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDimX; + int idy = blockIdx.x + threadIdx.y * GridDimX; while (idy < K) { int id = ids[idy]; @@ -34,18 +34,18 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids, PADDLE_ASSERT(id < N); T* out = output + idy * D; const T* tab = table + id * D; - for (int i = idx; i < D; i += blockDimX) { + for (int i = idx; i < D; i += BlockDimX) { out[i] = tab[i]; } - idy += blockDimY * gridDimX; + idy += BlockDimY * GridDimX; } } -template +template __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, const int N, const int K, const int D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDimX; + int idy = blockIdx.x + threadIdx.y * GridDimX; while (idy < K) { int id = ids[idy]; @@ -53,10 +53,10 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, PADDLE_ASSERT(id < N); const T* out = output + idy * D; T* tab = table + id * D; - for (int i = idx; i < D; i += blockDimX) { + for (int i = idx; i < D; i += BlockDimX) { paddle::platform::CudaAtomicAdd(&tab[i], out[i]); } - idy += blockDimY * gridDimX; + idy += BlockDimY * GridDimX; } } @@ -96,10 +96,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { const T* d_output = d_output_t->data(); T* d_table = d_table_t->mutable_data(context.GetPlace()); - auto* device_context = - const_cast(context.device_context_); - functor::Set()(static_cast(0), d_table_t, - device_context); + auto t = framework::EigenVector::Flatten(*d_table_t); + t.device(context.GetEigenDevice()) = + t.constant(static_cast(0)); + dim3 threads(128, 8); dim3 grids(8, 1); LookupTableGrad<<>>(d_table, d_output, ids, N, diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 9254e03a1b..4da8079b91 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -14,8 +14,8 @@ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/functor/math_functor.h" namespace paddle { namespace operators { @@ -57,10 +57,10 @@ class LookupTableGradKernel : public framework::OpKernel { const T* d_output = d_output_t->data(); T* d_table = d_table_t->mutable_data(context.GetPlace()); - auto* device_context = - const_cast(context.device_context_); - functor::Set()(static_cast(0), d_table_t, - device_context); + auto t = framework::EigenVector::Flatten(*d_table_t); + t.device(context.GetEigenDevice()) = + t.constant(static_cast(0)); + for (size_t i = 0; i < product(ids_t->dims()); ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h index 939c3713ad..6feec0d7f8 100644 --- a/paddle/platform/cuda_helper.h +++ b/paddle/platform/cuda_helper.h @@ -18,10 +18,6 @@ limitations under the License. */ namespace paddle { namespace platform { -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - #define CUDA_ATOMIC_WRAPPER(op, T) \ __device__ __forceinline__ T CudaAtomic##op(T* address, const T val) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 8b8e2f444b..06b82fa2e4 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -23,6 +23,10 @@ def grad_var_name(var_name): return var_name + "@GRAD" +def empty_var_name(): + return "@EMPTY@" + + def get_numeric_gradient(op, input_values, output_name, @@ -171,7 +175,7 @@ class GradientChecker(unittest.TestCase): ] return outs - def compare_grad(self, forward_op, input_value): + def compare_grad(self, forward_op, input_value, no_grad_set=None): """ Compare the input gradients between CPU and GPU for the given forward operator. @@ -179,15 +183,20 @@ class GradientChecker(unittest.TestCase): :type forward_op: Operator :param input_value: input values. :type input_value: dict{string:numpy.array} + :param no_grad_set: the set of variables names without gradients. + :type no_grad_set: a set of string :raises: AssertionError, there is different gradient value. """ - backward_op = core.Operator.backward(forward_op, set()) + if no_grad_set is None: + no_grad_set = set() + backward_op = core.Operator.backward(forward_op, no_grad_set) # return if not compile with GPU or not implementing GPU kernel if not (core.is_compile_gpu() and backward_op.support_gpu()): return outputs = backward_op.outputs() out_names = [item for k in outputs for item in outputs[k]] + out_names = filter(lambda x: x != empty_var_name(), out_names) cpu_grads = self.__get_gradient(forward_op, backward_op, input_value, out_names, core.CPUPlace()) gpu_grads = self.__get_gradient(forward_op, backward_op, input_value, diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py index 3056bf53e3..19eb464baa 100644 --- a/python/paddle/v2/framework/tests/test_lookup_table.py +++ b/python/paddle/v2/framework/tests/test_lookup_table.py @@ -21,6 +21,8 @@ class TestSigmoidGradOp(GradientChecker): table = np.random.random((17, 31)).astype('float32') ids = np.random.randint(0, 17, 4).astype('int32') inputs = {'W': table, 'Ids': ids} + # comapre gradients + self.compare_grad(op, inputs, set(['Ids'])) # check gradients self.check_grad(op, inputs, set('W'), 'Out') From 6f4b968f5618adce529d12bd2e3b72d4d1b64f61 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 00:58:02 -0700 Subject: [PATCH 14/27] can train the parameters --- python/paddle/v2/framework/tests/mnist.py | 39 +++++++++++++++++------ 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index 32349b8d4d..ededf767bc 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -69,7 +69,7 @@ def init_param(param_name, dims): tensor = var.get_tensor() tensor.set_dims(dims) data = numpy.random.uniform( - low=0.0, high=1.0, size=tensor.shape()).astype("float32") + low=-0.5, high=0.5, size=tensor.shape()).astype("float32") tensor.set(data, place) @@ -109,7 +109,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): bias_name = name + ".b" init_param(param_name=bias_name, dims=[size]) sgd_optimizer( - net=optimize_net, param_name=bias_name, learning_rate=0.01) + net=optimize_net, param_name=bias_name, learning_rate=0.001) bias_out = name + ".rowwise_add.out" scope.new_var(bias_out) rowwise_append_op = Operator( @@ -158,20 +158,33 @@ def print_inputs_outputs(op): def set_cost(): - cost_data = numpy.array(scope.find_var("cross_entropy_1").get_tensor()) + cost_shape = numpy.array(scope.find_var("cross_entropy_3").get_tensor( + )).shape + cost_grad = scope.find_var(grad_var_name("cross_entropy_3")).get_tensor() + cost_grad.set_dims(cost_shape) + cost_grad.alloc_float(place) + cost_grad.set(numpy.ones(cost_shape).astype("float32"), place) + + +def print_cost(): + cost_data = numpy.array(scope.find_var("cross_entropy_3").get_tensor()) print(cost_data.sum() / len(cost_data)) - cost_grad = scope.find_var(grad_var_name("cross_entropy_1")).get_tensor() - cost_grad.set_dims(cost_data.shape) - cost_grad.alloc_float(place) - cost_grad.set(numpy.ones(cost_data.shape).astype("float32"), place) +def error_rate(predict, label): + predict_var = numpy.array(scope.find_var(predict).get_tensor()).argmax( + axis=1) + label = numpy.array(scope.find_var(label).get_tensor()) + error_num = numpy.sum(predict_var != label) + print(error_num / float(len(label))) images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) label = data_layer(name='label', dims=[BATCH_SIZE]) -fc = fc_layer(net=forward_network, input=images, size=10, act="softmax") -cost = cross_entropy_layer(net=forward_network, input=fc, label=label) +fc1 = fc_layer(net=forward_network, input=images, size=100, act="sigmoid") +fc2 = fc_layer(net=forward_network, input=fc1, size=100, act="sigmoid") +predict = fc_layer(net=forward_network, input=fc2, size=100, act="softmax") +cost = cross_entropy_layer(net=forward_network, input=predict, label=label) forward_network.complete_add_op(True) backward_net = get_backward_net(forward_network) @@ -192,8 +205,8 @@ reader = paddle.batch( PASS_NUM = 1000 for pass_id in range(PASS_NUM): + batch_id = 0 - print("pass[" + str(pass_id) + "]") for data in reader(): image = numpy.array(map(lambda x: x[0], data)).astype("float32") label = numpy.array(map(lambda x: x[1], data)).astype("int32") @@ -207,3 +220,9 @@ for pass_id in range(PASS_NUM): backward_net.run(scope, dev_ctx) optimize_net.run(scope, dev_ctx) + if batch_id % 100 == 0: + print("pass[" + str(pass_id) + "] batch_id[" + str(batch_id) + "]") + print_cost() + error_rate(predict, "label") + + batch_id = batch_id + 1 From e3342ff8e79fbe1cacb8fa5a66cb9c69cba1eeb9 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 23 Aug 2017 19:30:46 +0800 Subject: [PATCH 15/27] Fix android build error. --- paddle/cuda/include/stub/hl_cuda_cudnn_stub.h | 11 ++++++++--- paddle/cuda/src/hl_cuda_cudnn.cc | 3 ++- paddle/gserver/tests/test_LayerGrad.cpp | 6 ++++-- python/paddle/trainer/config_parser.py | 6 +++--- python/paddle/trainer_config_helpers/layers.py | 7 +++---- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/paddle/cuda/include/stub/hl_cuda_cudnn_stub.h b/paddle/cuda/include/stub/hl_cuda_cudnn_stub.h index abd0d6b099..3afcc6fa85 100644 --- a/paddle/cuda/include/stub/hl_cuda_cudnn_stub.h +++ b/paddle/cuda/include/stub/hl_cuda_cudnn_stub.h @@ -78,7 +78,9 @@ inline void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, int padding_height, int padding_width, int stride_height, - int stride_width) {} + int stride_width, + int dilation_h, + int dilation_w) {} inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, hl_tensor_descriptor image, @@ -86,7 +88,9 @@ inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, int padding_height, int padding_width, int stride_height, - int stride_width) {} + int stride_width, + int dilation_h, + int dilation_w) {} inline void hl_destroy_convolution_descriptor(hl_convolution_descriptor conv) {} @@ -99,7 +103,8 @@ inline void hl_conv_workspace(hl_tensor_descriptor input, int* convBwdDataAlgo, size_t* bwdDataLimitBytes, int* convBwdFilterAlgo, - size_t* bwdFilterLimitBytes) {} + size_t* bwdFilterLimitBytes, + bool useDilation) {} inline void hl_convolution_forward(hl_tensor_descriptor input, real* input_data, diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index f55fa523e1..f38ef69255 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -640,7 +640,8 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, #else if (dilation_h > 1 || dilation_w > 1) { LOG(FATAL) - << "Current cudnn version does't support for dilation convolution."; + << "Current cuDNN version does't support for dilation convolution. " + << "The dilation convolution requires cuDNN >= v6.0."; } CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc, diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 9348c47bd4..9946f76664 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef PADDLE_ONLY_CPU #include +#endif #include #include #include @@ -262,8 +264,8 @@ TEST(Projection, conv) { testProjectionConv(1, false); testProjectionConv(3, false); /// test ConvTransProjection - /// testProjectionConv(1, true); - /// testProjectionConv(3, true); + testProjectionConv(1, true); + testProjectionConv(3, true); } #endif diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 33a20afb18..ddfd615d84 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -862,7 +862,6 @@ class Conv(Cfg): filter_size, channels, padding=None, - dilation=None, stride=None, groups=None, filter_channels=None, @@ -871,8 +870,9 @@ class Conv(Cfg): caffe_mode=True, filter_size_y=None, padding_y=None, - dilation_y=None, - stride_y=None): + stride_y=None, + dilation=None, + dilation_y=None): self.add_keys(locals()) if filter_size_y is None: self.filter_size_y = filter_size diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 74b88cd4f8..9876798558 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2340,7 +2340,7 @@ def img_conv_layer(input, groups=1, stride=1, padding=0, - dilation=0, + dilation=1, bias_attr=None, param_attr=None, shared_biases=True, @@ -2472,9 +2472,6 @@ def img_conv_layer(input, else: dilation_y = dilation - if dilation > 1 or dilation_y > 1: - assert layer_type in ["cudnn_conv", "cudnn_convt"] - if param_attr.attr.get('initial_smart'): # special initial for conv layers. init_w = (2.0 / (filter_size**2 * num_channels))**0.5 @@ -2484,6 +2481,8 @@ def img_conv_layer(input, param_attr.attr["initial_smart"] = False if layer_type: + if dilation > 1 or dilation_y > 1: + assert layer_type in ["cudnn_conv", "cudnn_convt"] if trans: assert layer_type in ["exconvt", "cudnn_convt"] else: From 76677f25774a84d9ced011be02e62ae15b03506c Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 09:12:34 -0700 Subject: [PATCH 16/27] add test --- python/paddle/v2/framework/tests/mnist.py | 36 ++++++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index ededf767bc..e47de2436f 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -52,7 +52,7 @@ def grad_var_name(var_name): return var_name + "@GRAD" -def sgd_optimizer(net, param_name, learning_rate=0.01): +def sgd_optimizer(net, param_name, learning_rate=0.005): grad_name = grad_var_name(param_name) optimize_op = Operator( "sgd", @@ -166,9 +166,9 @@ def set_cost(): cost_grad.set(numpy.ones(cost_shape).astype("float32"), place) -def print_cost(): +def mean_cost(): cost_data = numpy.array(scope.find_var("cross_entropy_3").get_tensor()) - print(cost_data.sum() / len(cost_data)) + return cost_data.sum() / len(cost_data) def error_rate(predict, label): @@ -176,7 +176,7 @@ def error_rate(predict, label): axis=1) label = numpy.array(scope.find_var(label).get_tensor()) error_num = numpy.sum(predict_var != label) - print(error_num / float(len(label))) + return error_num / float(len(label)) images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) @@ -198,16 +198,35 @@ print_inputs_outputs(forward_network) print_inputs_outputs(backward_net) print_inputs_outputs(optimize_net) -reader = paddle.batch( +train_reader = paddle.batch( paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=8192), batch_size=BATCH_SIZE) + +def test(): + test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + cost = [] + error = [] + for data in test_reader(): + image = numpy.array(map(lambda x: x[0], data)).astype("float32") + label = numpy.array(map(lambda x: x[1], data)).astype("int32") + feed_data("pixel", image) + feed_data("label", label) + + forward_network.infer_shape(scope) + forward_network.run(scope, dev_ctx) + cost.append(mean_cost()) + error.append(error_rate(predict, "label")) + print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str( + sum(error) / float(len(error)))) + + PASS_NUM = 1000 for pass_id in range(PASS_NUM): batch_id = 0 - for data in reader(): + for data in train_reader(): image = numpy.array(map(lambda x: x[0], data)).astype("float32") label = numpy.array(map(lambda x: x[1], data)).astype("int32") feed_data("pixel", image) @@ -222,7 +241,8 @@ for pass_id in range(PASS_NUM): optimize_net.run(scope, dev_ctx) if batch_id % 100 == 0: print("pass[" + str(pass_id) + "] batch_id[" + str(batch_id) + "]") - print_cost() - error_rate(predict, "label") + test() + # print(mean_cost()) + # print(error_rate(predict, "label")) batch_id = batch_id + 1 From cf515e4a72f4b02fbbbfdbd79c3b66b1be694e7b Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 09:39:47 -0700 Subject: [PATCH 17/27] optimize code and name --- python/paddle/v2/framework/tests/mnist.py | 56 +++++++++++------------ 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index e47de2436f..886e99610d 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -134,7 +134,7 @@ def cross_entropy_layer(net, input, label): return cost_name -def get_backward_net(forward_net): +def create_backward_net(forward_net): net = core.Operator.backward(forward_net, set()) for input in net.inputs()["all"]: var = scope.new_var(input) @@ -145,29 +145,29 @@ def get_backward_net(forward_net): return net -def print_inputs_outputs(op): +def debug_print_op(op): print("===============" + op.type() + "==============") print("***inputs:***") for input in op.inputs()["all"]: print input, scope.find_var(input).get_tensor().get_dims() - print("***outputs:***") + print("\n***outputs:***") for output in op.outputs()["all"]: print output, scope.find_var(output).get_tensor().get_dims() print("") print("") -def set_cost(): - cost_shape = numpy.array(scope.find_var("cross_entropy_3").get_tensor( - )).shape - cost_grad = scope.find_var(grad_var_name("cross_entropy_3")).get_tensor() +def set_cost(cost): + cost_shape = numpy.array(scope.find_var(cost).get_tensor()).shape + cost_grad = \ + scope.find_var(grad_var_name(cost)).get_tensor() cost_grad.set_dims(cost_shape) cost_grad.alloc_float(place) cost_grad.set(numpy.ones(cost_shape).astype("float32"), place) -def mean_cost(): - cost_data = numpy.array(scope.find_var("cross_entropy_3").get_tensor()) +def mean_cost(cost): + cost_data = numpy.array(scope.find_var(cost).get_tensor()) return cost_data.sum() / len(cost_data) @@ -180,23 +180,23 @@ def error_rate(predict, label): images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) -label = data_layer(name='label', dims=[BATCH_SIZE]) +labels = data_layer(name='label', dims=[BATCH_SIZE]) fc1 = fc_layer(net=forward_network, input=images, size=100, act="sigmoid") fc2 = fc_layer(net=forward_network, input=fc1, size=100, act="sigmoid") predict = fc_layer(net=forward_network, input=fc2, size=100, act="softmax") -cost = cross_entropy_layer(net=forward_network, input=predict, label=label) +cost = cross_entropy_layer(net=forward_network, input=predict, label=labels) forward_network.complete_add_op(True) -backward_net = get_backward_net(forward_network) +backward_net = create_backward_net(forward_network) optimize_net.complete_add_op(True) print(forward_network) print(backward_net) print(optimize_net) -print_inputs_outputs(forward_network) -print_inputs_outputs(backward_net) -print_inputs_outputs(optimize_net) +debug_print_op(forward_network) +debug_print_op(backward_net) +debug_print_op(optimize_net) train_reader = paddle.batch( paddle.reader.shuffle( @@ -204,19 +204,19 @@ train_reader = paddle.batch( batch_size=BATCH_SIZE) -def test(): +def test(cost_name): test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) cost = [] error = [] for data in test_reader(): - image = numpy.array(map(lambda x: x[0], data)).astype("float32") - label = numpy.array(map(lambda x: x[1], data)).astype("int32") - feed_data("pixel", image) - feed_data("label", label) + image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") + label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + feed_data(images, image_data) + feed_data(labels, label_data) forward_network.infer_shape(scope) forward_network.run(scope, dev_ctx) - cost.append(mean_cost()) + cost.append(mean_cost(cost_name)) error.append(error_rate(predict, "label")) print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str( sum(error) / float(len(error)))) @@ -227,22 +227,20 @@ for pass_id in range(PASS_NUM): batch_id = 0 for data in train_reader(): - image = numpy.array(map(lambda x: x[0], data)).astype("float32") - label = numpy.array(map(lambda x: x[1], data)).astype("int32") - feed_data("pixel", image) - feed_data("label", label) + image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") + label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + feed_data(images, image_data) + feed_data(labels, label_data) forward_network.infer_shape(scope) forward_network.run(scope, dev_ctx) - set_cost() + set_cost(cost) backward_net.infer_shape(scope) backward_net.run(scope, dev_ctx) optimize_net.run(scope, dev_ctx) if batch_id % 100 == 0: print("pass[" + str(pass_id) + "] batch_id[" + str(batch_id) + "]") - test() - # print(mean_cost()) - # print(error_rate(predict, "label")) + test(cost) batch_id = batch_id + 1 From 9db4ad6130d79d72fa150e534b5b54fa723c3240 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 09:42:58 -0700 Subject: [PATCH 18/27] reduce pass num to 1 --- python/paddle/v2/framework/tests/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index 886e99610d..eefd5709a3 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -222,7 +222,7 @@ def test(cost_name): sum(error) / float(len(error)))) -PASS_NUM = 1000 +PASS_NUM = 1 for pass_id in range(PASS_NUM): batch_id = 0 From 37cd8165b3089c8e4a6ce743f5e0ee8c029ba46b Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 10:56:56 -0700 Subject: [PATCH 19/27] change 128 to BATCH_SIZE --- python/paddle/v2/framework/tests/mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index eefd5709a3..e878bfa4e9 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -205,7 +205,8 @@ train_reader = paddle.batch( def test(cost_name): - test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) cost = [] error = [] for data in test_reader(): From da7a1f2f6c355b1bcdc0bd88e644f027d70f75d8 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 23 Aug 2017 21:30:08 +0000 Subject: [PATCH 20/27] master client: retry connecting to etcd --- go/master/client.go | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/go/master/client.go b/go/master/client.go index 62801b9b7f..9344c6f0ab 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -60,13 +60,30 @@ func WithAddr(addr string) func(c *Client) error { } } +func retry(f func() error, dur time.Duration, count int) error { + err := f() + if err != nil { + if count > 0 { + return retry(f, dur, count-1) + } + return err + } + return nil +} + // WithEtcd sets the client to use etcd for master discovery. func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error { return func(c *Client) error { - cli, err := clientv3.New(clientv3.Config{ - Endpoints: endpoints, - DialTimeout: timeout, - }) + var cli *clientv3.Client + f := func() error { + var err error + cli, err = clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: timeout, + }) + return err + } + err := retry(f, time.Second, 10) if err != nil { return err } @@ -101,9 +118,6 @@ func NewClient(opts ...func(*Client) error) (*Client, error) { } } c.ch = make(chan record, c.bufSize) - // FIXME: connection is created asyncrosly in monitorMaster go routine, - // ensure the connection is ready for use before calling c.addClient. - time.Sleep(time.Second) return c, nil } From 5270585e107b16dc527ada329dddf6fc44714a35 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 23 Aug 2017 21:38:43 +0000 Subject: [PATCH 21/27] fix according to comment --- go/master/client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/master/client.go b/go/master/client.go index 9344c6f0ab..199690d488 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -64,6 +64,7 @@ func retry(f func() error, dur time.Duration, count int) error { err := f() if err != nil { if count > 0 { + time.Sleep(dur) return retry(f, dur, count-1) } return err From 05176bd1bb5af94bfbabbb524ed9e65448134e39 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Thu, 24 Aug 2017 01:23:27 +0000 Subject: [PATCH 22/27] master server will wait etcd forever --- go/master/client.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/go/master/client.go b/go/master/client.go index 199690d488..f04cf50ce3 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -60,18 +60,6 @@ func WithAddr(addr string) func(c *Client) error { } } -func retry(f func() error, dur time.Duration, count int) error { - err := f() - if err != nil { - if count > 0 { - time.Sleep(dur) - return retry(f, dur, count-1) - } - return err - } - return nil -} - // WithEtcd sets the client to use etcd for master discovery. func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error { return func(c *Client) error { @@ -84,9 +72,14 @@ func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error { }) return err } - err := retry(f, time.Second, 10) - if err != nil { - return err + for { + err := f() + if err != nil { + log.Warningln(err) + } else { + break + } + time.Sleep(time.Second) } ch := make(chan string, 1) From 0e300f9bf04ba459dbef93af9537f847cebbcd27 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 20:14:54 -0700 Subject: [PATCH 23/27] use init_net and random_op to initialize parameter --- python/paddle/v2/framework/tests/mnist.py | 54 +++++++++++------------ 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index e878bfa4e9..0c27ce3e35 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -9,11 +9,8 @@ scope = core.Scope() place = core.CPUPlace() dev_ctx = core.DeviceContext.create(place) -# init_net = core.Net.create() -forward_network = core.Net.create() - -# should be init after forward_op is constructed -# backward_net = core.Operator.backward(forward_net, set()) +init_net = core.Net.create() +forward_net = core.Net.create() backward_net = None optimize_net = core.Net.create() @@ -64,13 +61,12 @@ def sgd_optimizer(net, param_name, learning_rate=0.005): # should use operator and add these to the init_network -def init_param(param_name, dims): - var = scope.new_var(param_name) - tensor = var.get_tensor() - tensor.set_dims(dims) - data = numpy.random.uniform( - low=-0.5, high=0.5, size=tensor.shape()).astype("float32") - tensor.set(data, place) +def init_param(net, param_name, dims): + scope.new_var(param_name) + op = Operator( + "uniform_random", Out=param_name, dims=dims, min=-0.5, max=0.5, seed=10) + op.infer_shape(scope) + net.append_op(op) # fc_layer @@ -96,7 +92,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): input_dims = scope.find_var(input).get_tensor().get_dims() w_name = param or name + ".w" - init_param(param_name=w_name, dims=[input_dims[1], size]) + init_param(net=init_net, param_name=w_name, dims=[input_dims[1], size]) sgd_optimizer(net=optimize_net, param_name=w_name, learning_rate=0.01) pre_activation = name + ".mul.out" @@ -107,7 +103,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): # create bias variable if needed if bias: bias_name = name + ".b" - init_param(param_name=bias_name, dims=[size]) + init_param(net=init_net, param_name=bias_name, dims=[size]) sgd_optimizer( net=optimize_net, param_name=bias_name, learning_rate=0.001) bias_out = name + ".rowwise_add.out" @@ -181,20 +177,22 @@ def error_rate(predict, label): images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) labels = data_layer(name='label', dims=[BATCH_SIZE]) -fc1 = fc_layer(net=forward_network, input=images, size=100, act="sigmoid") -fc2 = fc_layer(net=forward_network, input=fc1, size=100, act="sigmoid") -predict = fc_layer(net=forward_network, input=fc2, size=100, act="softmax") -cost = cross_entropy_layer(net=forward_network, input=predict, label=labels) - -forward_network.complete_add_op(True) -backward_net = create_backward_net(forward_network) +fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid") +fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid") +predict = fc_layer(net=forward_net, input=fc2, size=100, act="softmax") +cost = cross_entropy_layer(net=forward_net, input=predict, label=labels) + +init_net.complete_add_op(True) +forward_net.complete_add_op(True) +backward_net = create_backward_net(forward_net) optimize_net.complete_add_op(True) -print(forward_network) +print(init_net) +print(forward_net) print(backward_net) print(optimize_net) -debug_print_op(forward_network) +debug_print_op(forward_net) debug_print_op(backward_net) debug_print_op(optimize_net) @@ -215,8 +213,8 @@ def test(cost_name): feed_data(images, image_data) feed_data(labels, label_data) - forward_network.infer_shape(scope) - forward_network.run(scope, dev_ctx) + forward_net.infer_shape(scope) + forward_net.run(scope, dev_ctx) cost.append(mean_cost(cost_name)) error.append(error_rate(predict, "label")) print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str( @@ -224,6 +222,8 @@ def test(cost_name): PASS_NUM = 1 + +init_net.run(scope, dev_ctx) for pass_id in range(PASS_NUM): batch_id = 0 @@ -233,8 +233,8 @@ for pass_id in range(PASS_NUM): feed_data(images, image_data) feed_data(labels, label_data) - forward_network.infer_shape(scope) - forward_network.run(scope, dev_ctx) + forward_net.infer_shape(scope) + forward_net.run(scope, dev_ctx) set_cost(cost) backward_net.infer_shape(scope) backward_net.run(scope, dev_ctx) From 0ee18a86d18b4d4506c63e13b2953c9153c27f8d Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 24 Aug 2017 11:50:35 +0800 Subject: [PATCH 24/27] Fix doc. --- python/paddle/trainer_config_helpers/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b3568cc257..f323b017c0 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2418,8 +2418,8 @@ def img_conv_layer(input, :param dilation: The x dimension of the dilation. Or input a tuple for two image dimension :type dilation: int|tuple|list - :param padding_y: The y dimension of the dilation. - :type padding_y: int + :param dilation_y: The y dimension of the dilation. + :type dilation_y: int :param bias_attr: Convolution bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|False From 12864f142073b4a280120e4d9b3abe4e2483ca32 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 22:51:35 -0700 Subject: [PATCH 25/27] register rowwise add gpu kernel --- paddle/operators/rowwise_add_op.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index cbc61ad3e1..4a57f64c89 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -18,3 +18,6 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( rowwise_add, ops::RowwiseAddKernel); +REGISTER_OP_GPU_KERNEL( + rowwise_add_grad, + ops::RowwiseAddGradKernel); From 3648165b63bd5331d1809cba896176e4af0a9ff2 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 23:00:45 -0700 Subject: [PATCH 26/27] add gpu support --- python/paddle/v2/framework/tests/mnist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index 0c27ce3e35..d9941023fe 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -7,6 +7,8 @@ BATCH_SIZE = 100 scope = core.Scope() place = core.CPUPlace() +# if you want to test GPU training, you can use gpu place +# place = core.GPUPlace(0) dev_ctx = core.DeviceContext.create(place) init_net = core.Net.create() From 625b15355a16fa42476e7dbd166b77e092dcb97f Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 23 Aug 2017 23:56:55 -0700 Subject: [PATCH 27/27] optimize code --- python/paddle/v2/framework/tests/mnist.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index d9941023fe..9a0b109850 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -17,14 +17,14 @@ backward_net = None optimize_net = core.Net.create() -def atom_id(): +def atomic_id(): id = 0 while True: yield id id += 1 -uniq_id = atom_id().next +uniq_id = atomic_id().next def data_layer(name, dims): @@ -164,7 +164,7 @@ def set_cost(cost): cost_grad.set(numpy.ones(cost_shape).astype("float32"), place) -def mean_cost(cost): +def get_cost_mean(cost): cost_data = numpy.array(scope.find_var(cost).get_tensor()) return cost_data.sum() / len(cost_data) @@ -217,7 +217,7 @@ def test(cost_name): forward_net.infer_shape(scope) forward_net.run(scope, dev_ctx) - cost.append(mean_cost(cost_name)) + cost.append(get_cost_mean(cost_name)) error.append(error_rate(predict, "label")) print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str( sum(error) / float(len(error))))