From 8a645685ce592789e9c706e7581f8939ae38f311 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 15 Mar 2018 01:03:34 +0800 Subject: [PATCH 01/79] Add sum accumulator with window for model average --- .../fluid/operators/average_accumulates_op.cc | 152 ++++++++++++++++++ .../fluid/operators/average_accumulates_op.cu | 59 +++++++ .../fluid/operators/average_accumulates_op.h | 118 ++++++++++++++ 3 files changed, 329 insertions(+) create mode 100644 paddle/fluid/operators/average_accumulates_op.cc create mode 100644 paddle/fluid/operators/average_accumulates_op.cu create mode 100644 paddle/fluid/operators/average_accumulates_op.h diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc new file mode 100644 index 0000000000..808693b61c --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -0,0 +1,152 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" + +namespace paddle { +namespace operators { + +template <> +void getAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("num_accumulates"); + auto* in_num_updates = ctx.Input("num_updates"); + + old_num_accumulates_ = in_old_num_accumulates->data()[0]; + num_accumulates_ = in_num_accumulates->data()[0]; + num_updates_ = in_num_updates->data()[0]; +} + +template <> +void setAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto* out_old_num_accumulates = ctx.Output("old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("num_accumulates"); + auto* out_num_updates = ctx.Output("num_updates"); + + out_old_num_accumulates->data()[0] = old_num_accumulates_; + out_num_accumulates->data()[0] = num_accumulates_; + out_num_updates->data()[0] = num_updates_; +} + +class AverageAccumulatesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("Param"), + "Input (Param) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("Grad"), + "Input (Grad) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("sum_1"), + "Input (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("sum_2"), + "Input (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("sum_3"), + "Input (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE(ctx->HasInput("num_accumulates"), + "Input (num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasInput("old_num_accumulates"), + "Input (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("num_updates"), + "Input (num_updates) of average_accumulates op should not be null."); + + PADDLE_ENFORCE( + ctx->HasOutput("sum_1"), + "Output (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("sum_2"), + "Output (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("sum_3"), + "Output (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("num_accumulates"), + "Output (num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasOutput("old_num_accumulates"), + "Output (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("num_updates"), + "Output (num_updates) of average_accumulates op should not be null."); + + auto in_dim = ctx->GetInputDim("Param"); + + ctx->SetOutputDim("sum_1", in_dim); + ctx->SetOutputDim("sum_2", in_dim); + ctx->SetOutputDim("sum_3", in_dim); + ctx->SetOutputDim("num_accumulates", {1}); + ctx->SetOutputDim("old_num_accumulates", {1}); + ctx->SetOutputDim("num_updates", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Param")->type()), + ctx.GetPlace()); + } +}; + +class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("sum_1", ""); + AddInput("sum_2", ""); + AddInput("sum_3", ""); + AddInput("num_accumulates", ""); + AddInput("old_num_accumulates", ""); + AddInput("num_updates", ""); + + AddOutput("sum_1", ""); + AddOutput("sum_2", ""); + AddOutput("sum_3", ""); + AddOutput("num_accumulates", ""); + AddOutput("old_num_accumulates", ""); + AddOutput("num_updates", ""); + + AddAttr("", "average_window"); + AddAttr("", "max_average_window"); + AddAttr("", "min_average_window"); + + AddComment(R"DOC( +AverageAccumulates Operator. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(average_accumulate, ops::AverageAccumulatesOp, + ops::AverageAccumulatesOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + average_accumulate, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.cu b/paddle/fluid/operators/average_accumulates_op.cu new file mode 100644 index 0000000000..56f2f02fd2 --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cu @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +template <> +void getAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("num_accumulates"); + auto* in_num_updates = ctx.Input("num_updates"); + + memory::Copy(platform::CPUPlace(), &old_num_accumulates_, + platform::CUDAPlace(), in_old_num_accumulates->data(), + sizeof(int64_t)); + memory::Copy(platform::CPUPlace(), &num_accumulates_, platform::CUDAPlace(), + in_old_num_accumulates->data(), sizeof(int64_t)); + memory::Copy(platform::CPUPlace(), &num_updates_, platform::CUDAPlace(), + in_num_updates->data(), sizeof(int64_t)); +} + +template <> +void setAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto* out_old_num_accumulates = ctx.Output("old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("num_accumulates"); + auto* out_num_updates = ctx.Output("num_updates"); + + memory::Copy(platform::CUDAPlace(), out_old_num_accumulates->data(), + platform::CPUPlace(), &old_num_accumulates_, sizeof(int64_t)); + memory::Copy(platform::CUDAPlace(), out_num_accumulates->data(), + platform::CPUPlace(), &num_accumulates_, sizeof(int64_t)); + memory::Copy(platform::CUDAPlace(), out_num_updates->data(), + platform::CPUPlace(), &num_updates_, sizeof(int64_t)); +} +} +} + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + average_accumulate, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.h b/paddle/fluid/operators/average_accumulates_op.h new file mode 100644 index 0000000000..73814dd24b --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenVector = framework::EigenVector; + +template +void getAccumulators(const framework::ExecutionContext& ctx, + int64_t& num_updates_, int64_t& num_accumulates_, + int64_t& old_num_accumulates_); + +template +void setAccumulators(const framework::ExecutionContext& ctx, + int64_t num_updates_, int64_t num_accumulates_, + int64_t old_num_accumulates_); + +template +class AverageAccumulatesKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + static const int64_t kMaxNumAccumulates = 16384; + // accumulators + int64_t num_updates_ = 0; + int64_t num_accumulates_ = 0; + int64_t old_num_accumulates_ = 0; + // attrs + int64_t min_average_window_; + int64_t max_average_window_; + float average_window_; + + auto* param = ctx.Input("Param"); + auto* in_sum_1 = ctx.Input("sum_1"); + auto* in_sum_2 = ctx.Input("sum_2"); + auto* in_sum_3 = ctx.Input("sum_3"); + + auto* out_sum_1 = ctx.Output("sum_1"); + auto* out_sum_2 = ctx.Output("sum_2"); + auto* out_sum_3 = ctx.Output("sum_3"); + + getAccumulators(ctx, num_updates_, num_accumulates_, + old_num_accumulates_); + average_window_ = ctx.Attr("average_window"); + max_average_window_ = + ctx.Attr("max_average_window"); // default bach number + min_average_window_ = + ctx.Attr("min_average_window"); // default 10000L + min_average_window_ = + std::min(min_average_window_, max_average_window_); + + auto param_tensor = EigenVector::Flatten(*param); + auto in_sum_1_tensor = EigenVector::Flatten(*in_sum_1); + auto in_sum_2_tensor = EigenVector::Flatten(*in_sum_2); + auto in_sum_3_tensor = EigenVector::Flatten(*in_sum_3); + auto out_sum_1_tensor = EigenVector::Flatten(*out_sum_1); + auto out_sum_2_tensor = EigenVector::Flatten(*out_sum_2); + auto out_sum_3_tensor = EigenVector::Flatten(*out_sum_3); + + auto& place = *ctx.template device_context().eigen_device(); + math::SetConstant constant_functor; + // start batch + ++num_updates_; + ++num_accumulates_; + + // update + out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor; + + out_sum_2_tensor.device(place) = in_sum_2_tensor; + out_sum_3_tensor.device(place) = in_sum_3_tensor; + // needSpecialTraversal + if (num_updates_ % kMaxNumAccumulates == 0) { + out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + } + + if (num_accumulates_ >= min_average_window_ && + num_accumulates_ >= std::min(max_average_window_, + num_updates_ * average_window_)) { + out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + constant_functor(ctx.template device_context(), out_sum_2, + 0.0); + + // finishBatch + old_num_accumulates_ = num_accumulates_; + num_accumulates_ = 0; + } + setAccumulators(ctx, num_updates_, num_accumulates_, + old_num_accumulates_); + } +}; + +} // namespace operators +} // namespace paddle From 39c676e20861ecd37a055fc48e0e294803bc3e4a Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Fri, 16 Mar 2018 16:59:55 -0700 Subject: [PATCH 02/79] initial commit --- paddle/fluid/operators/batch_norm_op.cu.cc | 8 ++++---- paddle/fluid/operators/math/math_function.cc | 1 + paddle/fluid/operators/math/math_function.cu | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 2d1556efc6..949497f48c 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -270,9 +270,9 @@ class BatchNormGradKernel } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - batch_norm, - ops::BatchNormKernel); + batch_norm, ops::BatchNormKernel, + ops::BatchNormKernel); REGISTER_OP_CUDA_KERNEL( - batch_norm_grad, - ops::BatchNormGradKernel); + batch_norm_grad, ops::BatchNormGradKernel); diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 35d251f71a..1cbd2fa870 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -278,6 +278,7 @@ void axpy( cblas_daxpy(n, alpha, x, 1, y, 1); } +template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 3abbcdb71d..bccfaef9ce 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -348,6 +348,7 @@ void axpy( &alpha, x, 1, y, 1)); } +template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; From 0a95a44b9a70e16601e5bee9133db58695abd16c Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sat, 17 Mar 2018 17:00:01 -0700 Subject: [PATCH 03/79] add python batch norm inference test --- paddle/fluid/operators/batch_norm_op.cu.cc | 4 +- .../tests/unittests/test_batch_norm_op.py | 69 ++++++++++++++++++- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 949497f48c..f4919398eb 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -125,8 +125,8 @@ class BatchNormKernel auto &dev_ctx = ctx.template device_context(); math::SetConstant functor; - functor(dev_ctx, saved_mean, 0); - functor(dev_ctx, saved_variance, 0); + functor(dev_ctx, saved_mean, static_cast(0)); + functor(dev_ctx, saved_variance, static_cast(0)); auto handle = dev_ctx.cudnn_handle(); diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 80e6fa6df3..d5a57bdd73 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set): return backward_op +def _reference_testing(x, scale, offset, mean, var, epsilon, data_format): + x_shape = x.shape + if len(x_shape) == 2: + if data_format == "NCHW": + x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1)) + else: + x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1])) + + if data_format == "NCHW": + n, c, h, w = x.shape + mean_tile = np.reshape(mean, (1, c, 1, 1)) + mean_tile = np.tile(mean_tile, (n, 1, h, w)) + var_tile = np.reshape(var, (1, c, 1, 1)) + var_tile = np.tile(var_tile, (n, 1, h, w)) + normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon) + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + offset_tile = np.reshape(offset, (1, c, 1, 1)) + offset_tile = np.reshape(offset_tile, (1, c, 1, 1)) + y = normalized * scale_tile + offset_tile + elif data_format == "NHWC": + normalized = (x - mean) / np.sqrt(var + epsilon) + y = normalized * scale + offset + else: + raise ValueError("Unknown data order.") + + if len(x_shape) == 2: + y = np.reshape(y, x_shape) + return y + + def _reference_training(x, scale, offset, epsilon, data_format): x_shape = x.shape if len(x_shape) == 2: @@ -155,7 +186,43 @@ def set_output_grad(scope, outputs, place, feed_dict=None): __set_tensor__(output, data) -class TestBatchNormOp(OpTest): +class TestBatchNormOpInference(OpTest): + def setUp(self): + self.dtype = np.float32 + + def test_python(self): + data_format = "NHWC" + epsilon = 0.00001 + + n, h, w, c = 2, 3, 4, 5 + x_shape = [n, h, w, c] + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(self.dtype) + scale_val = np.random.random_sample(scale_shape).astype(self.dtype) + bias_val = np.random.random_sample(scale_shape).astype(self.dtype) + + mean = np.zeros(scale_shape).astype(self.dtype) + variance = np.ones(scale_shape).astype(self.dtype) + + # run forward + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, + epsilon, "NHWC") + + # running N, C, H, W case + # should produce the same results + x_shape2 = [n, c, h, w] + x_val2 = np.transpose(x_val, (0, 3, 1, 2)) + y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance, + epsilon, "NCHW") + + # transfer (N, C, H, W) back to (N, H, W, C) + y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) + self.__assert_close(y_out, y_out2_trans, "inference output") + print 'python: NHWC, NCHW, inference checking passed' + + +class TestBatchNormOpTraining(OpTest): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) From 151cfff90b4e860baba5a374cf044d6ce8b8c7ff Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sat, 17 Mar 2018 20:51:20 -0700 Subject: [PATCH 04/79] add more tests --- .../tests/unittests/test_batch_norm_op.py | 94 ++++++++++++++++--- 1 file changed, 80 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index d5a57bdd73..f631050e2a 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -188,14 +188,27 @@ def set_output_grad(scope, outputs, place, feed_dict=None): class TestBatchNormOpInference(OpTest): def setUp(self): + self.op_type = "conv2d" + self.is_test = True self.dtype = np.float32 + self.data_layout = "NCHW" + init_dtype() + init_data_layout() + init_test_case() - def test_python(self): - data_format = "NHWC" epsilon = 0.00001 - - n, h, w, c = 2, 3, 4, 5 - x_shape = [n, h, w, c] + shape = self.shape + if len(shape) == 2: + x_shape = shape + c = x_shape[1] + else: + n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if self.data_layout == "NHWC": + x_shape = [n, h, w, c] + elif self.data_layout == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data layout.") scale_shape = [c] x_val = np.random.random_sample(x_shape).astype(self.dtype) @@ -205,7 +218,64 @@ class TestBatchNormOpInference(OpTest): mean = np.zeros(scale_shape).astype(self.dtype) variance = np.ones(scale_shape).astype(self.dtype) - # run forward + saved_mean = np.zeros(scale_shape).astype(self.dtype) + saved_variance = np.ones(scale_shape).astype(self.dtype) + + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, + epsilon, self.data_layout).astype(self.dtype) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(x_val), + 'Scale': OpTest.np_dtype_to_fluid_dtype(scale_val), + 'Bias': OpTest.np_dtype_to_fluid_dtype(bias_val), + 'Mean': OpTest.np_dtype_to_fluid_dtype(mean), + 'Variance': OpTest.np_dtype_to_fluid_dtype(variance) + } + self.attrs = { + 'is_test': self.is_test, + 'epsilon': epsilon, + 'data_layout': self.data_layout + } + self.outputs = { + 'Y': y_out, + 'MeanOut': mean, + 'VarianceOut': variance, + 'SavedMean': saved_mean, + 'SavedVariance': saved_variance + } + + def test_check_output(self): + self.check_output() + + def init_dtype(self): + pass + + def init_data_layout(self): + pass + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + + +class TestBatchNormOpTraining(OpTest): + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def test_python_testing(self): + data_format = "NHWC" + epsilon = 0.00001 + + n, h, w, c = 2, 3, 4, 5 + x_shape = [n, h, w, c] + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, epsilon, "NHWC") @@ -218,15 +288,11 @@ class TestBatchNormOpInference(OpTest): # transfer (N, C, H, W) back to (N, H, W, C) y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) - self.__assert_close(y_out, y_out2_trans, "inference output") + self.__assert_close(y_out, y_out2_trans, + "inference outputs of two formats have differences") print 'python: NHWC, NCHW, inference checking passed' - -class TestBatchNormOpTraining(OpTest): - def __assert_close(self, tensor, np_array, msg, atol=1e-4): - self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - - def test_python(self): + def test_python_training(self): data_format = "NHWC" epsilon = 0.00001 momentum = 0.9 @@ -264,7 +330,7 @@ class TestBatchNormOpTraining(OpTest): # transfer (N, C, H, W) back to (N, H, W, C) y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) - self.__assert_close(y_out, y_out2_trans, "batch variance") + self.__assert_close(y_out, y_out2_trans, "batch output") print 'python: NHWC, NCHW, forward checking passed' # test backward now From 5e36757c374b0e9c17dc7ab00ab87b10afc43c26 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sat, 17 Mar 2018 23:14:03 -0700 Subject: [PATCH 05/79] fix test --- .../tests/unittests/test_batch_norm_op.py | 152 ++++++++++-------- 1 file changed, 88 insertions(+), 64 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index f631050e2a..2f2873c183 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -187,74 +187,99 @@ def set_output_grad(scope, outputs, place, feed_dict=None): class TestBatchNormOpInference(OpTest): - def setUp(self): - self.op_type = "conv2d" - self.is_test = True - self.dtype = np.float32 - self.data_layout = "NCHW" - init_dtype() - init_data_layout() - init_test_case() + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - epsilon = 0.00001 - shape = self.shape - if len(shape) == 2: - x_shape = shape - c = x_shape[1] - else: - n, h, w, c = shape[0], shape[1], shape[2], shape[3] - if self.data_layout == "NHWC": - x_shape = [n, h, w, c] - elif self.data_layout == "NCHW": - x_shape = [n, c, h, w] + def test_inference(self): + def test_with_place(place, data_layout, dtype, shape): + epsilon = 0.00001 + if len(shape) == 2: + x_shape = shape + c = x_shape[1] else: - raise ValueError("Unknown data layout.") - scale_shape = [c] + n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if data_layout == "NHWC": + x_shape = [n, h, w, c] + elif data_layout == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data layout.") + scale_shape = [c] - x_val = np.random.random_sample(x_shape).astype(self.dtype) - scale_val = np.random.random_sample(scale_shape).astype(self.dtype) - bias_val = np.random.random_sample(scale_shape).astype(self.dtype) + x_val = np.random.random_sample(x_shape).astype(dtype) + scale_val = np.random.random_sample(scale_shape).astype(dtype) + bias_val = np.random.random_sample(scale_shape).astype(dtype) - mean = np.zeros(scale_shape).astype(self.dtype) - variance = np.ones(scale_shape).astype(self.dtype) + mean = np.zeros(scale_shape).astype(dtype) + variance = np.ones(scale_shape).astype(dtype) - saved_mean = np.zeros(scale_shape).astype(self.dtype) - saved_variance = np.ones(scale_shape).astype(self.dtype) + y_out = _reference_testing(x_val, scale_val, bias_val, mean, + variance, epsilon, + data_layout).astype(dtype) - y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, - epsilon, self.data_layout).astype(self.dtype) - - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x_val), - 'Scale': OpTest.np_dtype_to_fluid_dtype(scale_val), - 'Bias': OpTest.np_dtype_to_fluid_dtype(bias_val), - 'Mean': OpTest.np_dtype_to_fluid_dtype(mean), - 'Variance': OpTest.np_dtype_to_fluid_dtype(variance) - } - self.attrs = { - 'is_test': self.is_test, - 'epsilon': epsilon, - 'data_layout': self.data_layout - } - self.outputs = { - 'Y': y_out, - 'MeanOut': mean, - 'VarianceOut': variance, - 'SavedMean': saved_mean, - 'SavedVariance': saved_variance - } - - def test_check_output(self): - self.check_output() - - def init_dtype(self): - pass - - def init_data_layout(self): - pass - - def init_test_case(self): - self.shape = [2, 3, 4, 5] + scope = core.Scope() + + # create input + x_tensor = create_or_get_tensor( + scope, "x_val", OpTest.np_dtype_to_fluid_dtype(x_val), place) + scale_tensor = create_or_get_tensor( + scope, "scale_val", + OpTest.np_dtype_to_fluid_dtype(scale_val), place) + bias_tensor = create_or_get_tensor( + scope, "bias_val", + OpTest.np_dtype_to_fluid_dtype(bias_val), place) + mean_tensor = create_or_get_tensor( + scope, "mean", OpTest.np_dtype_to_fluid_dtype(mean), place) + variance_tensor = create_or_get_tensor( + scope, "variance", + OpTest.np_dtype_to_fluid_dtype(variance), place) + + # create output + y_tensor = create_or_get_tensor(scope, "y_out", None, place) + saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None, + place) + saved_variance_tensor = create_or_get_tensor( + scope, "saved_variance", None, place) + mean_out_tensor = mean_tensor + variance_out_tensor = variance_tensor + + batch_norm_op = Operator( + "batch_norm", + # inputs + X="x_val", + Scale="scale_val", + Bias="bias_val", + Mean="mean", + Variance="variance", + # outputs + Y="y_out", + MeanOut="mean", + VarianceOut="variance", + SavedMean="saved_mean", + SavedVariance="saved_variance", + # attrs + is_test=True, + data_layout=data_layout, + epsilon=epsilon) + + batch_norm_op.run(scope, place) + + # check inference result + self.__assert_close( + y_tensor, y_out, "inference output are different at " + + str(place) + ", " + data_layout + ", " + str(np.dtype(dtype))) + + places = [core.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + place = core.CUDAPlace(0) + if self.dtype != np.float16 or core.is_float16_supported(place): + places.append(place) + + for place in places: + for data_format in ["NCHW", "NHWC"]: + for dtype in [np.float32, np.float16]: + test_with_place(place, data_format, dtype, [2, 3, 4, 5]) + test_with_place(place, data_format, dtype, [2, 3]) class TestBatchNormOpTraining(OpTest): @@ -288,8 +313,7 @@ class TestBatchNormOpTraining(OpTest): # transfer (N, C, H, W) back to (N, H, W, C) y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) - self.__assert_close(y_out, y_out2_trans, - "inference outputs of two formats have differences") + self.__assert_close(y_out, y_out2_trans, "inference output") print 'python: NHWC, NCHW, inference checking passed' def test_python_training(self): From 3233b2b3234b290de5da9a89ca7db1fd223f2789 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 00:17:18 -0700 Subject: [PATCH 06/79] update test --- .../tests/unittests/test_batch_norm_op.py | 169 ++++++++++-------- 1 file changed, 93 insertions(+), 76 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 2f2873c183..91a9d826a0 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -187,99 +187,116 @@ def set_output_grad(scope, outputs, place, feed_dict=None): class TestBatchNormOpInference(OpTest): + def setUp(self): + self.dtype = np.float32 + def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def test_inference(self): - def test_with_place(place, data_layout, dtype, shape): - epsilon = 0.00001 - if len(shape) == 2: - x_shape = shape - c = x_shape[1] + def check_with_place(place, data_layout, dtype, shape): + epsilon = 0.00001 + if len(shape) == 2: + x_shape = shape + c = x_shape[1] + else: + n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if data_layout == "NHWC": + x_shape = [n, h, w, c] + elif data_layout == "NCHW": + x_shape = [n, c, h, w] else: - n, h, w, c = shape[0], shape[1], shape[2], shape[3] - if data_layout == "NHWC": - x_shape = [n, h, w, c] - elif data_layout == "NCHW": - x_shape = [n, c, h, w] - else: - raise ValueError("Unknown data layout.") - scale_shape = [c] - - x_val = np.random.random_sample(x_shape).astype(dtype) - scale_val = np.random.random_sample(scale_shape).astype(dtype) - bias_val = np.random.random_sample(scale_shape).astype(dtype) - - mean = np.zeros(scale_shape).astype(dtype) - variance = np.ones(scale_shape).astype(dtype) - - y_out = _reference_testing(x_val, scale_val, bias_val, mean, - variance, epsilon, - data_layout).astype(dtype) + raise ValueError("Unknown data layout.") + scale_shape = [c] - scope = core.Scope() + x_val = np.random.random_sample(x_shape).astype(dtype) + scale_val = np.random.random_sample(scale_shape).astype(dtype) + bias_val = np.random.random_sample(scale_shape).astype(dtype) - # create input - x_tensor = create_or_get_tensor( - scope, "x_val", OpTest.np_dtype_to_fluid_dtype(x_val), place) - scale_tensor = create_or_get_tensor( - scope, "scale_val", - OpTest.np_dtype_to_fluid_dtype(scale_val), place) - bias_tensor = create_or_get_tensor( - scope, "bias_val", - OpTest.np_dtype_to_fluid_dtype(bias_val), place) - mean_tensor = create_or_get_tensor( - scope, "mean", OpTest.np_dtype_to_fluid_dtype(mean), place) - variance_tensor = create_or_get_tensor( - scope, "variance", - OpTest.np_dtype_to_fluid_dtype(variance), place) + mean = np.zeros(scale_shape).astype(dtype) + variance = np.ones(scale_shape).astype(dtype) - # create output - y_tensor = create_or_get_tensor(scope, "y_out", None, place) - saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None, - place) - saved_variance_tensor = create_or_get_tensor( - scope, "saved_variance", None, place) - mean_out_tensor = mean_tensor - variance_out_tensor = variance_tensor + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, + epsilon, data_layout).astype(dtype) + + scope = core.Scope() + + # create input + x_tensor = create_or_get_tensor(scope, "x_val", + OpTest.np_dtype_to_fluid_dtype(x_val), + place) + scale_tensor = create_or_get_tensor( + scope, "scale_val", + OpTest.np_dtype_to_fluid_dtype(scale_val), place) + bias_tensor = create_or_get_tensor( + scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place) + mean_tensor = create_or_get_tensor(scope, "mean", + OpTest.np_dtype_to_fluid_dtype(mean), + place) + variance_tensor = create_or_get_tensor( + scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place) + + # create output + y_tensor = create_or_get_tensor(scope, "y_out", None, place) + saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None, + place) + saved_variance_tensor = create_or_get_tensor(scope, "saved_variance", + None, place) + mean_out_tensor = mean_tensor + variance_out_tensor = variance_tensor + + batch_norm_op = Operator( + "batch_norm", + # inputs + X="x_val", + Scale="scale_val", + Bias="bias_val", + Mean="mean", + Variance="variance", + # outputs + Y="y_out", + MeanOut="mean", + VarianceOut="variance", + SavedMean="saved_mean", + SavedVariance="saved_variance", + # attrs + is_test=True, + data_layout=data_layout, + epsilon=epsilon) + + batch_norm_op.run(scope, place) + + # check inference result + self.__assert_close(y_tensor, y_out, + "inference output are different at " + str(place) + + ", " + data_layout + ", " + str(np.dtype(dtype))) + + def test_check_output(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(core.CUDAPlace(0)) - batch_norm_op = Operator( - "batch_norm", - # inputs - X="x_val", - Scale="scale_val", - Bias="bias_val", - Mean="mean", - Variance="variance", - # outputs - Y="y_out", - MeanOut="mean", - VarianceOut="variance", - SavedMean="saved_mean", - SavedVariance="saved_variance", - # attrs - is_test=True, - data_layout=data_layout, - epsilon=epsilon) + for place in places: + for data_format in ["NCHW", "NHWC"]: + check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) + check_with_place(place, data_format, self.dtype, [2, 3]) - batch_norm_op.run(scope, place) - # check inference result - self.__assert_close( - y_tensor, y_out, "inference output are different at " + - str(place) + ", " + data_layout + ", " + str(np.dtype(dtype))) +class TestFP16BatchNormOpInference(TestBatchNormOpInference): + def setUp(self): + self.dtype = np.float16 - places = [core.CPUPlace()] + def test_check_output(self): + places = [] if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): place = core.CUDAPlace(0) - if self.dtype != np.float16 or core.is_float16_supported(place): + if core.is_float16_supported(place): places.append(place) for place in places: for data_format in ["NCHW", "NHWC"]: - for dtype in [np.float32, np.float16]: - test_with_place(place, data_format, dtype, [2, 3, 4, 5]) - test_with_place(place, data_format, dtype, [2, 3]) + check_output_with_place(place, data_format, self.dtype, + [2, 3, 4, 5]) + check_output_with_place(place, data_format, self.dtype, [2, 3]) class TestBatchNormOpTraining(OpTest): From aee686771c71389aef854bf332a14db0435614fd Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Sun, 18 Mar 2018 18:28:40 +0800 Subject: [PATCH 07/79] Add clone_variable function for Block class. --- python/paddle/fluid/framework.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d14d6349b1..50f6b5e0c3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -918,6 +918,24 @@ class Block(object): name=v.name) self.vars[new_p.name] = new_p + def clone_variable(self, var): + """ + Clone a variable into current block. + Args: + var: the variable to be cloned. + + Returns: + The new variable cloned from 'var' in current block. + """ + assert isinstance(var, Variable) + return self.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + class Program(object): def __init__(self): @@ -960,14 +978,14 @@ class Program(object): """Clone the Program object Set for_test to False when we want to clone the program for training. - Set for_test to True when we want to clone the program for testing. + Set for_test to True when we want to clone the program for testing. Args: for_test(bool): Some operators, such as batch_norm and drop_out ops, behave differently in training and testing. If for_test is True, the is_test attributes in these operators will be set to True for - testing purposes, otherwise, they remain unchanged. - + testing purposes, otherwise, they remain unchanged. + Returns(Program): The cloned Program object. """ From 016d0eb7f7656b4a7a2f6828b8058faa23ce86ec Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Sun, 18 Mar 2018 18:32:37 +0800 Subject: [PATCH 08/79] Add python API for sum op. --- python/paddle/fluid/layers/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 14ad18d508..1b9aeb6b47 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -51,7 +51,7 @@ __all__ = [ 'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or', 'logical_xor', 'logical_not', 'uniform_random', 'uniform_random_batch_size_like', 'gaussian_random', - 'gaussian_random_batch_size_like', 'cumsum', 'scatter' + 'gaussian_random_batch_size_like', 'cumsum', 'scatter', 'sum' ] + __activations__ for _OP in set(__all__): From 550622529cf09ae4cb11c46817d73c9a1a5c88a2 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 18 Mar 2018 20:04:03 +0800 Subject: [PATCH 09/79] Add MultipleReader and open_files_op --- paddle/fluid/operators/reader/CMakeLists.txt | 1 + .../reader/create_double_buffer_reader_op.cc | 5 +- .../fluid/operators/reader/open_files_op.cc | 199 ++++++++++++++++++ .../operators/reader/reader_op_registry.h | 22 +- 4 files changed, 224 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/reader/open_files_op.cc diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 744bd3b7ef..1254783d69 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -20,5 +20,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) +reader_library(open_files_op SRCS open_files_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index d0de092947..447fae1053 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -120,10 +120,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { }; void DoubleBufferReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + if (local_buffer_.payloads_.empty()) { buffer_->Receive(&local_buffer_); } - *out = local_buffer_.payloads_; local_buffer_.payloads_.clear(); if (local_buffer_.ctx_) { diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc new file mode 100644 index 0000000000..473c002e93 --- /dev/null +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class MultipleReader : public framework::ReaderBase { + public: + struct Quota {}; + + MultipleReader(const std::vector& file_names, + const std::vector& dims, size_t thread_num) + : file_names_(file_names), dims_(dims), thread_num_(thread_num) { + PADDLE_ENFORCE_GT(thread_num_, 0); + StartNewScheduler(); + } + + void ReadNext(std::vector* out) override; + bool HasNext() const override; + void ReInit() override; + + private: + void StartNewScheduler(); + void ScheduleThreadFunc(); + void PrefetchThreadFunc(std::string file_name); + + std::vector file_names_; + std::vector dims_; + size_t thread_num_; + framework::Channel* waiting_file_idx_; + framework::Channel* thread_quotas_; + framework::Channel>* buffer_; + mutable std::vector local_buffer_; +}; + +void MultipleReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + + if (local_buffer_.empty()) { + buffer_->Receive(&local_buffer_); + } + *out = local_buffer_; + local_buffer_.clear(); +} + +bool MultipleReader::HasNext() const { + return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true; +} + +void MultipleReader::ReInit() { + buffer_->Close(); + thread_quotas_->Close(); + waiting_file_idx_->Close(); + local_buffer_.clear(); + + StartNewScheduler(); +} + +void MultipleReader::StartNewScheduler() { + waiting_file_idx_ = framework::MakeChannel(file_names_.size()); + thread_quotas_ = framework::MakeChannel(thread_num_); + buffer_ = + framework::MakeChannel>(thread_num_); + + for (size_t i = 0; i < file_names_.size(); ++i) { + waiting_file_idx_->Send(&i); + } + waiting_file_idx_->Close(); + for (size_t i = 0; i < thread_num_; ++i) { + Quota quota; + thread_quotas_->Send("a); + } + + std::thread scheduler([this] { ScheduleThreadFunc(); }); + scheduler.detach(); +} + +void MultipleReader::ScheduleThreadFunc() { + VLOG(5) << "MultipleReader schedule thread starts."; + size_t completed_thread_num = 0; + Quota quota; + while (thread_quotas_->Receive("a)) { + size_t file_idx; + if (waiting_file_idx_->Receive(&file_idx)) { + // Still have files to read. Start a new prefetch thread. + std::string file_name = file_names_[file_idx]; + std::thread prefetcher( + [this, file_name] { PrefetchThreadFunc(file_name); }); + prefetcher.detach(); + } else { + // No more file to read. + ++completed_thread_num; + if (completed_thread_num == thread_num_) { + thread_quotas_->Close(); + buffer_->Close(); + break; + } + } + } + VLOG(5) << "MultipleReader schedule thread terminates."; +} + +void MultipleReader::PrefetchThreadFunc(std::string file_name) { + VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; + std::unique_ptr reader = + CreateReaderByFileName(file_name, dims_); + while (reader->HasNext()) { + std::vector ins; + reader->ReadNext(&ins); + if (!buffer_->Send(&ins)) { + VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " + "thread of file '" + << file_name << "' will terminate."; + break; + } + } + Quota quota; + thread_quotas_->Send("a); + VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; +} + +class OpenFilesOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + const auto& file_names = Attr>("file_names"); + PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); + const size_t thread_num = Attr("thread_num"); + + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new MultipleReader( + file_names, RestoreShapes(shape_concat, ranks), thread_num)); + } +}; + +class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddComment(R"DOC( + OpenFiles Operator + + An OpenFilesOp creates a MultipleReader, which is able to + read data multi-threaded from multiple files. + )DOC"); + AddOutput("Out", "(ReaderHolder) The created MultipleReader."); + AddAttr>("shape_concat", + "The concat of all data's shapes."); + AddAttr>( + "ranks", + "The ranks of each data." + "e.g." + "shape_concat = [2,3,4,5,6]" + "ranks = [3,2]" + "It means the reader will generate two data each time," + "whose shapes are [2,3,4] and [5,6] respectively."); + AddAttr>("lod_levels", "The LoD levels of each data."); + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, + reader::OpenFilesOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 58f9b4ba35..feab7c63a3 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -21,6 +21,8 @@ namespace paddle { namespace operators { namespace reader { +static constexpr char kFileFormatSeparator[] = ":"; + using FileReaderCreator = std::function&)>; @@ -29,12 +31,28 @@ std::unordered_map& FileReaderRegistry(); template int RegisterFileReader(const std::string& filetype) { FileReaderRegistry()[filetype] = []( - const std::string& fn, const std::vector& dim) { - return new Reader(fn, dim); + const std::string& fn, const std::vector& dims) { + return new Reader(fn, dims); }; return 0; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_format]:[file_name] (e.g., 'recordio:data_file')."); + std::string filetype = file_name.substr(0, separator_pos); + std::string f_name = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(f_name, dims); + return std::unique_ptr(reader); +} + extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks); From 3d677b1eca75733adbc1939dd0a50cbacead6718 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 18 Mar 2018 20:29:48 +0800 Subject: [PATCH 10/79] fix compile errors and make OpenFilesOpMaker derived from FileReaderMakerBase --- paddle/fluid/operators/reader/CMakeLists.txt | 2 +- .../fluid/operators/reader/open_files_op.cc | 25 ++++++------------- .../operators/reader/reader_op_registry.cc | 16 ++++++++++++ .../operators/reader/reader_op_registry.h | 15 +---------- 4 files changed, 25 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 1254783d69..4a43fc02d2 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -15,11 +15,11 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +reader_library(open_files_op SRCS open_files_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) -reader_library(open_files_op SRCS open_files_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 473c002e93..6b62e1db49 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -161,31 +161,20 @@ class OpenFilesOp : public framework::OperatorBase { } }; -class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { +class OpenFilesOpMaker : public FileReaderMakerBase { public: OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(op_proto, op_checker) { + : FileReaderMakerBase(op_proto, op_checker) { + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + AddComment(R"DOC( OpenFiles Operator An OpenFilesOp creates a MultipleReader, which is able to read data multi-threaded from multiple files. )DOC"); - AddOutput("Out", "(ReaderHolder) The created MultipleReader."); - AddAttr>("shape_concat", - "The concat of all data's shapes."); - AddAttr>( - "ranks", - "The ranks of each data." - "e.g." - "shape_concat = [2,3,4,5,6]" - "ranks = [3,2]" - "It means the reader will generate two data each time," - "whose shapes are [2,3,4] and [5,6] respectively."); - AddAttr>("lod_levels", "The LoD levels of each data."); - AddAttr>("file_names", "Files to be read."); - AddAttr("thread_num", "The maximal concurrent prefetch thread number.") - .GreaterThan(0); } }; @@ -196,4 +185,4 @@ class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { namespace reader = paddle::operators::reader; REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, - reader::OpenFilesOpMaker); \ No newline at end of file + reader::OpenFilesOpMaker); diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 0ba4f38544..05d79c76d5 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -36,6 +36,22 @@ std::unordered_map& FileReaderRegistry() { return regs; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_format]:[file_name] (e.g., 'recordio:data_file')."); + std::string filetype = file_name.substr(0, separator_pos); + std::string f_name = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(f_name, dims); + return std::unique_ptr(reader); +} + FileReaderMakerBase::FileReaderMakerBase( framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpAttrChecker* op_checker) diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index feab7c63a3..dd19b982da 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -38,20 +38,7 @@ int RegisterFileReader(const std::string& filetype) { } std::unique_ptr CreateReaderByFileName( - const std::string& file_name, const std::vector& dims) { - size_t separator_pos = file_name.find(kFileFormatSeparator); - PADDLE_ENFORCE_NE(separator_pos, std::string::npos, - "File name illegal! A legal file name should be like: " - "[file_format]:[file_name] (e.g., 'recordio:data_file')."); - std::string filetype = file_name.substr(0, separator_pos); - std::string f_name = file_name.substr(separator_pos + 1); - - auto itor = FileReaderRegistry().find(filetype); - PADDLE_ENFORCE(itor != FileReaderRegistry().end(), - "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(f_name, dims); - return std::unique_ptr(reader); -} + const std::string& file_name, const std::vector& dims); extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks); From 87fe52c10987e6c1f13890e49a98f4d8b85bbd24 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Sun, 18 Mar 2018 20:57:17 +0800 Subject: [PATCH 11/79] Add ModelAverage class to optimizer.py --- python/paddle/fluid/optimizer.py | 149 +++++++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 421963a2f9..5d90bb532d 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict - +from paddle.fluid.framework import Program import framework import layers from backward import append_backward @@ -24,7 +24,10 @@ from layer_helper import LayerHelper from regularizer import append_regularization_ops from clip import append_gradient_clip_ops, error_clip_callback -__all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad'] +__all__ = [ + 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', + 'ModelAverage' +] class Optimizer(object): @@ -119,7 +122,12 @@ class Optimizer(object): """ pass - def _add_accumulator(self, name, param, dtype=None, fill_value=0.0): + def _add_accumulator(self, + name, + param, + dtype=None, + fill_value=0.0, + shape=None): """Utility function to add an accumulator for a parameter Args: @@ -133,17 +141,19 @@ class Optimizer(object): param.name in self._accumulators[name]): raise Exception("Accumulator {} already exists for parameter {}". format(name, param.name)) - + if shape == None: + shape = param.shape assert isinstance(self.helper, LayerHelper) var = self.helper.create_global_variable( name=unique_name.generate(name), persistable=True, dtype=dtype or param.dtype, type=param.type, - shape=param.shape) + shape=shape) self.helper.set_variable_initializer( var, initializer=Constant(value=float(fill_value))) self._accumulators[name][param.name] = var + return var def _get_accumulator(self, name, param): """Utility function to fetch an accumulator for a parameter @@ -592,3 +602,132 @@ Adagrad = AdagradOptimizer Adam = AdamOptimizer Adamax = AdamaxOptimizer DecayedAdagrad = DecayedAdagradOptimizer + + +class ModelAverage(Optimizer): + """Accumulate the average of parameters whtin sliding window. The average + result will be saved in temporary variables which can be applied to + parameter variables of current model by calling 'apply()' method. And the + 'restore()' method is used to restored the parameter values of current model. + + The size of average window is determined by average_window_rate, + min_average_window, max_average_window and current update times. + + Args: + params_grads: A list of parameter-grad variable pairs. + average_window_rate: The rate of average window. + min_average_window: The minimum size of average window. + max_average_window: The maximum size of average window. + + Examples: + ... + optimizer = fluid.optimizer.Momentum() + _, params_grads = optimizer.minimize(cost) + model_average = fluid.optimizer.ModelAverage(params_grads, 0.15, + min_average_window=10000, + max_average_window=20000) + for pass_id in range(args.pass_num): + for data in train_reader(): + exe.run(fluid.default_main_program()...) + model_average.apply() + for data in test_reader(): + exe.run(inference_program...) + model_average.restore(exe) + """ + + def __init__(self, + params_grads, + average_window_rate, + min_average_window=10000, + max_average_window=10000, + **kwargs): + super(ModelAverage, self).__init__(0.0, **kwargs) + self.average_window = average_window_rate + self.min_average_window = min_average_window + self.max_average_window = max_average_window + self.params_grads = params_grads + for param, _ in self.params_grads: + self._append_average_accumulate_op(param) + + def _add_average_apply_op(self, block, param_grad): + param = block.clone_variable(param_grad[0]) + grad = block.clone_variable(param_grad[1]) + sum_1 = block.clone_variable(self._get_accumulator('sum_1', param)) + sum_2 = block.clone_variable(self._get_accumulator('sum_2', param)) + sum_3 = block.clone_variable(self._get_accumulator('sum_3', param)) + num_accumulates = block.clone_variable( + self._get_accumulator('num_accumulates', param)) + old_num_accumulates = block.clone_variable( + self._get_accumulator('old_num_accumulates', param)) + num_updates = block.clone_variable( + self._get_accumulator('num_updates', param)) + # backup param value to grad + layers.assign(input=param, output=grad) + # param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates) + tmp = layers.sum(x=[num_accumulates, old_num_accumulates]) + sum = layers.sum(x=[sum_1, sum_2, sum_3]) + tmp = layers.cast(x=tmp, dtype='float32') + sum = layers.cast(x=sum, dtype='float32') + layers.elementwise_div(x=sum, y=tmp, out=param) + + def _add_average_restore_op(self, block, param_grad): + param = block.clone_variable(param_grad[0]) + grad = block.clone_variable(param_grad[1]) + layers.assign(input=grad, output=param) + + def _append_average_accumulate_op(self, param): + self.helper = LayerHelper("average_accumulate") + sum_1 = self._add_accumulator('sum_1', param) + sum_2 = self._add_accumulator('sum_2', param) + sum_3 = self._add_accumulator('sum_3', param) + num_accumulates = self._add_accumulator( + 'num_accumulates', param, dtype='int64', shape=[1]) + old_num_accumulates = self._add_accumulator( + 'old_num_accumulates', param, dtype='int64', shape=[1]) + num_updates = self._add_accumulator( + 'num_updates', param, dtype='int64', shape=[1]) + + self.helper.append_op( + type='average_accumulates', + inputs={ + "param": param, + "in_sum_1": sum_1, + "in_sum_2": sum_2, + "in_sum_3": sum_3, + "in_num_accumulates": num_accumulates, + "in_old_num_accumulates": old_num_accumulates, + "in_num_updates": num_updates + }, + outputs={ + "out_sum_1": sum_1, + "out_sum_2": sum_2, + "out_sum_3": sum_3, + "out_num_accumulates": num_accumulates, + "out_old_num_accumulates": old_num_accumulates, + "out_num_updates": num_updates, + }, + attrs={ + "average_window": self.average_window, + "min_average_window": self.min_average_window, + "max_average_window": self.max_average_window, + }) + + def apply(self, executor): + """Apply average values to parameters of current model. + """ + apply_program = Program() + block = apply_program.global_block() + with program_guard(main_program=apply_program): + for param_grad in self.params_grads: + self._add_average_apply_op(block, param_grad) + executor.run(apply_program) + + def restore(self, executor): + """Restore parameter values of current model. + """ + restore_program = Program() + block = restore_program.global_block() + with program_guard(main_program=restore_program): + for param_grad in self.params_grads: + self._add_average_restore_op(block, param_grad) + executor.run(restore_program) From e0b136c0f972813d87e8f03d67e97b7b7c4dfcb3 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Sun, 18 Mar 2018 22:24:43 +0800 Subject: [PATCH 12/79] Refine average accumulates op 1. Rename inputs and outputs 2. Add some comments --- .../fluid/operators/average_accumulates_op.cc | 138 +++++++++++------- .../fluid/operators/average_accumulates_op.cu | 36 +++-- .../fluid/operators/average_accumulates_op.h | 92 ++++++------ 3 files changed, 147 insertions(+), 119 deletions(-) diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index 808693b61c..368a1f5612 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,9 +21,9 @@ template <> void getAccumulators( const framework::ExecutionContext& ctx, int64_t& num_updates_, int64_t& num_accumulates_, int64_t& old_num_accumulates_) { - auto* in_old_num_accumulates = ctx.Input("old_num_accumulates"); - auto* in_num_accumulates = ctx.Input("num_accumulates"); - auto* in_num_updates = ctx.Input("num_updates"); + auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("in_num_accumulates"); + auto* in_num_updates = ctx.Input("in_num_updates"); old_num_accumulates_ = in_old_num_accumulates->data()[0]; num_accumulates_ = in_num_accumulates->data()[0]; @@ -34,9 +34,9 @@ template <> void setAccumulators( const framework::ExecutionContext& ctx, int64_t num_updates_, int64_t num_accumulates_, int64_t old_num_accumulates_) { - auto* out_old_num_accumulates = ctx.Output("old_num_accumulates"); - auto* out_num_accumulates = ctx.Output("num_accumulates"); - auto* out_num_updates = ctx.Output("num_updates"); + auto* out_old_num_accumulates = ctx.Output("out_old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("out_num_accumulates"); + auto* out_num_updates = ctx.Output("out_num_updates"); out_old_num_accumulates->data()[0] = old_num_accumulates_; out_num_accumulates->data()[0] = num_accumulates_; @@ -49,64 +49,62 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE( - ctx->HasInput("Param"), - "Input (Param) of average_accumulates op should not be null."); + ctx->HasInput("param"), + "Input (param) of average_accumulates op should not be null."); PADDLE_ENFORCE( - ctx->HasInput("Grad"), - "Input (Grad) of average_accumulates op should not be null."); - PADDLE_ENFORCE( - ctx->HasInput("sum_1"), + ctx->HasInput("in_sum_1"), "Input (sum_1) of average_accumulates op should not be null."); PADDLE_ENFORCE( - ctx->HasInput("sum_2"), + ctx->HasInput("in_sum_2"), "Input (sum_2) of average_accumulates op should not be null."); PADDLE_ENFORCE( - ctx->HasInput("sum_3"), + ctx->HasInput("in_sum_3"), "Input (sum_3) of average_accumulates op should not be null."); - PADDLE_ENFORCE(ctx->HasInput("num_accumulates"), - "Input (num_accumulates) of average_accumulates op should " - "not be null."); - PADDLE_ENFORCE(ctx->HasInput("old_num_accumulates"), + PADDLE_ENFORCE( + ctx->HasInput("in_num_accumulates"), + "Input (in_num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasInput("in_old_num_accumulates"), "Input (old_num_accumulates) of average_accumulates op " "should not be null."); PADDLE_ENFORCE( - ctx->HasInput("num_updates"), + ctx->HasInput("in_num_updates"), "Input (num_updates) of average_accumulates op should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("sum_1"), + ctx->HasOutput("out_sum_1"), "Output (sum_1) of average_accumulates op should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("sum_2"), + ctx->HasOutput("out_sum_2"), "Output (sum_2) of average_accumulates op should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("sum_3"), + ctx->HasOutput("out_sum_3"), "Output (sum_3) of average_accumulates op should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("num_accumulates"), + PADDLE_ENFORCE(ctx->HasOutput("out_num_accumulates"), "Output (num_accumulates) of average_accumulates op should " "not be null."); - PADDLE_ENFORCE(ctx->HasOutput("old_num_accumulates"), + PADDLE_ENFORCE(ctx->HasOutput("out_old_num_accumulates"), "Output (old_num_accumulates) of average_accumulates op " "should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("num_updates"), + ctx->HasOutput("out_num_updates"), "Output (num_updates) of average_accumulates op should not be null."); - auto in_dim = ctx->GetInputDim("Param"); + auto in_dim = ctx->GetInputDim("param"); - ctx->SetOutputDim("sum_1", in_dim); - ctx->SetOutputDim("sum_2", in_dim); - ctx->SetOutputDim("sum_3", in_dim); - ctx->SetOutputDim("num_accumulates", {1}); - ctx->SetOutputDim("old_num_accumulates", {1}); - ctx->SetOutputDim("num_updates", {1}); + ctx->SetOutputDim("out_sum_1", in_dim); + ctx->SetOutputDim("out_sum_2", in_dim); + ctx->SetOutputDim("out_sum_3", in_dim); + ctx->SetOutputDim("out_num_accumulates", {1}); + ctx->SetOutputDim("out_old_num_accumulates", {1}); + ctx->SetOutputDim("out_num_updates", {1}); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Param")->type()), + framework::ToDataType(ctx.Input("param")->type()), ctx.GetPlace()); } }; @@ -115,26 +113,60 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { public: AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("sum_1", ""); - AddInput("sum_2", ""); - AddInput("sum_3", ""); - AddInput("num_accumulates", ""); - AddInput("old_num_accumulates", ""); - AddInput("num_updates", ""); - - AddOutput("sum_1", ""); - AddOutput("sum_2", ""); - AddOutput("sum_3", ""); - AddOutput("num_accumulates", ""); - AddOutput("old_num_accumulates", ""); - AddOutput("num_updates", ""); - - AddAttr("", "average_window"); - AddAttr("", "max_average_window"); - AddAttr("", "min_average_window"); + AddInput("param", + "Input(Tensor or LoDTensor): The parameter to be accumulated."); + AddInput("in_sum_1", + "Input(Tensor or LoDTensor): A tensor used to store the parameter " + "sums with the same shape as input(param)."); + AddInput("in_sum_2", + "Input(Tensor or LoDTensor): A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param). It is used to avoid loss of precision due to too " + "many sums."); + AddInput("in_sum_3", + "Input(Tensor or LoDTensor): A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param)."); + AddInput("in_num_accumulates", + "Input(Tensor): The accumulating times of current window with " + "shape [1]."); + AddInput("in_old_num_accumulates", + "Input(Tensor): The accumulating times of previous window with " + "shape [1]."); + AddInput("in_num_updates", + "Input(Tensor): The total number of batches used by trainning " + "before this batch with shape [1]."); + + AddOutput("out_sum_1", + "Output(Tensor or LoDTensor): A tensor used to store the " + "parameter sums with the same shape as input(param)."); + AddOutput("out_sum_2", + "Output(Tensor or LoDTensor): A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param). It is used to avoid loss of precision due to too " + "many sums."); + AddOutput("out_sum_3", + "Output(Tensor or LoDTensor): A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param)."); + AddOutput("out_num_accumulates", + "Output(Tensor): The accumulating times of current window with " + "shape [1]."); + AddOutput("out_old_num_accumulates", + "Output(Tensor): The accumulating times of previous window with " + "shape [1]."); + AddOutput("out_num_updates", + "Output(Tensor): The total number of batches used by trainning " + "before this batch with shape [1]."); + + AddAttr("average_window", + "The rate of average window size relative to num_updates."); + AddAttr("max_average_window", "Maximum size of average window."); + AddAttr("min_average_window", "Minimu size of average window."); AddComment(R"DOC( AverageAccumulates Operator. +Accumulate the sum of parameter whtin sliding window. The size of sliding window is determined by 'average_window', 'max_average_window' and 'min_average_window'. )DOC"); } }; @@ -143,10 +175,10 @@ AverageAccumulates Operator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(average_accumulate, ops::AverageAccumulatesOp, +REGISTER_OPERATOR(average_accumulates, ops::AverageAccumulatesOp, ops::AverageAccumulatesOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - average_accumulate, + average_accumulates, ops::AverageAccumulatesKernel, ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.cu b/paddle/fluid/operators/average_accumulates_op.cu index 56f2f02fd2..dbaa8ba6c9 100644 --- a/paddle/fluid/operators/average_accumulates_op.cu +++ b/paddle/fluid/operators/average_accumulates_op.cu @@ -21,39 +21,43 @@ template <> void getAccumulators( const framework::ExecutionContext& ctx, int64_t& num_updates_, int64_t& num_accumulates_, int64_t& old_num_accumulates_) { - auto* in_old_num_accumulates = ctx.Input("old_num_accumulates"); - auto* in_num_accumulates = ctx.Input("num_accumulates"); - auto* in_num_updates = ctx.Input("num_updates"); - + auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("in_num_accumulates"); + auto* in_num_updates = ctx.Input("in_num_updates"); + auto stream = ctx.cuda_device_context().stream(); memory::Copy(platform::CPUPlace(), &old_num_accumulates_, platform::CUDAPlace(), in_old_num_accumulates->data(), - sizeof(int64_t)); + sizeof(int64_t), stream); memory::Copy(platform::CPUPlace(), &num_accumulates_, platform::CUDAPlace(), - in_old_num_accumulates->data(), sizeof(int64_t)); + in_num_accumulates->data(), sizeof(int64_t), stream); memory::Copy(platform::CPUPlace(), &num_updates_, platform::CUDAPlace(), - in_num_updates->data(), sizeof(int64_t)); + in_num_updates->data(), sizeof(int64_t), stream); } template <> void setAccumulators( const framework::ExecutionContext& ctx, int64_t num_updates_, int64_t num_accumulates_, int64_t old_num_accumulates_) { - auto* out_old_num_accumulates = ctx.Output("old_num_accumulates"); - auto* out_num_accumulates = ctx.Output("num_accumulates"); - auto* out_num_updates = ctx.Output("num_updates"); + auto stream = ctx.cuda_device_context().stream(); + auto* out_old_num_accumulates = ctx.Output("out_old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("out_num_accumulates"); + auto* out_num_updates = ctx.Output("out_num_updates"); memory::Copy(platform::CUDAPlace(), out_old_num_accumulates->data(), - platform::CPUPlace(), &old_num_accumulates_, sizeof(int64_t)); + platform::CPUPlace(), &old_num_accumulates_, sizeof(int64_t), + stream); memory::Copy(platform::CUDAPlace(), out_num_accumulates->data(), - platform::CPUPlace(), &num_accumulates_, sizeof(int64_t)); + platform::CPUPlace(), &num_accumulates_, sizeof(int64_t), + stream); memory::Copy(platform::CUDAPlace(), out_num_updates->data(), - platform::CPUPlace(), &num_updates_, sizeof(int64_t)); -} -} + platform::CPUPlace(), &num_updates_, sizeof(int64_t), stream); } +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - average_accumulate, + average_accumulates, ops::AverageAccumulatesKernel, ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.h b/paddle/fluid/operators/average_accumulates_op.h index 73814dd24b..d33fd5519a 100644 --- a/paddle/fluid/operators/average_accumulates_op.h +++ b/paddle/fluid/operators/average_accumulates_op.h @@ -29,88 +29,80 @@ using EigenVector = framework::EigenVector; template void getAccumulators(const framework::ExecutionContext& ctx, - int64_t& num_updates_, int64_t& num_accumulates_, - int64_t& old_num_accumulates_); + int64_t& num_updates, int64_t& num_accumulates, + int64_t& old_num_accumulates); template void setAccumulators(const framework::ExecutionContext& ctx, - int64_t num_updates_, int64_t num_accumulates_, - int64_t old_num_accumulates_); + int64_t num_updates, int64_t num_accumulates, + int64_t old_num_accumulates); template class AverageAccumulatesKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + // It is used to avoid loss of precision static const int64_t kMaxNumAccumulates = 16384; - // accumulators - int64_t num_updates_ = 0; - int64_t num_accumulates_ = 0; - int64_t old_num_accumulates_ = 0; - // attrs - int64_t min_average_window_; - int64_t max_average_window_; - float average_window_; - - auto* param = ctx.Input("Param"); - auto* in_sum_1 = ctx.Input("sum_1"); - auto* in_sum_2 = ctx.Input("sum_2"); - auto* in_sum_3 = ctx.Input("sum_3"); - - auto* out_sum_1 = ctx.Output("sum_1"); - auto* out_sum_2 = ctx.Output("sum_2"); - auto* out_sum_3 = ctx.Output("sum_3"); - - getAccumulators(ctx, num_updates_, num_accumulates_, - old_num_accumulates_); - average_window_ = ctx.Attr("average_window"); - max_average_window_ = - ctx.Attr("max_average_window"); // default bach number - min_average_window_ = - ctx.Attr("min_average_window"); // default 10000L - min_average_window_ = - std::min(min_average_window_, max_average_window_); - + // Get accumulators from input + int64_t num_updates = 0; + int64_t num_accumulates = 0; + int64_t old_num_accumulates = 0; + getAccumulators(ctx, num_updates, num_accumulates, + old_num_accumulates); + + // Get attrs + float average_window = ctx.Attr("average_window"); + int64_t max_average_window = ctx.Attr("max_average_window"); + int64_t min_average_window = ctx.Attr("min_average_window"); + min_average_window = + std::min(min_average_window, max_average_window); + + // Get inputs + auto* param = ctx.Input("param"); + auto* in_sum_1 = ctx.Input("in_sum_1"); + auto* in_sum_2 = ctx.Input("in_sum_2"); + auto* in_sum_3 = ctx.Input("in_sum_3"); auto param_tensor = EigenVector::Flatten(*param); auto in_sum_1_tensor = EigenVector::Flatten(*in_sum_1); auto in_sum_2_tensor = EigenVector::Flatten(*in_sum_2); auto in_sum_3_tensor = EigenVector::Flatten(*in_sum_3); + + // Get outputs + auto* out_sum_1 = ctx.Output("out_sum_1"); + auto* out_sum_2 = ctx.Output("out_sum_2"); + auto* out_sum_3 = ctx.Output("out_sum_3"); auto out_sum_1_tensor = EigenVector::Flatten(*out_sum_1); auto out_sum_2_tensor = EigenVector::Flatten(*out_sum_2); auto out_sum_3_tensor = EigenVector::Flatten(*out_sum_3); + // Compute auto& place = *ctx.template device_context().eigen_device(); math::SetConstant constant_functor; - // start batch - ++num_updates_; - ++num_accumulates_; - - // update + ++num_updates; + ++num_accumulates; out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor; - out_sum_2_tensor.device(place) = in_sum_2_tensor; out_sum_3_tensor.device(place) = in_sum_3_tensor; - // needSpecialTraversal - if (num_updates_ % kMaxNumAccumulates == 0) { + if (num_updates % kMaxNumAccumulates == 0) { out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor; constant_functor(ctx.template device_context(), out_sum_1, 0.0); } - - if (num_accumulates_ >= min_average_window_ && - num_accumulates_ >= std::min(max_average_window_, - num_updates_ * average_window_)) { + if (num_accumulates >= min_average_window && + num_accumulates >= std::min(max_average_window, + num_updates * average_window)) { out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor; constant_functor(ctx.template device_context(), out_sum_1, 0.0); constant_functor(ctx.template device_context(), out_sum_2, 0.0); - - // finishBatch - old_num_accumulates_ = num_accumulates_; - num_accumulates_ = 0; + old_num_accumulates = num_accumulates; + num_accumulates = 0; } - setAccumulators(ctx, num_updates_, num_accumulates_, - old_num_accumulates_); + + // Set accumulators to output + setAccumulators(ctx, num_updates, num_accumulates, + old_num_accumulates); } }; From 87ac675ae7365cdc8afc8f12503df962ce9aaabc Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 18 Mar 2018 23:49:11 +0800 Subject: [PATCH 13/79] Add python wrapper for open_files_op --- python/paddle/fluid/layers/io.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 9c91f395e7..89153f325b 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -287,6 +287,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var) +def open_files(filenames, thread_num, shapes, lod_levels, dtypes): + dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] + shape_concat = [] + ranks = [] + + for shape in shapes: + shape_concat.extend(shape) + ranks.append(len(shape)) + + var_name = unique_name('multiple_reader') + + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=var_name) + startup_blk.append_op( + type='open_files', + outputs={'Out': [startup_var]}, + attrs={ + 'shape_concat': shape_concat, + 'lod_levels': lod_levels, + 'ranks': ranks, + 'filename': filenames, + 'thread_num': thread_num + }) + + startup_var.desc.set_dtypes(dtypes) + startup_var.persistable = True + return _copy_reader_var_(default_main_program().current_block(), + startup_var) + + def __create_decorated_reader__(op_type, reader, attrs): var_name = unique_name(op_type) startup_blk = default_startup_program().current_block() From e870947cfd1f0a2d86d5d422d445c41a99913090 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 17:58:42 -0700 Subject: [PATCH 14/79] fix batch norm fp16 param type --- paddle/fluid/operators/batch_norm_op.cc | 23 +++++++++++ paddle/fluid/operators/batch_norm_op.cu.cc | 38 +++++++++++-------- paddle/fluid/platform/cudnn_helper.h | 5 +++ .../tests/unittests/test_batch_norm_op.py | 31 ++++++++------- 4 files changed, 69 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 215ae229af..ae970acc27 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel { ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + // For float or float16 input tensor, the type of the scale, bias, mean, + // and var tensors should both be float. + auto bn_param_type = framework::proto::VarType::FP32; + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Scale")->type()), + "Scale input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Bias")->type()), + "Bias input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Mean")->type()), + "Mean input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( + ctx.Input("Variance")->type()), + "Variance input should be of float type"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index f4919398eb..5e97678862 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -26,6 +26,8 @@ using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; +template +using bn_param_type = CudnnDataType::bn_param_type; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -104,8 +106,9 @@ class BatchNormKernel CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + // Note: PERSISTENT not implemented for inference CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); + bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_)); const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); @@ -118,15 +121,15 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data(ctx.GetPlace()); - variance_out->mutable_data(ctx.GetPlace()); - saved_mean->mutable_data(ctx.GetPlace()); - saved_variance->mutable_data(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant functor; - functor(dev_ctx, saved_mean, static_cast(0)); - functor(dev_ctx, saved_variance, static_cast(0)); + math::SetConstant> functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -147,8 +150,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data(), bias->template data(), - est_mean->template data(), est_var->template data(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -159,11 +164,14 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data(), bias->template data(), this_factor, - mean_out->template mutable_data(ctx.GetPlace()), - variance_out->template mutable_data(ctx.GetPlace()), epsilon, - saved_mean->template mutable_data(ctx.GetPlace()), - saved_variance->template mutable_data(ctx.GetPlace()))); + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>(ctx.GetPlace()), + variance_out->template mutable_data>(ctx.GetPlace()), + epsilon, + saved_mean->template mutable_data>(ctx.GetPlace()), + saved_variance->template mutable_data>( + ctx.GetPlace()))); } // clean when exit. diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 7e001ecc56..a40c366241 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -85,6 +85,9 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; + // cudnn batch norm requires that Scale, Bias, Mean, and Variance + // to be FLOAT tensors when the input x is HALF tensor + static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; // The scaling param type is float for HALF and FLOAT tensors typedef const float ScalingParamType; static ScalingParamType* kOne() { @@ -101,6 +104,7 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; + static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; typedef const float ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; @@ -116,6 +120,7 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; + static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE; typedef const double ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 91a9d826a0..261c457708 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -193,7 +193,7 @@ class TestBatchNormOpInference(OpTest): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def check_with_place(place, data_layout, dtype, shape): + def check_with_place(self, place, data_layout, dtype, shape): epsilon = 0.00001 if len(shape) == 2: x_shape = shape @@ -209,11 +209,11 @@ class TestBatchNormOpInference(OpTest): scale_shape = [c] x_val = np.random.random_sample(x_shape).astype(dtype) - scale_val = np.random.random_sample(scale_shape).astype(dtype) - bias_val = np.random.random_sample(scale_shape).astype(dtype) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) - mean = np.zeros(scale_shape).astype(dtype) - variance = np.ones(scale_shape).astype(dtype) + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, epsilon, data_layout).astype(dtype) @@ -266,9 +266,13 @@ class TestBatchNormOpInference(OpTest): batch_norm_op.run(scope, place) # check inference result - self.__assert_close(y_tensor, y_out, - "inference output are different at " + str(place) + - ", " + data_layout + ", " + str(np.dtype(dtype))) + self.__assert_close( + y_tensor, + y_out, + "inference output are different at " + str(place) + ", " + + data_layout + ", " + str(np.dtype(dtype)) + + str(np.array(y_tensor)) + str(y_out), + atol=2e-2) def test_check_output(self): places = [core.CPUPlace()] @@ -277,8 +281,9 @@ class TestBatchNormOpInference(OpTest): for place in places: for data_format in ["NCHW", "NHWC"]: - check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) - check_with_place(place, data_format, self.dtype, [2, 3]) + self.check_with_place(place, data_format, self.dtype, + [2, 3, 4, 5]) + self.check_with_place(place, data_format, self.dtype, [2, 3]) class TestFP16BatchNormOpInference(TestBatchNormOpInference): @@ -294,9 +299,9 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): for place in places: for data_format in ["NCHW", "NHWC"]: - check_output_with_place(place, data_format, self.dtype, - [2, 3, 4, 5]) - check_output_with_place(place, data_format, self.dtype, [2, 3]) + self.check_with_place(place, data_format, self.dtype, + [2, 3, 4, 5]) + self.check_with_place(place, data_format, self.dtype, [2, 3]) class TestBatchNormOpTraining(OpTest): From ffa22a5f905daa7b7a4a9e4e6a1e3f17b9fea073 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 20:01:27 -0700 Subject: [PATCH 15/79] fix scaling param type --- paddle/fluid/operators/batch_norm_op.cu.cc | 40 ++++++++++++---------- paddle/fluid/platform/cudnn_helper.h | 5 --- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 5e97678862..2de935d087 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -27,7 +28,7 @@ using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template -using bn_param_type = CudnnDataType::bn_param_type; +using ScalingParamType = typename CudnnDataType::ScalingParamType; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -121,15 +122,15 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data>(ctx.GetPlace()); - variance_out->mutable_data>(ctx.GetPlace()); - saved_mean->mutable_data>(ctx.GetPlace()); - saved_variance->mutable_data>(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant> functor; - functor(dev_ctx, saved_mean, static_cast>(0)); - functor(dev_ctx, saved_variance, static_cast>(0)); + math::SetConstant> functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -150,10 +151,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data>(), - bias->template data>(), - est_mean->template data>(), - est_var->template data>(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -164,13 +165,14 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - bias->template data>(), this_factor, - mean_out->template mutable_data>(ctx.GetPlace()), - variance_out->template mutable_data>(ctx.GetPlace()), - epsilon, - saved_mean->template mutable_data>(ctx.GetPlace()), - saved_variance->template mutable_data>( + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>(ctx.GetPlace()), + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean->template mutable_data>( + ctx.GetPlace()), + saved_variance->template mutable_data>( ctx.GetPlace()))); } diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index a40c366241..7e001ecc56 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -85,9 +85,6 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; - // cudnn batch norm requires that Scale, Bias, Mean, and Variance - // to be FLOAT tensors when the input x is HALF tensor - static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; // The scaling param type is float for HALF and FLOAT tensors typedef const float ScalingParamType; static ScalingParamType* kOne() { @@ -104,7 +101,6 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; typedef const float ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; @@ -120,7 +116,6 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE; typedef const double ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; From 446d54f5c32d8cf15ad83ba71783f92b19621931 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 20:24:11 -0700 Subject: [PATCH 16/79] update --- paddle/fluid/operators/batch_norm_op.cc | 2 +- paddle/fluid/operators/batch_norm_op.cu.cc | 38 ++++++++++++---------- paddle/fluid/platform/cudnn_helper.h | 9 +++-- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index ae970acc27..5d27f5b60c 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -83,7 +83,7 @@ class BatchNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const ExecutionContext &ctx) const override { + const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); // For float or float16 input tensor, the type of the scale, bias, mean, diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 2de935d087..6ceacc3992 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -28,7 +28,7 @@ using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template -using ScalingParamType = typename CudnnDataType::ScalingParamType; +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -122,15 +122,16 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data>(ctx.GetPlace()); - variance_out->mutable_data>(ctx.GetPlace()); - saved_mean->mutable_data>(ctx.GetPlace()); - saved_variance->mutable_data>(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant> functor; - functor(dev_ctx, saved_mean, static_cast>(0)); - functor(dev_ctx, saved_variance, static_cast>(0)); + math::SetConstant> + functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -151,10 +152,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data>(), - bias->template data>(), - est_mean->template data>(), - est_var->template data>(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -165,14 +166,15 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - bias->template data>(), this_factor, - mean_out->template mutable_data>(ctx.GetPlace()), - variance_out->template mutable_data>( + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>( ctx.GetPlace()), - epsilon, saved_mean->template mutable_data>( + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean->template mutable_data>( ctx.GetPlace()), - saved_variance->template mutable_data>( + saved_variance->template mutable_data>( ctx.GetPlace()))); } diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 7e001ecc56..7c604e14eb 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -86,7 +86,8 @@ class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; // The scaling param type is float for HALF and FLOAT tensors - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -101,7 +102,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -116,7 +118,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - typedef const double ScalingParamType; + using ScalingParamType = const double; + using BatchNormParamType = double; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; From 6ec0f91273f097b4efaccc3e54e86c0a74c4173d Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 20:40:53 -0700 Subject: [PATCH 17/79] decrease atol --- python/paddle/fluid/tests/unittests/test_batch_norm_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 261c457708..10aa63e18a 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -272,7 +272,7 @@ class TestBatchNormOpInference(OpTest): "inference output are different at " + str(place) + ", " + data_layout + ", " + str(np.dtype(dtype)) + str(np.array(y_tensor)) + str(y_out), - atol=2e-2) + atol=1e-3) def test_check_output(self): places = [core.CPUPlace()] From cad4d7f325b810a154469a02c2b5aa0f7b50dc66 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 19 Mar 2018 16:40:35 +0800 Subject: [PATCH 18/79] Refine initial and API of ModelAverage API 1. Implement 'with model_average.apply()' syntax 2. Init apply_program and restore_program in __init__ functin of ModelAverage --- python/paddle/fluid/optimizer.py | 45 ++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 5473e61468..394cf050a7 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -23,6 +23,7 @@ from initializer import Constant from layer_helper import LayerHelper from regularizer import append_regularization_ops from clip import append_gradient_clip_ops, error_clip_callback +from contextlib import contextmanager __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', @@ -631,10 +632,10 @@ class ModelAverage(Optimizer): for pass_id in range(args.pass_num): for data in train_reader(): exe.run(fluid.default_main_program()...) - model_average.apply() - for data in test_reader(): - exe.run(inference_program...) - model_average.restore(exe) + + with model_average.apply(exe): + for data in test_reader(): + exe.run(inference_program...) """ def __init__(self, @@ -651,6 +652,18 @@ class ModelAverage(Optimizer): for param, _ in self.params_grads: self._append_average_accumulate_op(param) + self.apply_program = Program() + block = self.apply_program.global_block() + with program_guard(main_program=self.apply_program): + for param_grad in self.params_grads: + self._add_average_apply_op(block, param_grad) + + self.restore_program = Program() + block = self.restore_program.global_block() + with program_guard(main_program=self.restore_program): + for param_grad in self.params_grads: + self._add_average_restore_op(block, param_grad) + def _add_average_apply_op(self, block, param_grad): param = block.clone_variable(param_grad[0]) grad = block.clone_variable(param_grad[1]) @@ -714,22 +727,20 @@ class ModelAverage(Optimizer): "max_average_window": self.max_average_window, }) - def apply(self, executor): + @contextmanager + def apply(self, executor, need_restore=True): """Apply average values to parameters of current model. """ - apply_program = Program() - block = apply_program.global_block() - with program_guard(main_program=apply_program): - for param_grad in self.params_grads: - self._add_average_apply_op(block, param_grad) - executor.run(apply_program) + executor.run(self.apply_program) + print "finish apply" + try: + yield + finally: + if need_restore: + self.restore(executor) def restore(self, executor): """Restore parameter values of current model. """ - restore_program = Program() - block = restore_program.global_block() - with program_guard(main_program=restore_program): - for param_grad in self.params_grads: - self._add_average_restore_op(block, param_grad) - executor.run(restore_program) + executor.run(self.restore_program) + print "finish restore" From d9868b08392d831d7f4bd1a1a098217cc4573c8f Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 19 Mar 2018 19:06:38 +0800 Subject: [PATCH 19/79] Add multi_pass_reader --- paddle/fluid/operators/reader/CMakeLists.txt | 1 + .../reader/create_multi_pass_reader_op.cc | 101 ++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 paddle/fluid/operators/reader/create_multi_pass_reader_op.cc diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 744bd3b7ef..fc7ef227f0 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -20,5 +20,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) +reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc new file mode 100644 index 0000000000..4d4e9fb909 --- /dev/null +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class MultiPassReader : public framework::DecoratedReader { + public: + MultiPassReader(ReaderBase* reader, int pass_num) + : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} + + void ReadNext(std::vector* out) override { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + reader_->ReadNext(out); + } + + bool HasNext() const override { + if (reader_->HasNext()) { + return true; + } else { + ++pass_count_; + if (pass_count_ >= pass_num_) { + return false; + } else { + reader_->ReInit(); + return true; + } + } + } + + void ReInit() override { + pass_count_ = 0; + reader_->ReInit(); + } + + private: + int pass_num_; + mutable int pass_count_; +}; + +class CreateMultiPassReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) + ->Get(); + auto& out = detail::Ref(scope.FindVar(Output("Out"))); + int pass_num = Attr("pass_num"); + out.GetMutable()->Reset( + new MultiPassReader(underlying_reader.Get(), pass_num)); + } +}; + +class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase { + public: + CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : DecoratedReaderMakerBase(op_proto, op_checker) { + AddAttr("pass_num", "The number of pass to run.").GreaterThan(0); + AddComment(R"DOC( + CreateMultiPassReader Operator + + This operator creates a multi-pass reader. A multi-pass reader + is used to yield data for several pass training continuously. + It takes the the number of pass to run as one of its attributes + ('pass_num'), and maintains a pass counter to record how many + passes it has completed. When the underlying reader reach the EOF, + the multi-pass reader checks whether it has completed training + of the given number of pass. If not, the underlying reader will + be re-initialized and starts a new pass automatically. + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::reader; +REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader, + ops::CreateMultiPassReaderOp, + ops::CreateMultiPassReaderOpMaker); From 26734cfe77d17a48094b2fa7ab768e5afe635cc5 Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Mon, 19 Mar 2018 16:46:03 -0700 Subject: [PATCH 20/79] expose dilation option to conv2d and add bias/activation option to con2d_trans --- python/paddle/fluid/layers/nn.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9656dcf94f..beed54bd0a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1117,12 +1117,14 @@ def conv2d(input, filter_size, stride=1, padding=0, + dilation=1, groups=None, param_attr=None, bias_attr=None, use_cudnn=True, use_mkldnn=False, - act=None): + act=None, + name=None): """ **Convlution2D Layer** @@ -1183,6 +1185,9 @@ def conv2d(input, padding(int|tuple): The padding size. If padding is a tuple, it must contain two integers, (padding_H, padding_W). Otherwise, the padding_H = padding_W = padding. Default: padding = 0. + dilation(int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. groups(int): The groups number of the Conv2d Layer. According to grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2, the first half of the filters is only connected to the first half @@ -1193,6 +1198,8 @@ def conv2d(input, use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: True act(str): Activation type. Default: None + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. Returns: Variable: The tensor variable storing the convolution and \ @@ -1233,6 +1240,7 @@ def conv2d(input, filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') stride = utils.convert_to_list(stride, 2, 'stride') padding = utils.convert_to_list(padding, 2, 'padding') + dilation = utils.convert_to_list(dilation, 2, 'dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1262,6 +1270,7 @@ def conv2d(input, attrs={ 'strides': stride, 'paddings': padding, + 'dilations': dilation, 'groups': groups, 'use_cudnn': use_cudnn, 'use_mkldnn': use_mkldnn @@ -1670,7 +1679,9 @@ def conv2d_transpose(input, stride=1, dilation=1, param_attr=None, + bias_attr=None, use_cudnn=True, + act=None, name=None): """ **Convlution2D transpose layer** @@ -1739,8 +1750,10 @@ def conv2d_transpose(input, dilation_H = dilation_W = dilation. Default: dilation = 1. param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer. Default: None + bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: True + act(str): Activation type. Default: None name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -1793,12 +1806,12 @@ def conv2d_transpose(input, img_filter = helper.create_parameter( dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) - out = helper.create_tmp_variable(dtype=input.dtype) + pre_bias = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( type='conv2d_transpose', inputs={'Input': [input], 'Filter': [img_filter]}, - outputs={'Output': out}, + outputs={'Output': pre_bias}, attrs={ 'strides': stride, 'paddings': padding, @@ -1806,6 +1819,8 @@ def conv2d_transpose(input, 'use_cudnn': use_cudnn }) + pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + out = helper.append_activation(pre_act) return out From 4bf168b2745964077d39483334e6d6bb9d9b8087 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 17:15:46 -0700 Subject: [PATCH 21/79] add fp16 kernel for elementwise add --- paddle/fluid/operators/elementwise_add_op.cu | 21 ++++---- .../unittests/test_elementwise_add_op.py | 54 ++++++++++++++----- 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index 19dc4a5215..c8bf524144 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -14,19 +14,20 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_add_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = padddle::platform; REGISTER_OP_CUDA_KERNEL( - elementwise_add, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel); + elementwise_add, ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel + ops::ElementwiseAddKernel); REGISTER_OP_CUDA_KERNEL( elementwise_add_grad, - ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 5b2384e94d..28286d79ea 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -13,34 +13,60 @@ # limitations under the License. import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest -class TestElementwiseOp(OpTest): +class TestElementwiseAddOp(OpTest): def setUp(self): self.op_type = "elementwise_add" + self.dtype = np.float32 + init_dtype() + + x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.inputs = { - 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), - 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") + 'X': OpTest.np_dtype_to_fluid_dtype(x), + 'Y': OpTest.np_dtype_to_fluid_dtype(y) } - self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['Y'])} + self.outputs = {'Out': np.add(x, y)} def test_check_output(self): self.check_output() def test_check_grad_normal(self): + if self.dtype == np.float16: + return self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) def test_check_grad_ingore_x(self): + if self.dtype == np.float16: + return self.check_grad( ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) def test_check_grad_ingore_y(self): + if self.dtype == np.float16: + return self.check_grad( ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + def init_dtype(): + pass + + +class TestFP16ElementwiseAddOp(TestElementwiseAddOp): + def init_dtype(): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + -class TestElementwiseAddOp_scalar(TestElementwiseOp): +class TestElementwiseAddOp_scalar(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -50,7 +76,7 @@ class TestElementwiseAddOp_scalar(TestElementwiseOp): self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} -class TestElementwiseAddOp_scalar2(TestElementwiseOp): +class TestElementwiseAddOp_scalar2(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -60,7 +86,7 @@ class TestElementwiseAddOp_scalar2(TestElementwiseOp): self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} -class TestElementwiseAddOp_Vector(TestElementwiseOp): +class TestElementwiseAddOp_Vector(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -70,7 +96,7 @@ class TestElementwiseAddOp_Vector(TestElementwiseOp): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['Y'])} -class TestElementwiseAddOp_broadcast_0(TestElementwiseOp): +class TestElementwiseAddOp_broadcast_0(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -84,7 +110,7 @@ class TestElementwiseAddOp_broadcast_0(TestElementwiseOp): } -class TestElementwiseAddOp_broadcast_1(TestElementwiseOp): +class TestElementwiseAddOp_broadcast_1(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -98,7 +124,7 @@ class TestElementwiseAddOp_broadcast_1(TestElementwiseOp): } -class TestElementwiseAddOp_broadcast_2(TestElementwiseOp): +class TestElementwiseAddOp_broadcast_2(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -111,7 +137,7 @@ class TestElementwiseAddOp_broadcast_2(TestElementwiseOp): } -class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): +class TestElementwiseAddOp_broadcast_3(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -125,7 +151,7 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): } -class TestElementwiseAddOp_broadcast_4(TestElementwiseOp): +class TestElementwiseAddOp_broadcast_4(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -139,7 +165,7 @@ class TestElementwiseAddOp_broadcast_4(TestElementwiseOp): } -class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp): +class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { @@ -153,7 +179,7 @@ class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp): } -class TestElementwiseAddOp_rowwise_add_1(TestElementwiseOp): +class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp): def setUp(self): self.op_type = "elementwise_add" self.inputs = { From d22f4de79474d86f415210229ae1f0f750e7e91c Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 20 Mar 2018 11:09:20 +0800 Subject: [PATCH 22/79] Refine sum_accumulates_op. --- .../fluid/operators/average_accumulates_op.cc | 88 +++++++++++++------ .../fluid/operators/average_accumulates_op.cu | 4 +- .../fluid/operators/average_accumulates_op.h | 11 ++- python/paddle/fluid/optimizer.py | 2 - 4 files changed, 69 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index 368a1f5612..c95077fcbd 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -18,7 +18,7 @@ namespace paddle { namespace operators { template <> -void getAccumulators( +void GetAccumulators( const framework::ExecutionContext& ctx, int64_t& num_updates_, int64_t& num_accumulates_, int64_t& old_num_accumulates_) { auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); @@ -31,7 +31,7 @@ void getAccumulators( } template <> -void setAccumulators( +void SetAccumulators( const framework::ExecutionContext& ctx, int64_t num_updates_, int64_t num_accumulates_, int64_t old_num_accumulates_) { auto* out_old_num_accumulates = ctx.Output("out_old_num_accumulates"); @@ -113,60 +113,92 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { public: AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("param", - "Input(Tensor or LoDTensor): The parameter to be accumulated."); + AddInput("param", "(Tensor), The parameter to be accumulated."); AddInput("in_sum_1", - "Input(Tensor or LoDTensor): A tensor used to store the parameter " + "(Tensor), A tensor used to store the parameter " "sums with the same shape as input(param)."); AddInput("in_sum_2", - "Input(Tensor or LoDTensor): A auxiliary tensor to help " + "(Tensor), A auxiliary tensor to help " "accumulating sums of parameter values with the same shape as " "input(param). It is used to avoid loss of precision due to too " "many sums."); AddInput("in_sum_3", - "Input(Tensor or LoDTensor): A auxiliary tensor to help " + "(Tensor), A auxiliary tensor to help " "accumulating sums of parameter values with the same shape as " "input(param)."); AddInput("in_num_accumulates", - "Input(Tensor): The accumulating times of current window with " - "shape [1]."); - AddInput("in_old_num_accumulates", - "Input(Tensor): The accumulating times of previous window with " + "(Tensor), The accumulating times of current window with " "shape [1]."); + AddInput( + "in_old_num_accumulates", + "(Tensor), The accumulating times of previous window with " + "shape [1]."); AddInput("in_num_updates", - "Input(Tensor): The total number of batches used by trainning " + "(Tensor), The total number of batches used by trainning " "before this batch with shape [1]."); AddOutput("out_sum_1", - "Output(Tensor or LoDTensor): A tensor used to store the " + "(Tensor), A tensor used to store the " "parameter sums with the same shape as input(param)."); AddOutput("out_sum_2", - "Output(Tensor or LoDTensor): A auxiliary tensor to help " + "(Tensor), A auxiliary tensor to help " "accumulating sums of parameter values with the same shape as " "input(param). It is used to avoid loss of precision due to too " "many sums."); AddOutput("out_sum_3", - "Output(Tensor or LoDTensor): A auxiliary tensor to help " + "(Tensor), A auxiliary tensor to help " "accumulating sums of parameter values with the same shape as " "input(param)."); - AddOutput("out_num_accumulates", - "Output(Tensor): The accumulating times of current window with " - "shape [1]."); - AddOutput("out_old_num_accumulates", - "Output(Tensor): The accumulating times of previous window with " - "shape [1]."); - AddOutput("out_num_updates", - "Output(Tensor): The total number of batches used by trainning " - "before this batch with shape [1]."); + AddOutput( + "out_num_accumulates", + "(Tensor), The accumulating times of current window with " + "shape [1]."); + AddOutput( + "out_old_num_accumulates", + "(Tensor) The accumulating times of previous window with " + "shape [1]."); + AddOutput( + "out_num_updates", + "(Tensor), The total number of batches used by trainning " + "before this batch with shape [1]."); AddAttr("average_window", - "The rate of average window size relative to num_updates."); - AddAttr("max_average_window", "Maximum size of average window."); - AddAttr("min_average_window", "Minimu size of average window."); + "(float, default 0) " + "The rate of average window size relative to num_updates.") + .SetDefault(0); + AddAttr("max_average_window", + "(int64_t) " + "Maximum size of average window. It suggests that the " + "number of mini-batches " + "in one pass is appropriate value to set."); + AddAttr("min_average_window", + "(int64_t, default 10000L) " + "Minimu size of average window.") + .SetDefault(10000L); AddComment(R"DOC( AverageAccumulates Operator. -Accumulate the sum of parameter whtin sliding window. The size of sliding window is determined by 'average_window', 'max_average_window' and 'min_average_window'. +Accumulate the sum of parameter whtin sliding window. The size of sliding window is +determined by 'average_window', 'max_average_window' and 'min_average_window'. +Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'. +'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'. + +All the accumulators were inited to zero before training. + +And for a mini-batch in training, accumulators were computed as below steps: + num_updates += 1 + num_accumulates += 1 + sum_1 += param + if num_updates % kMaxNumAccumulates == 0: + sum_2 += sum_1 + sum_1 = 0 + if num_accumulates >= min_average_window && num_accumulates >= min(max_average_window, num_updates * average_window): + sum_3 = sum_1 + sum_2 + sum_1 = 0 + sum_2 = 0 + old_num_accumulates = num_accumulates + num_accumulates = 0 + )DOC"); } }; diff --git a/paddle/fluid/operators/average_accumulates_op.cu b/paddle/fluid/operators/average_accumulates_op.cu index dbaa8ba6c9..270c469844 100644 --- a/paddle/fluid/operators/average_accumulates_op.cu +++ b/paddle/fluid/operators/average_accumulates_op.cu @@ -18,7 +18,7 @@ limitations under the License. */ namespace paddle { namespace operators { template <> -void getAccumulators( +void GetAccumulators( const framework::ExecutionContext& ctx, int64_t& num_updates_, int64_t& num_accumulates_, int64_t& old_num_accumulates_) { auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); @@ -35,7 +35,7 @@ void getAccumulators( } template <> -void setAccumulators( +void SetAccumulators( const framework::ExecutionContext& ctx, int64_t num_updates_, int64_t num_accumulates_, int64_t old_num_accumulates_) { auto stream = ctx.cuda_device_context().stream(); diff --git a/paddle/fluid/operators/average_accumulates_op.h b/paddle/fluid/operators/average_accumulates_op.h index d33fd5519a..f858109d14 100644 --- a/paddle/fluid/operators/average_accumulates_op.h +++ b/paddle/fluid/operators/average_accumulates_op.h @@ -28,12 +28,12 @@ template ; template -void getAccumulators(const framework::ExecutionContext& ctx, +void GetAccumulators(const framework::ExecutionContext& ctx, int64_t& num_updates, int64_t& num_accumulates, int64_t& old_num_accumulates); template -void setAccumulators(const framework::ExecutionContext& ctx, +void SetAccumulators(const framework::ExecutionContext& ctx, int64_t num_updates, int64_t num_accumulates, int64_t old_num_accumulates); @@ -47,7 +47,7 @@ class AverageAccumulatesKernel : public framework::OpKernel { int64_t num_updates = 0; int64_t num_accumulates = 0; int64_t old_num_accumulates = 0; - getAccumulators(ctx, num_updates, num_accumulates, + GetAccumulators(ctx, num_updates, num_accumulates, old_num_accumulates); // Get attrs @@ -84,6 +84,8 @@ class AverageAccumulatesKernel : public framework::OpKernel { out_sum_2_tensor.device(place) = in_sum_2_tensor; out_sum_3_tensor.device(place) = in_sum_3_tensor; if (num_updates % kMaxNumAccumulates == 0) { + // Move the sum to a different buffer to avoid loss of precision due to + // too many sums. out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor; constant_functor(ctx.template device_context(), out_sum_1, 0.0); @@ -91,6 +93,7 @@ class AverageAccumulatesKernel : public framework::OpKernel { if (num_accumulates >= min_average_window && num_accumulates >= std::min(max_average_window, num_updates * average_window)) { + // Now the average window is too long, discard the old sum. out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor; constant_functor(ctx.template device_context(), out_sum_1, 0.0); @@ -101,7 +104,7 @@ class AverageAccumulatesKernel : public framework::OpKernel { } // Set accumulators to output - setAccumulators(ctx, num_updates, num_accumulates, + SetAccumulators(ctx, num_updates, num_accumulates, old_num_accumulates); } }; diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 394cf050a7..d8373eaab4 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -732,7 +732,6 @@ class ModelAverage(Optimizer): """Apply average values to parameters of current model. """ executor.run(self.apply_program) - print "finish apply" try: yield finally: @@ -743,4 +742,3 @@ class ModelAverage(Optimizer): """Restore parameter values of current model. """ executor.run(self.restore_program) - print "finish restore" From 3da094fd7ba6fe75618ec3a72ddf514b32726efa Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 21:29:59 -0700 Subject: [PATCH 23/79] rearrange test --- paddle/fluid/operators/elementwise_add_op.cu | 6 +- .../unittests/test_elementwise_add_op.py | 255 +++++++++++------- 2 files changed, 160 insertions(+), 101 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index c8bf524144..dfff518f17 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -17,14 +17,14 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = padddle::platform; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_add, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel - ops::ElementwiseAddKernel); + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel); REGISTER_OP_CUDA_KERNEL( elementwise_add_grad, ops::ElementwiseAddGradKernel, diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 28286d79ea..1f52bd90d0 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -21,15 +21,17 @@ class TestElementwiseAddOp(OpTest): def setUp(self): self.op_type = "elementwise_add" self.dtype = np.float32 - init_dtype() + self.axis = -1 + self.init_dtype() + self.init_input_output() + self.init_axis() - x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) - y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.outputs = {'Out': np.add(x, y)} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': self.out} def test_check_output(self): self.check_output() @@ -51,12 +53,20 @@ class TestElementwiseAddOp(OpTest): self.check_grad( ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) - def init_dtype(): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.add(self.x, self.y) + + def init_dtype(self): + pass + + def init_axis(self): pass class TestFP16ElementwiseAddOp(TestElementwiseAddOp): - def init_dtype(): + def init_dtype(self): self.dtype = np.float16 def test_check_output(self): @@ -67,130 +77,179 @@ class TestFP16ElementwiseAddOp(TestElementwiseAddOp): class TestElementwiseAddOp_scalar(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(1).astype(np.float32) - } - self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y + + +class TestFP16ElementwiseAddOp_scalar(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y class TestElementwiseAddOp_scalar2(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(1, 1).astype(np.float32) - } - self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + self.out = self.x + self.y + + +class TestFP16ElementwiseAddOp_scalar2(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + self.out = self.x + self.y class TestElementwiseAddOp_Vector(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.random((32, )).astype("float32"), - 'Y': np.random.random((32, )).astype("float32") - } - self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['Y'])} + def init_input_output(self): + self.x = np.random.random((32, )).astype(self.dtype) + self.y = np.random.random((32, )).astype(self.dtype) + self.out = np.add(self.x, self.y) + + +class TestFP16ElementwiseAddOp_Vector(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.random((32, )).astype(self.dtype) + self.y = np.random.random((32, )).astype(self.dtype) + self.out = np.add(self.x, self.y) class TestElementwiseAddOp_broadcast_0(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(2).astype(np.float32) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + self.out = self.x + self.y.reshape(2, 1, 1) - self.attrs = {'axis': 0} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1) - } + def init_axis(self): + self.axis = 0 + + +class TestFP16ElementwiseAddOp_broadcast_0(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + self.out = self.x + self.y.reshape(2, 1, 1) + + def init_axis(self): + self.axis = 0 class TestElementwiseAddOp_broadcast_1(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(3).astype(np.float32) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 1) - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 1) - } + def init_axis(self): + self.axis = 1 + + +class TestFP16ElementwiseAddOp_broadcast_1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 1) + + def init_axis(self): + self.axis = 1 class TestElementwiseAddOp_broadcast_2(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(4).astype(np.float32) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1, 4) - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1, 4) - } + +class TestFP16ElementwiseAddOp_broadcast_2(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1, 4) class TestElementwiseAddOp_broadcast_3(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4, 5).astype(np.float32), - 'Y': np.random.rand(3, 4).astype(np.float32) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 4, 1) - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4, 1) - } + def init_axis(self): + self.axis = 1 + + +class TestFP16ElementwiseAddOp_broadcast_3(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 4, 1) + + def init_axis(self): + self.axis = 1 class TestElementwiseAddOp_broadcast_4(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4, 5).astype(np.float32), - 'Y': np.random.rand(2, 1).astype(np.float32) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2, 1).astype(self.dtype) + self.out = self.x + self.y.reshape(2, 1, 1, 1) + + def init_axis(self): + self.axis = 0 - self.attrs = {'axis': 0} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1, 1) - } + +class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2, 1).astype(self.dtype) + self.out = self.x + self.y.reshape(2, 1, 1, 1) + + def init_axis(self): + self.axis = 0 class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(3, 4).astype(np.float32) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 4) - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4) - } + def init_axis(self): + self.axis = 1 + + +class TestFP16ElementwiseAddOp_rowwise_add_0(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 4) + + def init_axis(self): + self.axis = 1 class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 1).astype(np.float32), - 'Y': np.random.rand(1).astype(np.float32) - } + def init_input_output(self): + self.x = np.random.rand(2, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1) - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1) - } + def init_axis(self): + self.axis = 1 + + +class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1) + + def init_axis(self): + self.axis = 1 if __name__ == '__main__': From a2981f5c5018c23aa969389c64a329e53f8cf290 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 15:16:42 +0800 Subject: [PATCH 24/79] fix a bug --- .../fluid/operators/reader/open_files_op.cc | 79 ++++++++++++------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 6b62e1db49..49cdf5365c 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -21,12 +21,10 @@ namespace reader { class MultipleReader : public framework::ReaderBase { public: - struct Quota {}; - MultipleReader(const std::vector& file_names, const std::vector& dims, size_t thread_num) - : file_names_(file_names), dims_(dims), thread_num_(thread_num) { - PADDLE_ENFORCE_GT(thread_num_, 0); + : file_names_(file_names), dims_(dims) { + prefetchers_.resize(thread_num); StartNewScheduler(); } @@ -34,16 +32,20 @@ class MultipleReader : public framework::ReaderBase { bool HasNext() const override; void ReInit() override; + ~MultipleReader() { EndScheduler(); } + private: void StartNewScheduler(); + void EndScheduler(); void ScheduleThreadFunc(); - void PrefetchThreadFunc(std::string file_name); + void PrefetchThreadFunc(std::string file_name, size_t thread_idx); std::vector file_names_; std::vector dims_; - size_t thread_num_; + std::thread scheduler_; + std::vector prefetchers_; framework::Channel* waiting_file_idx_; - framework::Channel* thread_quotas_; + framework::Channel* available_thread_idx_; framework::Channel>* buffer_; mutable std::vector local_buffer_; }; @@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const { } void MultipleReader::ReInit() { - buffer_->Close(); - thread_quotas_->Close(); - waiting_file_idx_->Close(); + EndScheduler(); local_buffer_.clear(); - StartNewScheduler(); } void MultipleReader::StartNewScheduler() { + size_t thread_num = prefetchers_.size(); waiting_file_idx_ = framework::MakeChannel(file_names_.size()); - thread_quotas_ = framework::MakeChannel(thread_num_); + available_thread_idx_ = framework::MakeChannel(thread_num); buffer_ = - framework::MakeChannel>(thread_num_); + framework::MakeChannel>(thread_num); for (size_t i = 0; i < file_names_.size(); ++i) { waiting_file_idx_->Send(&i); } waiting_file_idx_->Close(); - for (size_t i = 0; i < thread_num_; ++i) { - Quota quota; - thread_quotas_->Send("a); + for (size_t i = 0; i < thread_num; ++i) { + available_thread_idx_->Send(&i); } - std::thread scheduler([this] { ScheduleThreadFunc(); }); - scheduler.detach(); + scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); +} + +void MultipleReader::EndScheduler() { + available_thread_idx_->Close(); + buffer_->Close(); + waiting_file_idx_->Close(); + scheduler_.join(); + delete buffer_; + delete available_thread_idx_; + delete waiting_file_idx_; } void MultipleReader::ScheduleThreadFunc() { VLOG(5) << "MultipleReader schedule thread starts."; size_t completed_thread_num = 0; - Quota quota; - while (thread_quotas_->Receive("a)) { + size_t thread_idx; + while (available_thread_idx_->Receive(&thread_idx)) { + std::thread& prefetcher = prefetchers_[thread_idx]; + if (prefetcher.joinable()) { + prefetcher.join(); + } size_t file_idx; if (waiting_file_idx_->Receive(&file_idx)) { // Still have files to read. Start a new prefetch thread. std::string file_name = file_names_[file_idx]; - std::thread prefetcher( - [this, file_name] { PrefetchThreadFunc(file_name); }); - prefetcher.detach(); + prefetcher = std::thread([this, file_name, thread_idx] { + PrefetchThreadFunc(file_name, thread_idx); + }); } else { // No more file to read. ++completed_thread_num; - if (completed_thread_num == thread_num_) { - thread_quotas_->Close(); - buffer_->Close(); + if (completed_thread_num == prefetchers_.size()) { break; } } } + // If users invoke ReInit() when scheduler is running, it will close the + // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler + // to release their resource. So a check is needed before scheduler ends. + for (auto& p : prefetchers_) { + if (p.joinable()) { + p.join(); + } + } VLOG(5) << "MultipleReader schedule thread terminates."; } -void MultipleReader::PrefetchThreadFunc(std::string file_name) { +void MultipleReader::PrefetchThreadFunc(std::string file_name, + size_t thread_idx) { VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; std::unique_ptr reader = CreateReaderByFileName(file_name, dims_); @@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) { break; } } - Quota quota; - thread_quotas_->Send("a); + if (!available_thread_idx_->Send(&thread_idx)) { + VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " + "Fail to send thread_idx."; + } VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; } From c346a345e05e2e17203d693c61be13c541016834 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 16:35:52 +0800 Subject: [PATCH 25/79] fix a bug --- paddle/fluid/recordio/header.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/recordio/header.cc b/paddle/fluid/recordio/header.cc index e50de15b7c..ed09d58f6a 100644 --- a/paddle/fluid/recordio/header.cc +++ b/paddle/fluid/recordio/header.cc @@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs) bool Header::Parse(std::istream& is) { uint32_t magic; - size_t read_size = - is.readsome(reinterpret_cast(&magic), sizeof(uint32_t)); + is.read(reinterpret_cast(&magic), sizeof(uint32_t)); + size_t read_size = is.gcount(); if (read_size < sizeof(uint32_t)) { return false; } From 7c14f49ffb8ae9c87f9161161c60d05ace307c2c Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Tue, 20 Mar 2018 17:04:50 +0800 Subject: [PATCH 26/79] Update index_en.rst fix #8956 --- doc/v2/getstarted/index_en.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/doc/v2/getstarted/index_en.rst b/doc/v2/getstarted/index_en.rst index 33f299be56..62325c799a 100644 --- a/doc/v2/getstarted/index_en.rst +++ b/doc/v2/getstarted/index_en.rst @@ -1,8 +1,18 @@ GET STARTED ============ +If you want to quickly know how to use PaddlePaddle, please refer to the following guide: + .. toctree:: :maxdepth: 1 quickstart_en.rst + + +While using PaddlePaddle to build applications, please understand some basic concepts. +Here is an example of linear regression. It introduces use flow of PaddlePaddle, including data format, model configuration and training, etc. + + .. toctree:: + :maxdepth: 1 + concepts/use_concepts_en.rst From f863866471f285015201183994d45dc5637919bb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 17:54:37 +0800 Subject: [PATCH 27/79] Add an unitest --- .../fluid/operators/reader/open_files_op.cc | 4 +- .../operators/reader/reader_op_registry.cc | 9 ++- .../operators/reader/reader_op_registry.h | 2 +- python/paddle/fluid/layers/io.py | 5 +- .../tests/unittests/test_multiple_reader.py | 71 +++++++++++++++++++ 5 files changed, 82 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_multiple_reader.py diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 49cdf5365c..1ab4111efe 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -94,7 +94,9 @@ void MultipleReader::EndScheduler() { available_thread_idx_->Close(); buffer_->Close(); waiting_file_idx_->Close(); - scheduler_.join(); + if (scheduler_.joinable()) { + scheduler_.join(); + } delete buffer_; delete available_thread_idx_; delete waiting_file_idx_; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 05d79c76d5..fc8dc747ff 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -38,17 +38,16 @@ std::unordered_map& FileReaderRegistry() { std::unique_ptr CreateReaderByFileName( const std::string& file_name, const std::vector& dims) { - size_t separator_pos = file_name.find(kFileFormatSeparator); + size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); PADDLE_ENFORCE_NE(separator_pos, std::string::npos, "File name illegal! A legal file name should be like: " - "[file_format]:[file_name] (e.g., 'recordio:data_file')."); - std::string filetype = file_name.substr(0, separator_pos); - std::string f_name = file_name.substr(separator_pos + 1); + "[file_name].[file_format] (e.g., 'data_file.recordio')."); + std::string filetype = file_name.substr(separator_pos + 1); auto itor = FileReaderRegistry().find(filetype); PADDLE_ENFORCE(itor != FileReaderRegistry().end(), "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(f_name, dims); + framework::ReaderBase* reader = (itor->second)(file_name, dims); return std::unique_ptr(reader); } diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index dd19b982da..929d32ad8b 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { namespace reader { -static constexpr char kFileFormatSeparator[] = ":"; +static constexpr char kFileFormatSeparator[] = "."; using FileReaderCreator = std::function&)>; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 89153f325b..f169642eaa 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,8 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' + 'open_files', 'read_file', 'create_shuffle_reader', + 'create_double_buffer_reader' ] @@ -307,7 +308,7 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): 'shape_concat': shape_concat, 'lod_levels': lod_levels, 'ranks': ranks, - 'filename': filenames, + 'file_names': filenames, 'thread_num': thread_num }) diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py new file mode 100644 index 0000000000..cb1aaaae5a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist +from shutil import copyfile + + +class TestMultipleReader(unittest.TestCase): + def setUp(self): + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(mnist.train(), batch_size=32) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist_0.recordio', reader, feeder) + copyfile('./mnist_0.recordio', './mnist_1.recordio') + copyfile('./mnist_0.recordio', './mnist_2.recordio') + print(self.num_batch) + + def test_multiple_reader(self, thread_num=3): + file_list = [ + './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' + ] + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_files = fluid.layers.open_files( + filenames=file_list, + thread_num=thread_num, + shapes=[(-1, 784), (-1, 1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(data_files) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + batch_count = 0 + while not data_files.eof(): + img_val, = exe.run(fetch_list=[img]) + batch_count += 1 + print(batch_count) + # data_files.reset() + print("FUCK") + + self.assertEqual(batch_count, self.num_batch * 3) From 68c9f6ef1160ff965b4740f369ee4077ab8c113f Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 20 Mar 2018 18:03:22 +0800 Subject: [PATCH 28/79] Fix error while params_grads[1]==None --- python/paddle/fluid/optimizer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index d8373eaab4..4993fe39e0 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -649,20 +649,23 @@ class ModelAverage(Optimizer): self.min_average_window = min_average_window self.max_average_window = max_average_window self.params_grads = params_grads - for param, _ in self.params_grads: - self._append_average_accumulate_op(param) + for param, grad in self.params_grads: + if grad is not None: + self._append_average_accumulate_op(param) self.apply_program = Program() block = self.apply_program.global_block() with program_guard(main_program=self.apply_program): for param_grad in self.params_grads: - self._add_average_apply_op(block, param_grad) + if param_grad[1] is not None: + self._add_average_apply_op(block, param_grad) self.restore_program = Program() block = self.restore_program.global_block() with program_guard(main_program=self.restore_program): for param_grad in self.params_grads: - self._add_average_restore_op(block, param_grad) + if param_grad[1] is not None: + self._add_average_restore_op(block, param_grad) def _add_average_apply_op(self, block, param_grad): param = block.clone_variable(param_grad[0]) From 2532b922dc4897478589d7b4064cde40113f943b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 19:20:58 +0800 Subject: [PATCH 29/79] Add more unittests and fix bugs --- paddle/fluid/operators/reader/open_files_op.cc | 1 + python/paddle/fluid/tests/unittests/.gitignore | 3 +++ .../tests/unittests/test_multiple_reader.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 1ab4111efe..414c76fea0 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() { // No more file to read. ++completed_thread_num; if (completed_thread_num == prefetchers_.size()) { + buffer_->Close(); break; } } diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 6b3fc2a83c..ad02bdecf4 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -1 +1,4 @@ mnist.recordio +mnist_0.recordio +mnist_1.recordio +mnist_2.recordio diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py index cb1aaaae5a..69f8acf81e 100644 --- a/python/paddle/fluid/tests/unittests/test_multiple_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -22,9 +22,10 @@ from shutil import copyfile class TestMultipleReader(unittest.TestCase): def setUp(self): + self.batch_size = 64 # Convert mnist to recordio file with fluid.program_guard(fluid.Program(), fluid.Program()): - reader = paddle.batch(mnist.train(), batch_size=32) + reader = paddle.batch(mnist.train(), batch_size=self.batch_size) feeder = fluid.DataFeeder( feed_list=[ # order is image and label fluid.layers.data( @@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase): './mnist_0.recordio', reader, feeder) copyfile('./mnist_0.recordio', './mnist_1.recordio') copyfile('./mnist_0.recordio', './mnist_2.recordio') - print(self.num_batch) - def test_multiple_reader(self, thread_num=3): + def main(self, thread_num): file_list = [ './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' ] @@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase): while not data_files.eof(): img_val, = exe.run(fetch_list=[img]) batch_count += 1 - print(batch_count) - # data_files.reset() - print("FUCK") - + self.assertLessEqual(img_val.shape[0], self.batch_size) + data_files.reset() self.assertEqual(batch_count, self.num_batch * 3) + + def test_main(self): + self.main(thread_num=3) # thread number equals to file number + self.main(thread_num=10) # thread number is larger than file number + self.main(thread_num=2) # thread number is less than file number From bce08d19ccac8a5302a162d371c80bae7c74c289 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 19:39:19 +0800 Subject: [PATCH 30/79] Python wrapper for MultiPassReader --- python/paddle/fluid/layers/io.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 9c91f395e7..e0c4cffa2d 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -314,6 +314,11 @@ def create_double_buffer_reader(reader, place=None): attrs) +def create_multi_pass_reader(reader, pass_num): + return __create_decorated_reader__('create_multi_pass_reader', reader, + {'pass_num': int(pass_num)}) + + def read_file(file_obj): helper = LayerHelper('read_file') out = [ From dc2bc077a2f2479fcfb55c5b029d6eed6bb628c9 Mon Sep 17 00:00:00 2001 From: weixing02 <564445201@qq.com> Date: Tue, 20 Mar 2018 19:40:03 +0800 Subject: [PATCH 31/79] Build basic sphinx doctree for doc/fluid --- doc/CMakeLists.txt | 1 + doc/fluid/CMakeLists.txt | 49 ++++++++++++++++++++++++ doc/fluid/build_and_install/index_cn.rst | 2 + doc/fluid/build_and_install/index_en.rst | 2 + doc/fluid/design/index_cn.rst | 2 + doc/fluid/design/index_en.rst | 2 + doc/fluid/dev/index_cn.rst | 2 + doc/fluid/dev/index_en.rst | 4 ++ doc/fluid/faq/index_cn.rst | 2 + doc/fluid/faq/index_en.rst | 2 + doc/fluid/getstarted/index_cn.rst | 4 ++ doc/fluid/getstarted/index_en.rst | 4 ++ doc/fluid/howto/index_cn.rst | 2 + doc/fluid/howto/index_en.rst | 4 ++ doc/fluid/index_cn.rst | 12 ++++++ doc/fluid/index_en.rst | 12 ++++++ 16 files changed, 106 insertions(+) create mode 100644 doc/fluid/CMakeLists.txt create mode 100644 doc/fluid/build_and_install/index_cn.rst create mode 100644 doc/fluid/build_and_install/index_en.rst create mode 100644 doc/fluid/design/index_cn.rst create mode 100644 doc/fluid/design/index_en.rst create mode 100644 doc/fluid/dev/index_cn.rst create mode 100644 doc/fluid/dev/index_en.rst create mode 100644 doc/fluid/faq/index_cn.rst create mode 100644 doc/fluid/faq/index_en.rst create mode 100644 doc/fluid/getstarted/index_cn.rst create mode 100644 doc/fluid/getstarted/index_en.rst create mode 100644 doc/fluid/howto/index_cn.rst create mode 100644 doc/fluid/howto/index_en.rst create mode 100644 doc/fluid/index_cn.rst create mode 100644 doc/fluid/index_en.rst diff --git a/doc/CMakeLists.txt b/doc/CMakeLists.txt index da67701ec1..a9b27933a5 100644 --- a/doc/CMakeLists.txt +++ b/doc/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(v2) +add_subdirectory(fluid) diff --git a/doc/fluid/CMakeLists.txt b/doc/fluid/CMakeLists.txt new file mode 100644 index 0000000000..cc999f5a8d --- /dev/null +++ b/doc/fluid/CMakeLists.txt @@ -0,0 +1,49 @@ +if(NOT DEFINED SPHINX_THEME) + set(SPHINX_THEME default) +endif() + +if(NOT DEFINED SPHINX_THEME_DIR) + set(SPHINX_THEME_DIR) +endif() + +# configured documentation tools and intermediate build results +set(BINARY_BUILD_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_build") + +# Sphinx cache with pickled ReST documents +set(SPHINX_CACHE_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_doctrees") + +# HTML output director +set(SPHINX_HTML_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/html") + +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.en.in" + "${BINARY_BUILD_DIR_EN}/conf.py" + @ONLY) + +sphinx_add_target(paddle_fluid_docs + html + ${BINARY_BUILD_DIR_EN} + ${SPHINX_CACHE_DIR_EN} + ${CMAKE_CURRENT_SOURCE_DIR} + ${SPHINX_HTML_DIR_EN}) + +# configured documentation tools and intermediate build results +set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build") + +# Sphinx cache with pickled ReST documents +set(SPHINX_CACHE_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_doctrees") + +# HTML output directory +set(SPHINX_HTML_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/html") + +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.cn.in" + "${BINARY_BUILD_DIR_CN}/conf.py" + @ONLY) + +sphinx_add_target(paddle_fluid_docs_cn + html + ${BINARY_BUILD_DIR_CN} + ${SPHINX_CACHE_DIR_CN} + ${CMAKE_CURRENT_SOURCE_DIR} + ${SPHINX_HTML_DIR_CN}) diff --git a/doc/fluid/build_and_install/index_cn.rst b/doc/fluid/build_and_install/index_cn.rst new file mode 100644 index 0000000000..9276236f9f --- /dev/null +++ b/doc/fluid/build_and_install/index_cn.rst @@ -0,0 +1,2 @@ +安装与使用 +------------ diff --git a/doc/fluid/build_and_install/index_en.rst b/doc/fluid/build_and_install/index_en.rst new file mode 100644 index 0000000000..cc1e61a58a --- /dev/null +++ b/doc/fluid/build_and_install/index_en.rst @@ -0,0 +1,2 @@ +Build and Install +------------ diff --git a/doc/fluid/design/index_cn.rst b/doc/fluid/design/index_cn.rst new file mode 100644 index 0000000000..f1887be690 --- /dev/null +++ b/doc/fluid/design/index_cn.rst @@ -0,0 +1,2 @@ +设计思想 +------------ diff --git a/doc/fluid/design/index_en.rst b/doc/fluid/design/index_en.rst new file mode 100644 index 0000000000..18a4b4122f --- /dev/null +++ b/doc/fluid/design/index_en.rst @@ -0,0 +1,2 @@ +Design +------------ diff --git a/doc/fluid/dev/index_cn.rst b/doc/fluid/dev/index_cn.rst new file mode 100644 index 0000000000..e1edf079fa --- /dev/null +++ b/doc/fluid/dev/index_cn.rst @@ -0,0 +1,2 @@ +开发标准 +------------ diff --git a/doc/fluid/dev/index_en.rst b/doc/fluid/dev/index_en.rst new file mode 100644 index 0000000000..faf9dfcd31 --- /dev/null +++ b/doc/fluid/dev/index_en.rst @@ -0,0 +1,4 @@ +Development +------------ + +This is Development page diff --git a/doc/fluid/faq/index_cn.rst b/doc/fluid/faq/index_cn.rst new file mode 100644 index 0000000000..395c110989 --- /dev/null +++ b/doc/fluid/faq/index_cn.rst @@ -0,0 +1,2 @@ +FAQ +------------ diff --git a/doc/fluid/faq/index_en.rst b/doc/fluid/faq/index_en.rst new file mode 100644 index 0000000000..395c110989 --- /dev/null +++ b/doc/fluid/faq/index_en.rst @@ -0,0 +1,2 @@ +FAQ +------------ diff --git a/doc/fluid/getstarted/index_cn.rst b/doc/fluid/getstarted/index_cn.rst new file mode 100644 index 0000000000..c4d8525f23 --- /dev/null +++ b/doc/fluid/getstarted/index_cn.rst @@ -0,0 +1,4 @@ +新手入门 +------------ + +新手入门 diff --git a/doc/fluid/getstarted/index_en.rst b/doc/fluid/getstarted/index_en.rst new file mode 100644 index 0000000000..a4efd05e2f --- /dev/null +++ b/doc/fluid/getstarted/index_en.rst @@ -0,0 +1,4 @@ +GET STARTED +------------ + +This is get started page diff --git a/doc/fluid/howto/index_cn.rst b/doc/fluid/howto/index_cn.rst new file mode 100644 index 0000000000..a92abad0c5 --- /dev/null +++ b/doc/fluid/howto/index_cn.rst @@ -0,0 +1,2 @@ +进阶使用 +------------ diff --git a/doc/fluid/howto/index_en.rst b/doc/fluid/howto/index_en.rst new file mode 100644 index 0000000000..06036bdce5 --- /dev/null +++ b/doc/fluid/howto/index_en.rst @@ -0,0 +1,4 @@ +HOW TO +------------ + +This is how to page diff --git a/doc/fluid/index_cn.rst b/doc/fluid/index_cn.rst new file mode 100644 index 0000000000..be3bed4393 --- /dev/null +++ b/doc/fluid/index_cn.rst @@ -0,0 +1,12 @@ + PaddlePaddle Fluid +========================== + +.. toctree:: + :maxdepth: 1 + + getstarted/index_cn.rst + design/index_cn.rst + build_and_install/index_cn.rst + howto/index_cn.rst + dev/index_cn.rst + faq/index_cn.rst diff --git a/doc/fluid/index_en.rst b/doc/fluid/index_en.rst new file mode 100644 index 0000000000..87c831420a --- /dev/null +++ b/doc/fluid/index_en.rst @@ -0,0 +1,12 @@ + PaddlePaddle Fluid +========================== + +.. toctree:: + :maxdepth: 1 + + getstarted/index_en.rst + design/index_en.rst + build_and_install/index_en.rst + howto/index_en.rst + dev/index_en.rst + faq/index_en.rst From 30b70323b4dc04ff1270c520711fa5428f509ae5 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 20 Mar 2018 19:43:46 +0800 Subject: [PATCH 32/79] Expose RMSProp optimizer. (#9247) * Add RMSProp optimizer warpper. * Follow comments. --- python/paddle/fluid/optimizer.py | 118 +++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index e8623ee0da..a33760a528 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -664,6 +664,123 @@ class AdadeltaOptimizer(Optimizer): return adadelta_op +class RMSPropOptimizer(Optimizer): + """ + Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning + rate method. The original slides proposed RMSProp: Slide 29 of + http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf . + + The original equation is as follows: + + .. math:: + + r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2 \\\\ + + w & = w - \\frac{\\eta} {\\sqrt{r(w,t) + \\epsilon}} \\nabla Q_{i}(w) + + The first equation calculates moving average of the squared gradient for + each weight. Then dividing the gradient by :math: `sqrt{v(w,t)}`. + + In some cases, adding a momentum term :math: `\\beta` is beneficial. + In our implementation, Nesterov momentum is used: + + .. math:: + + r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2 \\\\ + + v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{v(w,t) + + \\epsilon}} \\nabla Q_{i}(w) + + w & = w - v(w, t) + + where, :math: `\\rho` is a hyperparameter and typical values are 0.9, 0.95 + and so on. :math: `beta` is the momentum term. :math: `\\epsilon` is a + smoothing term to avoid division by zero, usually set somewhere in range + from 1e-4 to 1e-8. + + + Args: + learning_rate(float): global leraning rate. + rho(float): rho is :math: `\\rho` in equation, set 0.95 by default. + epsilon(float): :math: `\\epsilon` in equation is smoothing term to + avoid division by zero, set 1e-6 by default. + momentum(float): :math: `\\beta` in equation is the momentum term, + set 0.0 by default. + + Raises: + ValueError: If learning_rate, rho, epsilon, momentum are None. + + Examples: + .. code-block:: python + + optimizer = fluid.optimizer.RMSProp(0.0001) + _, params_grads = optimizer.minimize(cost) + """ + + _momentum_acc_str = "momentum" + _mean_square_acc_str = "mean_square" + + def __init__(self, + learning_rate, + rho=0.95, + epsilon=1.0e-6, + momentum=0.0, + **kwargs): + super(RMSPropOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) + if learning_rate is None: + raise ValueError("learning_rate is not set.") + if rho is None: + raise ValueError("rho is not set.") + if epsilon is None: + raise ValueError("epsilon is not set.") + if momentum is None: + raise ValueError("momentum is not set.") + + self.type = "rmsprop" + self._rho = rho + self._epsilon = epsilon + self._momentum = momentum + + def _create_accumulators(self, block, parameters): + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") + + for p in parameters: + self._add_accumulator(self._momentum_acc_str, p) + self._add_accumulator(self._mean_square_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") + + momentum_acc = self._get_accumulator(self._momentum_acc_str, + param_and_grad[0]) + mean_square_acc = self._get_accumulator(self._mean_square_acc_str, + param_and_grad[0]) + rmsprop_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Moment": momentum_acc, + "MeanSquare": mean_square_acc, + "LearningRate": self._create_param_lr(param_and_grad), + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": momentum_acc, + "MeanSquareOut": mean_square_acc + }, + attrs={ + "epsilon": self._epsilon, + "decay": self._rho, + "momentum": self._momentum + }) + + return rmsprop_op + + # We short the class name, since users will use the optimizer with the package # name. The sample code: # @@ -679,3 +796,4 @@ Adam = AdamOptimizer Adamax = AdamaxOptimizer DecayedAdagrad = DecayedAdagradOptimizer Adadelta = AdadeltaOptimizer +RMSProp = RMSPropOptimizer From 963e20beb5d2927112b12989feb68a557f44f9c0 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Tue, 20 Mar 2018 20:22:35 +0800 Subject: [PATCH 33/79] Update index_en.rst fix https://github.com/PaddlePaddle/Paddle/issues/8921 --- doc/v2/faq/index_en.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/v2/faq/index_en.rst b/doc/v2/faq/index_en.rst index 57df868f76..4c73d2105e 100644 --- a/doc/v2/faq/index_en.rst +++ b/doc/v2/faq/index_en.rst @@ -1,7 +1,8 @@ FAQ ==== - +This document provides frequently asked questions of PaddlePaddle. If your questions are not here, please go to `PaddlePaddle Community `_to find answers or open an `issue `_ , we will reply in time. + .. toctree:: :maxdepth: 1 From a4f397fb681fd6b0bc28216c3d9bae3b660b50a8 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 20:50:03 +0800 Subject: [PATCH 34/79] add an unittest --- python/paddle/fluid/layers/io.py | 3 +- .../tests/unittests/test_multi_pass_reader.py | 66 +++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_multi_pass_reader.py diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index e0c4cffa2d..4ff5cd9bf5 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,8 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' + 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader', + 'create_multi_pass_reader' ] diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py new file mode 100644 index 0000000000..17374aec1b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py @@ -0,0 +1,66 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist + + +class TestMultipleReader(unittest.TestCase): + def setUp(self): + self.batch_size = 64 + self.pass_num = 3 + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_file = paddle.batch(mnist.train(), batch_size=self.batch_size) + feeder = fluid.DataFeeder( + feed_list=[ + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist.recordio', data_file, feeder) + + def test_main(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_file = fluid.layers.open_recordio_file( + filename='./mnist.recordio', + shapes=[(-1, 784), (-1, 1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + data_file = fluid.layers.create_multi_pass_reader( + reader=data_file, pass_num=self.pass_num) + img, label = fluid.layers.read_file(data_file) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + batch_count = 0 + while not data_file.eof(): + img_val, = exe.run(fetch_list=[img]) + batch_count += 1 + self.assertLessEqual(img_val.shape[0], self.batch_size) + print(batch_count) + data_file.reset() + self.assertEqual(batch_count, self.num_batch * self.pass_num) From 0b2f1b3f45aa11cf161c922da49ce30fe588ecd1 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 20:51:29 +0800 Subject: [PATCH 35/79] clear stream during Scanner::Reset() --- paddle/fluid/recordio/scanner.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc index d842f8fe5a..c22281dc97 100644 --- a/paddle/fluid/recordio/scanner.cc +++ b/paddle/fluid/recordio/scanner.cc @@ -28,6 +28,7 @@ Scanner::Scanner(const std::string &filename) { } void Scanner::Reset() { + stream_->clear(); stream_->seekg(0, std::ios::beg); ParseNextChunk(); } From a944d57181a3cd00a3f55c191daba825b5aabcda Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 20:54:55 +0800 Subject: [PATCH 36/79] refine code --- python/paddle/fluid/tests/unittests/test_multi_pass_reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py index 17374aec1b..8add353303 100644 --- a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py @@ -61,6 +61,5 @@ class TestMultipleReader(unittest.TestCase): img_val, = exe.run(fetch_list=[img]) batch_count += 1 self.assertLessEqual(img_val.shape[0], self.batch_size) - print(batch_count) data_file.reset() self.assertEqual(batch_count, self.num_batch * self.pass_num) From e84c8932c1862eca63ea9b1a9d23c387e7a4707b Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Tue, 20 Mar 2018 21:15:41 +0800 Subject: [PATCH 37/79] =?UTF-8?q?change=20"=E4=BD=BF=E7=94=A8=E6=B5=81?= =?UTF-8?q?=E7=A8=8B"=20translation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Done --- doc/v2/getstarted/index_en.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/v2/getstarted/index_en.rst b/doc/v2/getstarted/index_en.rst index 62325c799a..2cf3fc07ea 100644 --- a/doc/v2/getstarted/index_en.rst +++ b/doc/v2/getstarted/index_en.rst @@ -10,7 +10,7 @@ If you want to quickly know how to use PaddlePaddle, please refer to the followi While using PaddlePaddle to build applications, please understand some basic concepts. -Here is an example of linear regression. It introduces use flow of PaddlePaddle, including data format, model configuration and training, etc. +Here is an example of linear regression. It introduces workflow of PaddlePaddle, including data format, model configuration and training, etc. .. toctree:: :maxdepth: 1 From 976a1bb0d44d59c8299d20f09aa88c2a185b2177 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Tue, 20 Mar 2018 21:26:15 +0800 Subject: [PATCH 38/79] Update index_en.rst fix space and toctree format --- doc/v2/getstarted/index_en.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/v2/getstarted/index_en.rst b/doc/v2/getstarted/index_en.rst index 2cf3fc07ea..94b306895c 100644 --- a/doc/v2/getstarted/index_en.rst +++ b/doc/v2/getstarted/index_en.rst @@ -10,9 +10,10 @@ If you want to quickly know how to use PaddlePaddle, please refer to the followi While using PaddlePaddle to build applications, please understand some basic concepts. + Here is an example of linear regression. It introduces workflow of PaddlePaddle, including data format, model configuration and training, etc. - .. toctree:: +.. toctree:: :maxdepth: 1 concepts/use_concepts_en.rst From ce58c6ea2448b1f51e05db346eb5566383b0b308 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Tue, 20 Mar 2018 21:38:16 +0800 Subject: [PATCH 39/79] Update index_en.rst repair hyperlink format --- doc/v2/faq/index_en.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/v2/faq/index_en.rst b/doc/v2/faq/index_en.rst index 4c73d2105e..136c14c717 100644 --- a/doc/v2/faq/index_en.rst +++ b/doc/v2/faq/index_en.rst @@ -1,7 +1,7 @@ FAQ ==== -This document provides frequently asked questions of PaddlePaddle. If your questions are not here, please go to `PaddlePaddle Community `_to find answers or open an `issue `_ , we will reply in time. +This document provides frequently asked questions of PaddlePaddle. If your questions are not here, please go to `PaddlePaddle Community`_, to find answers or open an `issue`_ , we will reply in time. .. toctree:: :maxdepth: 1 From 9b278a3eeaf3977ad7618c8237cc316031d1f504 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Tue, 20 Mar 2018 21:47:47 +0800 Subject: [PATCH 40/79] fix format problem --- doc/v2/faq/index_en.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/v2/faq/index_en.rst b/doc/v2/faq/index_en.rst index 136c14c717..5ce5cfbae7 100644 --- a/doc/v2/faq/index_en.rst +++ b/doc/v2/faq/index_en.rst @@ -1,7 +1,7 @@ FAQ ==== -This document provides frequently asked questions of PaddlePaddle. If your questions are not here, please go to `PaddlePaddle Community`_, to find answers or open an `issue`_ , we will reply in time. +This document provides frequently asked questions of PaddlePaddle. If your questions are not here, please go to `PaddlePaddle Community `_ , to find answers or open an `issue `_ , we will reply in time. .. toctree:: :maxdepth: 1 From 37a272e670c55b66701f98d51e7cb90c43b97f87 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 20 Mar 2018 21:48:45 +0800 Subject: [PATCH 41/79] add executor.prepare (#9022) optimize executor.run --- paddle/fluid/framework/executor.cc | 28 ++- paddle/fluid/framework/executor.h | 15 +- python/paddle/fluid/executor.py | 165 ++++++++++-------- .../tests/unittests/test_executor_and_mul.py | 1 - 4 files changed, 116 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7155d5ef2f..a688115b11 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -14,12 +14,8 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" -#include - -#include "gflags/gflags.h" #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/feed_fetch_method.h" -#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" @@ -40,14 +36,13 @@ namespace { int kProgramId = -1; } // namespace -struct ExecutorPrepareContext { - ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id) - : prog_(prog), block_id_(block_id) {} +ExecutorPrepareContext::ExecutorPrepareContext( + const framework::ProgramDesc& prog, size_t block_id) + : prog_(prog), block_id_(block_id) {} - const framework::ProgramDesc& prog_; - size_t block_id_; - std::vector> ops_; -}; +ExecutorPrepareContext::~ExecutorPrepareContext() { + VLOG(5) << "destroy ExecutorPrepareContext"; +} Executor::Executor(const platform::Place& place) : place_(place) {} @@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars) { platform::RecordBlock b(block_id); - auto* ctx = Prepare(pdesc, block_id); - RunPreparedContext(ctx, scope, create_local_scope, create_vars); - delete ctx; + auto ctx = Prepare(pdesc, block_id); + RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); } // Check whether the block already has feed operators and feed_holder. @@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } -ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, - int block_id) { +std::unique_ptr Executor::Prepare( + const ProgramDesc& program, int block_id) { auto* ctx = new ExecutorPrepareContext(program, block_id); PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); auto& block = program.Block(block_id); for (auto& op_desc : block.AllOps()) { ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); } - return ctx; + return std::unique_ptr(ctx); } void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 28ce331515..fb29c70f14 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -22,7 +22,16 @@ limitations under the License. */ namespace paddle { namespace framework { -struct ExecutorPrepareContext; + +struct ExecutorPrepareContext { + ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); + ~ExecutorPrepareContext(); + + const framework::ProgramDesc& prog_; + size_t block_id_; + std::vector> ops_; +}; + class Executor { public: // TODO(dzhwinter) : Do not rely on this function, it will be removed @@ -47,8 +56,8 @@ class Executor { const std::string& feed_holder_name = "feed", const std::string& fetch_holder_name = "fetch"); - static ExecutorPrepareContext* Prepare(const ProgramDesc& program, - int block_id); + static std::unique_ptr Prepare( + const ProgramDesc& program, int block_id); void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 4490f2bf15..2612fb1ae4 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -235,6 +235,77 @@ class Executor(object): tensor.set_lod(lod) return tensor + def _get_program_cache(self, program_cache_key): + return self.program_caches.get(program_cache_key, None) + + def _add_program_cache(self, program_cache_key, program): + self.program_caches[program_cache_key] = program + + def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, + fetch_var_name): + tmp_program = program.clone() + + global_block = tmp_program.global_block() + + if feed_var_name in global_block.vars: + feed_var = global_block.var(feed_var_name) + else: + feed_var = global_block.create_var( + name=feed_var_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) + + if fetch_var_name in global_block.vars: + fetch_var = global_block.var(fetch_var_name) + else: + fetch_var = global_block.create_var( + name=fetch_var_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) + + # prepend feed operators + if not has_feed_operators(global_block, feed, feed_var_name): + for i, name in enumerate(feed): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + # append fetch_operators + if not has_fetch_operators(global_block, fetch_list, fetch_var_name): + for i, var in enumerate(fetch_list): + assert isinstance(var, Variable) or isinstance(var, str), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) + global_block.append_op( + type='fetch', + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + return tmp_program + + def _feed_data(self, program, feed, feed_var_name, scope): + # feed var to framework + for op in program.global_block().ops: + if op.desc.type() == 'feed': + feed_target_name = op.desc.output('Out')[0] + cur_feed = feed[feed_target_name] + if not isinstance(cur_feed, core.LoDTensor): + cur_feed = self.aslodtensor(cur_feed) + idx = op.desc.attr('col') + core.set_feed_variable(scope, cur_feed, feed_var_name, idx) + else: + break + + def _fetch_data(self, fetch_list, fetch_var_name, scope): + outs = [ + core.get_fetch_variable(scope, fetch_var_name, i) + for i in xrange(len(fetch_list)) + ] + return outs + def run(self, program=None, feed=None, @@ -268,7 +339,6 @@ class Executor(object): raise TypeError("feed should be a map") if fetch_list is None: fetch_list = [] - if program is None: program = default_main_program() @@ -278,79 +348,30 @@ class Executor(object): if scope is None: scope = global_scope() - program_cache = None - program_cache_key = get_program_cache_key(feed, fetch_list) - + cache_key = get_program_cache_key(feed, fetch_list) if use_program_cache: - # find program cache by cache_key - program_cache = self.program_caches.get(program_cache_key, None) - # TODO(qiao): Should check program_cache and program are exactly the same. + cached_program = self._get_program_cache(cache_key) + if cached_program is None: + cached_program = self._add_feed_fetch_ops( + program=program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name) + self._add_program_cache(cache_key, cached_program) + program = cached_program else: - self.program_caches.pop(program_cache_key, None) - - if program_cache is None: - program_cache = program.clone() - - if use_program_cache: - self.program_caches[program_cache_key] = program_cache - - global_block = program_cache.global_block() - - if feed_var_name in global_block.vars: - feed_var = global_block.var(feed_var_name) - else: - feed_var = global_block.create_var( - name=feed_var_name, - type=core.VarDesc.VarType.FEED_MINIBATCH, - persistable=True) - - if fetch_var_name in global_block.vars: - fetch_var = global_block.var(fetch_var_name) - else: - fetch_var = global_block.create_var( - name=fetch_var_name, - type=core.VarDesc.VarType.FETCH_LIST, - persistable=True) - - # prepend feed operators - if not has_feed_operators(global_block, feed, feed_var_name): - for i, name in enumerate(feed): - out = global_block.var(name) - global_block.prepend_op( - type='feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) - - # append fetch_operators - if not has_fetch_operators(global_block, fetch_list, - fetch_var_name): - for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance(var, str), ( - "Wrong type for fetch_list[%s]: %s" % (i, type(var))) - global_block.append_op( - type='fetch', - inputs={'X': [var]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) - - # feed var to framework - for op in program_cache.global_block().ops: - if op.desc.type() == 'feed': - feed_target_name = op.desc.output('Out')[0] - cur_feed = feed[feed_target_name] - if not isinstance(cur_feed, core.LoDTensor): - cur_feed = self.aslodtensor(cur_feed) - idx = op.desc.attr('col') - core.set_feed_variable(scope, cur_feed, feed_var_name, idx) - else: - break - - self.executor.run(program_cache.desc, scope, 0, True, True) - outs = [ - core.get_fetch_variable(scope, fetch_var_name, i) - for i in xrange(len(fetch_list)) - ] + self.program_caches.pop(cache_key, None) + program = self._add_feed_fetch_ops( + program=program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name) + + self._feed_data(program, feed, feed_var_name, scope) + self.executor.run(program.desc, scope, 0, True, True) + outs = self._fetch_data(fetch_list, fetch_var_name, scope) if return_numpy: outs = as_numpy(outs) return outs diff --git a/python/paddle/fluid/tests/unittests/test_executor_and_mul.py b/python/paddle/fluid/tests/unittests/test_executor_and_mul.py index 4958bef3ef..e1272c1d6d 100644 --- a/python/paddle/fluid/tests/unittests/test_executor_and_mul.py +++ b/python/paddle/fluid/tests/unittests/test_executor_and_mul.py @@ -16,7 +16,6 @@ import unittest import numpy import paddle.fluid.core as core - from paddle.fluid.executor import Executor from paddle.fluid.layers import mul, data From 873cb9bcc71e8bcc089e50186a30e0cc35f06665 Mon Sep 17 00:00:00 2001 From: Thuan Nguyen Date: Tue, 20 Mar 2018 12:10:42 -0700 Subject: [PATCH 42/79] Create select_op design document (#9139) * Create select_op design document * Fix pre-commit issues * Update select op as per varun's comments --- .../concurrent/images/select_op_workflow.png | Bin 0 -> 101447 bytes doc/fluid/design/concurrent/select_op.md | 265 ++++++++++++++++++ 2 files changed, 265 insertions(+) create mode 100644 doc/fluid/design/concurrent/images/select_op_workflow.png create mode 100644 doc/fluid/design/concurrent/select_op.md diff --git a/doc/fluid/design/concurrent/images/select_op_workflow.png b/doc/fluid/design/concurrent/images/select_op_workflow.png new file mode 100644 index 0000000000000000000000000000000000000000..719ed76f9d542d6c4f20c30f27656bb53325aa85 GIT binary patch literal 101447 zcmeEuWmJ@H_b&_s3?(Vujl|I1jYx>n9RkuJ(hbs53eq8nNJukucY}1dh~&_D?s?wl z{jamuS?hc|-_DoYg142a_BV2?+^DMOj`82?+%o2?^8!K?6RC z%HF3zLZU@dk(YVxW%4%*1Ew~8+qY_^tiV4wE{%f2D(9X?q4_;w&-#1+rv&Wc{ulJJ zgDx*wjo*A9e%snV9uytGB*!52EGpnWXKb?Z(LdgORAP6D@4$cMTs(W@Z_9@x1K*>| z+mC(s@G+^!NI`}6a7d|Xo5}lw1t#Ci!_WS=7im%#kN5Y6J<(6!bbY|lDp1B}GpMnx zd%U@=Ct!LGxwM6&fqYDwE8|ZL!n0Go=*l)(LD4}Jc*5!!p%is0yH|TfIfDze_SN^r zrfsd1B5sVTS7T9-V_(+p&ZIbT@6-N`o|r1UTrvph{=kB{7V;C10Vf|jq041 zeTIz&KeLnU7u!A}R~n?;ww*8~#<|E9|gqw~?T)1adbI!8)s zr&;nZHH-VO!j98v7KN9B-_&c-6Yc#b%GEF;I~68#5EXpZ>}5gWfgc@2pAZINHCOIb zJK6a3O(TcvuH%sgIbG`Mv)9KPTGeGyr7u#%%^a0gRg-nf47>xD8;oT^1DqH7?zD-sk>wQf7;c^p>V1|&h;M}k8FB!HgC*TBEkl)@S#1s-x9}}&I!eS&= zkgK9JhFm7{<&6uNoF1DnCx|-}(@81ukS#o~{fBfg2(9EbcESvt;AU?E?#7>>zY5LE zB9nmTtrzIBfuf~t0l7eMSX!Q;g9C2%SoY?ti!}L=*DrT6oOOm}<<~}pR{3SewRMf( zREVAbdL`F9)shO_v)|EK^Y2@x%&Fw)T_#GuGXn)_dVE_PaajbDL5V!gwn;l)|}RHP7<+?qs@kBbt4y z3s;I?`{(k?S#$S$^z-I!V$+}JhqdsH+WU!fR?t8sa648@pBZze3RIF~(eA#+OFlnM zYYg$c;uXd69f(~i|1#I#bmh7Hn4-7F>ivZS-;twMR>L^&U~ z+Y=|epZOG7h#@zuW2A|ssNra_UL*P`JM&R9{G-Z!F((hC8IY(}kvdggg6?`}I3ZsY zo+7Xbkg_XGeoU28`eGcfnfr7TE0+zr9r%YG-K%eS-V7ste(^{N?|Qt z>Vf-nUMiRIvvAC(&;MH8#g^|fHPiu1xE&}}jWN%Usa9r#s%V9FoQc#n?*ZoR)SOwRiw^?0qp-xy;@lb21hVWv}V^tMX+{wVM&;<}IxSq65JtH7!2O>ZbU;lbiGU*|(x-whmtFg)ChHsx8}2U_Uzyb&8f@7A zY+**ApErDmX{1|D0;)&nQf5$X9aouT3_i#X}Q``&XmvP&mVAUBsoyfuwAa7H;dg%`rnWSH+dedV3=Bc-$fnjT=rab zj2vBbZgY%8aeREZKix{!z*Hy%;Qbki@w=BF?Pb5yVNC&Q(#4D&_;6iha0CKSftq`hW&&o>Jh0CdZT07|c{3BQtv zA%PNCS#ucF^l!3zJf1qWUR&?2wD_j1Wko&>cQ@hX<5OF^`#WD(sFKS5LGbSY3gxXz5w#NDOk(ci&xu7piUIuZ8R~q%#}eA3W=$^32pE zYf??+Ivtn8RcUbB9$t{VKYPU|n1Wv#c=vw(O=@`afs+`N4*CKbivb#NF|nxPK6nj4 zbSDB|=Mubc(ZK9?je%F|=+4kzawTF1HB(z9KU<$i4ynR+w+mBElAM+s2U0i z_+U+aASEjzz^cp^L#ysHkNNajzp>{sld0cHIR0<^eE;R7Y_AOr%UWd(4WX1Tt^T*< z`9b3!rKzLCk!K5kXqOCz;!1vbUqQrYQOq@ML++Wi4*-Mnmgy8N>3@%|?r zglNcORPRX8ae?uCe>OEX7W-8$uK0D}1{G>T#|RY!@6M>j z$Kx7s2pf{?E@MN$q3yuVEauWv>ewi|OTi^CNu8eQcn@yldE*CDSD& zmRe##j}q@D6>1uQeeJ?U?MZJG=n@)YmU|)vALg@^B^7j-Gb^R$-z7hw6mnj1*x;Txc8-H@Rhs0CpcdM zamJP`pmqR|c3Qh1ugP?r#je&1deSFI1qN*PCP7VGzd&gx;UVK6Q)t{aq2X*DeXwo4 zI58(7JVFS{!QbjJ!N>T=ht6;;6aqBR#Y8jV{`faF3PPb^X*F6KCr<2yjQ#Sa(KAa< zY%9M`|7UpSv=_NAzbbx+6|0usmM5)ndXXc!+|5>jwn0^)XLj9Q7A=gSjHg5NK&|fH z<&9w@;p6ClAGzxtb@05#69KAV1gPfjbod?@0Z`Z$Hzrh0aHYBn-1Ope07$6M#m@&o z+e$S4R$!=>N3|XjS@3u%g-lRY)MxGLp~Lo2YqfQK(k{pbemL*dFOV@k#3Wnie=} ztqqS6lrtX)O7G|?N9F~YAA7`Es3GE3E;X27_#%P734}W041dl92>2Uz((Q-IyO<=` z%QN;xP{ceB`FZF@tGk+TFNGm=aF>}s5dJ4eNE|8&uP;J!{j^Z0G~18G-O=2m=cBBs z8XEfQiB;77d*IDm{3Q3r0LpB{1E_mtnf;03bYcr78n@x;>pQhiHK`B{6Z^!i^?I3N zaJT643)!CDn?6EV3^ADuS#VU_=1hoUW;}mzWgaX8k1;m>76{VG{}9o)j(VGP9ux#V zKJ@NVk4VAr1CiKrO^|UYE99j&AY9Y}cAu_sWP@p>LN1TjyS5QAUb;slk^oU=2(t|t zy^;emazA7l;+7!W+u4YZX-}dM<@JF$0|1}Lb&F{=jPhte|dr}^DL`Y)bExd z<|!k=WUp^>wcV`wYVt#*J*$0GpFnG^vKu*H_AXOV!|x*?+Q$n7Oud^1UO?eGS)igw znj!3(pO!HK3jKBv%QuP_#w^JNak}~vkHGk4m zb{&41W~AqR2u>Inq0%Cwn*?v`s;Jv`C{VeV2uezy_tJsB4%_@s<^;cYtn>mL8tO!cM3^T}_^v5<^6gG(L23JTobjnglhB3G51qX;XZ( zO~V)dja9~XnSX6_@kF55}^E8zFbm27Pl|5Ptos#BVnIuKqLmS<3F z*Xs|lM+(_9bZ&W51=ovX0cH0P9|&?`LMUKQg7N1jq`HIJX1~9bQ^t7iVaB_!*NOL? zIhc^9pECVz*u0*=Dx6?Xxb8dIC7~1w#__BRp@_>`&;ZA~LAD~RCa{nwcn^e%dDTKf zM@g^v_*bH%6H*%7i&;aw^SnY}mwn0C|5pALR^GaIt|Ur#r?sMm#9y)f6XTMD2={hT zTH19T*hQ8q{b-B_yo#7vj3t}JsD@Tr@2(wUkF7&Q&SfZGB5rAWIS``V^xjbpM(2yC z9+_54KmV@2)qZEb)2^!tlef`vv5q6PQxhY!)-16$##=Z;4Sl!?-vVSVJugccQ#OS> z5el0OX;CVAk!AsT>Z-B^5~(A8b~v){O0>juJuLO-`&BoRgOObe|6ewR!W)D2sNaS~WP_OtZ&JdCg zRYvo^*&{6D#F-63luj})lUXzbXX!HY@a4^9lHVGeXBP7uctJA2)i)VJs9kuk@w9%& z^SIWUOP12)f%1CD@M5sQ&U{kZ22YxjxbZau(|+&foLj((YJaw}iz@SZ7mZ$>602ixH`9HTXRFWd*xL!? z{jT?RR!|)^@IVkg zqyy{hM4sb3yBTt9JI@>STpHRV`3Qgax5n?+95u~MR^$Y~w~1By_cv7v?}NAw%^v_h>Z>3Z3zwmFUm z2=&fKl*WFW?+s?(6%JO@h)61`r=o}HOKL62$;m;yE4ZHOQJ0Q8N;M0H;^-htYJO>Q zZCIDXb(O}l4R?^tOu$;gslZ(0C1UMQUmf5Idr#v?!=*SwLTMl6d_kGfN|1!)>S8?> ztVvs%n|SG=f+^@6$b(HRCiw&QNje>bkGAwYGq}E$3#8r|B&JxA4pyJ7RDPBH!7WIj z)`FF_VLu)4fne=^luGIZ7V>D6Y^;9CL)jV52?J)Am!0H$e(zyp*MH}U>PNHIMEafw zUr^FfUh;Azu`-qRO1|z4GNe1gQWf{}q$Cl9(9STdlA^5|Fm|kbZLaVC)GIbZ40pF4 z!n_dUN`u@6uax{ktt35Fujvu>jz1092?$KfvhX%pI^X*3@HF@rBG9xu@T{BOIF)@* z#_v7~;6U)R-)@}gqIVG!X6uGAX*9^IZ_TBXkAAh&`F-W-2AS2a+TL9spT6vIPI=r7 z4a*TeLZuHRa?jb znm#pCCAh;!3qv(Zj8uh0a3|a|>8XKC+iqOgxg$Qgu4MJk+C>&P1yG;n>a2p8Ijqtwn0Ve?9xA4Ee2=%vA&b0C5s(7`h081K z+&mWU_;-nbiEz7%wN~5CI)v+o2)M#4$wK{5ittY-cJPs6Z?R>C1NsU^<@b;F_v7(2 zeB_6>T2WgasQq5t=-mFkt3YGGhM}y9ZVx+HNgcyTuLwSv!4>&bzEMSnwv6S3ci2HE zTsGW-XA*^@T_AGm0vLYovkcvMVQ~U12At#G7{3CY>|e{xy&D!7_ED z3&VX)h3zqfqymKfH4B;Gsm$0MW=F^R9<%D%3 z!5C4ixWUov?Nvb`!K)MB7Zr1%N2nZyRx4G}AI-!X&<3lrwtCEFyj6e#;B=M~(8C49NU{qwr9+6!=AnJBfJFXP? z$hZgExgTP%Xsq~G)`9dR6QftomYJJ;om$HHE(*h_S7P9;d4&$LQ7nX)c^`+__0DZV zFi8!=-1=mrhxzLWkZxgj9P8rG83-LqF#FSkgIit9X&E4G4DDhV@q93%s-R=(U=Ds4 zTjI2)Hd2I4UFf{{@XV#xg5~Lmf^oZJYYk_MT_7l2uRH7rt`2#;5-@)WdAzOYBH8mk znY96g+I>S8Rda}@1Lb&|P^nkz0^f|_!v$L^H#s}TyNqF_rMs}Weakq{Qp&x=N&$yC z?gZ81578f?)GRhjIj9k2bMk^uN(nM!H6+IHjyNtlQn}}BK6A(d<=!FY?&Rj6OO!E& zvBqSxSC*lx6*qYzxZ&ItI;DkHjFu10ShTs~a*$=g87y{3ZY-2acEE1Wwy~o<4fMT= zuF*!s0&O(#|Dv?$W`zY`@h1r%KXm_zWhW)ZoKaZP;$MNTXTECD0U$037s$B=>RDc9 z;B)bGqtGthetWRQ1MdRay$EGh*e|>(RgHYrN4CjQr}?)V;S~*KHt;|ZI}0XQT1u= zmMryC1cV((#zn20rQwROAne@xdhMX=@uc^^>_p1y64hN-LrCc;^|4S6QR;|;?&lWU zzD6-tEPh+U23&~eB+<3(cu;Qm0=)$E;bzP+E?RS!a?j&0$u47NBSmbFh(bmz5=9BI zO2D1wuIva>WaWSELq}%AuA*WhM`{D__kE|4EAB=BB((aZ^yF`nG?GlJ4MkzFe}I>@ zE7O_3&K1=V_7O4x|M8Q2;S9%_)4oRkGAF?ib`YOTt9s{}pX9wcwoE&w-D;>&mv0y& zc0q4#@SJ){MA@H%8e6bWyzKM8zC(?SG6n4^!*-*W>20f_z*Uf1G_30fHP7{mrEsVw zRNFS2%1V6o$pQ*yj4S|TB)OaB^Rn}M*tWTn$bRjO7*?UYl0HqWc&gh)MT8@Ah;v1^ zn0eyo7rOm{MS>M1R_1lGG0*RRM`55EhMFNHHJ5osN)R?w%g>*BC-Li9<$DgJ7jAYU zBVu#+Flg>N_^akF@%J_}0gTK`Jg_?PGyV^Gb;_eyvgDh3x@ce~BT``$h+(kZ>$!Vb zX~u0X5Mg)B>+}NGZUJx!HV==HC4AP>!=g>LxE2|sdm)?&BgZq6hpfz2y%Q-|E;*V; zEKZoD{JiX--2lH7<2#e8ZUyC4f*2-E{z&}^6L?o{O%ElEK{F0nm6W#rFGDvu-GX3- zKm<`MtRYxRK27n)tW-~s2N}8NFGhFfD&Dioc6|c7+t{t4hy z>Ee#tDF1VRNiWa)<1XBt>vew_=C7IL_ozRH5~o@c@0@n~uh+wWhhwdU}x0yMB)SRmLWOfA|w?_ROOQ*c!;)lB3Vzj^%q zCdo$%?2&KBR!_Q*f?np5qs%O$bw$s&z|ECqu2*K$6B+GiL#7J902 z)`P(KBtaAj9#QV!T;sNyrMby)jcikd=e{r+xt`q<{7K#|-=`OW9vbA0A}rkx(RAwK zTc?O5@5<}^L(WMRLc+xk=Cp@1gKYir_1qUhStR>T!QUuxByq}wbvVH9{qj(zze^`< zi!OygN(=I?omQF~q9kP8vEkwde6h?PrK-vyzF1^)o8pFPnJFRfB1-6}Oc^}~A z?(Cr*9V7)L>t1(8!UDZ8HepU|i2D!ZvmNBf*C?c;-ZCyqNarI8@n&jWOap0ggF+gT zT$EQCHrCeUd8ul~bl68j{9@2{WA|j*hR4SU)+Nq(9+#a@kDlwE;bX-k1M$aSX@K;6dvkz+npE zbb4WX?fSM;yAKGz+34=^(E@7)t;ez*AtXBaPbulr8)^Bo zaUKrOiZ7K&LEG&%Mi#6tN){gxO*SwZ2_wj}gYy@liB+IIT6dg|dze}6}KX3?8L_o7c^>1u@CMZO`rZ-MSoSRG*~uFwUEG)FRVGAZdf_R7iejs7lW7WcFW&Wt@L#_~2`Dr#;g%^1;3K z{eeNUBeUk9{@rraOBbD`9TTwk4b!-Qp6GLXbBH{tvd11}#<{sSZ5%Z1N2w~oW6#ql zd&feb|QS?G!R_&?rsR$MkzsIL3{G!YSVvlNcF-Hl6MTN`$^sp9{VxtnK}^_p6#DT% z6W@Z4MKTH*4~7xa3oro^hQ6)V$TR9a@yjUx91xi-z1Gz-NCu;{i$ezcmqNN3##74} zg6s*s%-_&Dq445d{KcwdFU?ww+-$e#T%eOR3cXEzWd`VpP;OO@U>{|yu2$GX3e?tFrhHh75YlTKEBKU>H> zY&G}fiOyG$E|u2Di~ZTLrHHkdZ#lo(igE%P@tv~HIba_GRlJw$E%Fuf8*^t@FiGpmyXC1?A5hfF9iD`JPCg%(S2c`i50#+6gfd zyz*NpTO%#=2J)M$(N{$w!oe=a3<3A&y2+NSGt0EnWad%^It(?(X9HxI%xC&CN!9Hu ziuQMAb4ps4?WTdH@gBT3hRNODSv~EhSt!27rDIrO!39`Wwi_Z_^K$yZzND4hKnsRj z;903G(t+d*%jCLf%+|VAm>4$dvX=>+?-}|@&27w(?oX4JSfKcaUMikoK(7Cd76h{; zu#~O!3NtO=T}B!xfcDlaFf^okB00V5u=6#0iTqe*Sf8Hz4z!Y05PPzhdik{cjhVEx zXf&=bmHcN9SL!r$Mlvk2u#W3&`aC2q^MXeOZ=i@SB$+}~6%x$QowR=q<j{6GnaKUM?P;(at8{W~ z1~LfAz{E%)WVYdK<`UC^=@uW&EWkjY4(p!WUQpU#j@p;bRi zI}K@8FX!XC4eAKu>}Q8C3w%pCKTKl2rE1%Rr<1>u!ZA~!JMOc_o@DmKW8!2*k^a*C zxUUyOMXD8d!YVX$xNPhke;8H=tO)5n(X=>TsnXH1S5sh<5ri5Q*$UQHlzE4Af)vom z+HHx2R!_Wb@+mvC6C_D-btNLP8JPeHnYHztHWv)?745$UF98$;G@(NSqTQuZeZI|o zD(B?)tg35Y)y+7NlV`~BezyW5ay5NBksp4tF9|flGRGLrdjaY5T;$Q9V*tA!4o~F zclk{3`&7TGG#`mBOWN6{I;=zgwhMFNERG%tD}*Y-7r%lmbe9GcT$k;Mb!Z%H=S{0; zdA>%4jM3o}%0#c{hr`qKUxn!7}v~aCBk_z+J zy3aqO{TS{^Hex+c71rVD7fE(0YU%0{CuiO%_W9&)nx^AVM*#STTBlc~Puu??mAjk$ zpAh#)wV(XNgA8?2Ek3J?X1+&P~QP<7?HB4a?#*oW3pdOBrx-bCyg6?XfH&z@8VycAK99Yj@WG&*dI$6m??o9-1 zU)wUCyk%~}Z5Ia z!dyCd-CnH+)X@fiV0h>r?dips=g^<&1R|=^7y>SBLB$YH;zx}-PKbQP4Z}k~=+?g7puD_%JLmTZ__+f z+$00VS798og5$Z#PM`EpDWs&+jW}>0jFSAqT~_exNOPF5W23@@r8hP0e6NOVyIX0; z;n{gE)bf$Jo<<3g#jxFM-&oWxecO1CV==C;7gW2Ye(GidU(o2W-;0B~IL+7Y+-O{B zy6vw@q(JI0tLs~w9_&B|KLyyX+D~0D5%MKH5TUjBl1!&1qq+Beq|$=9tbf1ASR8dQ zcn-oSJ60zqFbXaGMq|zZ*uoN$%UH)7WZej$r;v6-1mW?!13;pWKVBN;ZLrbtNl(rf zk=(Z$-xPc4jEEqoA3apJyk!#TEpZm~DXz%&&nE+GPFF0S$`MN1j^{cs3RBy=e8eAv z%T|bUg}t6+eu7pe#TB4;?Q-&Uo*Y9w9#RYAxfWC<^;+>J#-$Rq;s{F+fuJ0Uzz*>? z#{oXu!0_-VL#KwGtEFUZ)*`$0(FJlnu9wf&CcAw@FBMp8C4iFM3RAaM0)6Tw*)8rh zGaumzc>*eT`$4=XmX3CS5}PbXAJ8+l{}YDSjj+y zF8EKp3zhGNVFvicz=z(3iEluGuNYAp{34mcO>CcX-eLl;v7{COQyq!d{+!0fQvKnTt`tD#XwZwEz>W)(pdxr@DAaErAm{182dq;%rq5Nj)FPj27-YL zeVJ}I?ou>?hnHnv@iWDrd_ND3^{cb9v-0L5T)X6|63VZc{HkQ6V`v(#J)&?={QW z>#NhW_r0QsxWo>g!eeO)m{}YCA&=^Ke4@MG(^^D3B%=MJZZlLwFPHM9M_lh7e{~fm zWpb42a-sZF8{k)5(E>~tN#Q07BL0tGI;W+kpDuoOOoDOIy3B95pXkZtkH$>ZQ(WTm zm*~agwPOs*OS2%&>>W9tTmke3hVp^n3M)?Rt7`O;)!@a9GV~w;OfAb$?-Sg<^RU|p zTpFjaEw6ejQ7B}~_F>|~W^`EZlD=<8R2&obsDsQ&FW1vP)oyPUg-nl;XlY&q=_bvS z1=(HBi!P4s;bvhe7^#!Zz`myZ#3@P`D(IVnBi*x%Ekp-aJH1&He}HVZ23KEwRWeB4 z1lrmC^w_pO?mt##{sNWCss+6ylstw$iH?Ptr#ePpEfwxI>WKp@w{{T5prqEOU6a1oQKkBv=gE%x`Xo>fXFF99sduV{% zkK3Id+l=U~+Q%t%G0qrnZj;;_Dny%4C0QlFKakAbVaiz?lDy|Zv!|6N6vCW5daJ4q?tciuyX|L)XY>mJOtFCu8@8xmkQ^%g+TJ{f z8(gy5k)2>312RkL*1W=UQ@TSTMhoO2u3YMW4fw0PsKuv=eQ1cGz?@D6*TG95msf90+k(Z8f%=EqBBmXEFprxL~{`ug$n`g6L#P z{a6~<=rXRP%R$& zcjO=ZG2MEq3U9olnCeXailvRkketj2uz{Rp0GS%TeNnXKewv-TS8!MH z7hU@Yz=g{Gbi1AcCwQy&WEF+9Gyr`z1RWeKicd6j$^;#~d#Ds%@-eXp{^5E1;sU;5 zyyfvGTKW{B^VgTG>wH@bvaXCFhAK=5X6jm?hBRkd0BioIQZ;}_`ZqQ^b!!g2ez=o! zw};c~kN8vMp8yHG!(NQM>)a^t{v4nhPB}Oc^H8ccy{F~29wr^=@8@>naC+-wLQww- zcto^ND81kGc4Zh9wNYpvL0ljFkjd;YOUyF`;C-bBsX3m=1Pw0!(xa_!TTIVRy*l~= zJS-1)@R$8emSUSmjqzP55TQq%r3$wxNRj2m7h3z-aDg^Dkb!K(0p8M! z;c(*Q;&9o%6q_iYT*R%Q*sX+vm z9SjDB9nLss;q>_k?)`7b@6U)#Tn!MVJfYeFI~MXER9b9PfRU?63S{zO`EV#m?EPQv zIY0P_5LyO>b2HEXoh`3XyN)QYg>5Rvk*feVPza276r;(^_@r{^)5j2V@}kbsVgmGR zb$t#Hr_tEZBg5BSrrkcdu^6Bfq7h}12ZRw50My9wYyOgSlQ-KLPD{$p=AsaC`ULRO zGo*^w%H_+V^{cHh_S;Se*_k*w)zSs*Qy5H&1A-C!|8;Q`9}{50!QJFP%y}B%n$|79 zd2WG%X#(#iAmev*qsSZM`3hJf0BzfC!rpDU5ui_S0Nsvi<4s_GB-@D4qr>5>8D^H8 zWE<^cl4Q&G?inDXH8w>3(N1i+mJ~!KO!(4D7MD}QyzVuUWWbqj&BMFA!+XW=sY(jV^%pj%^m&(%3*V4#&7 zHMq^t7^SL=`~_G{ct&xVfzMO9I7y6Lf-6&u{LiOV_;i1&%a6tAdI}2e1N?pp zqgq{nGNeoe?5tqaq{+I_G~oJ;pt&E=BX&D%*tKTB$BVwZhWGiYMLeMrwdPt4xIN%; zr=$2(*G&uP0S~On13+S@Ws$6UM7xCmkX4ktY2OpFV1Qjn^L6?s&=LlwXyl`b*yt5v zhy?;OACGQ50~YI@b2Ry5rA8kwFE0^OH2?|zW99o#?7CcA-uvOe5yw<5Ts-%``qn}{s1!#{JK={CBW;%BWMV6z|(aF$lOZ)D+}fpz-USm z>C&>O9Z>Ll#wh#Jb}d zPZ<#}OR|@GC;0C#amfQXizTF1gka**App|q`^`=dPA(=jj(dpV10B1aTJU&G@17OihoyKL7=zig_dqUuxcLwCo&r0|Lh6)d4ei% zCtC#4EeHx?5;?GnN{fW3$iJ@?0z#w?oTdWz*Z?K;`Pyf+TAjohr zS|2aiEn$BJ)}jKIC=LL?x9bnp>c3qh39hmD3G2hCWozX44Os+pHM-_kv<_7c!6@S@q`HC!2Ec?DzdHm`qGctXNc3Sy#=R~U+Ep%T}Q@J zO}v^+pC(i#{nl19-rQ%ZLc{D~XL9fs`<*L74&)H8Z0=@J%kS_0@W?W23~`IyER^C; zRs?ul@9nnmy3Apn{w)x3vp=vegrlf!%>#=mo9+Rc9Tg4)k(j_2OeeL>y)#_S_?bHY3SBjarEf^dUMl%^l3ckhm~d+?Ct@qy=2m4>v3A5` zOf5h`UF@7_{${%V^g89>^z@U=>FVPHb60j6SC^tjs%*G=Qtr6A>_iOD(r^D`0)khh^ky^Q& zzZ|#{Uh(rnj{n$GH)b&5C`|E`I!70in1fd8u}VW-sw|RH{Hlmj^0c1v=DJG{r{w;h zOUBa$YlU#fW(h0#Tu@komm8)cl&@8xp9ZH52A$ieKDjc`Bwk7`(OR;M3Wj)2H94hR zW9h1NSxXEnE_pi3dHc8skzN0N;qBHBPkOjZ*807b9J_Y6n2xK2n2ph27 zj+jq>SJk`izTP))yKbQ>;X_c~bawXG%l-M@gsc+#x)ceiJNo3_vqaO&2i=ADl&p@1 z%XDM$2lhJ?#eFXGRTtX8uyKX{tc2rDf2}sV^6rIN)y(ahzvs}?ZD`7^sMMQsWQrTu zkKUH5z3QLX4Yo@aDcF<``PbY5Ep%OYAe*l6y@dtGs-3F!DJYK9&rwol&tEMnq{ zt@lp#X(awN`~2Kksr4xVwDrr=+P(3An07xsvtDX$zp9Sk>CXv#c=~Rv-C;0RXgCx1 z{yep&`Da@P>PU{@KlW@9-acMbM5SGK=OJtlArjCyYc=)pbfzjXuQmPW7XUon#X#Xw zsbMd7b^7H9K0Ho%=3Bn?fAs-dJO047vIbx{=;arh~nH^Ok?Hudc3sz0Wrz{gdyg zAzxyfz2k*+lE}|RNhjYN$&Eky!QT&fy$z!?C&SViu%Q`C1m?KN+0lv9gZ;(&f9X4e zm#RHVrnVbJ4+>;j>`tw6(1*VqHP@@_w6&SsTsT)PriXOi47x*` z>4!*O^b-iW@`(QVan|69;`8>HK zBgykGf0cfYnr5rzdiI)fOHnA9?K+41Q^W_Kn|FU9t2zlyWprL?-obNRd}uM!tPC#I z|LkNT0>b~`n=vaSGvfa3N#dZciN)cYGCk1bVX-|2b}J9IspG9WS7UzS-r|a-#^JC? z%$v0e5}cDN_iG?0-tJFqi}U4+J!|I$GkC)4irDS@RHUD^EJbQ4a4q7c&H@V}1DIiUP zVCm-sX#10)SP9dTz&zt#63ZAcb!`GLKjtAQG9%OblC%I+Ph3qW|9Gk;JB zGKQ=@|13TPHYVZJ>bISz5Gv(liq1)M{OPVrS`lET(4KBF3qP-;a$A3jgF^5Dz43{l z4rkMHu z2vWJ8UXV){W{nL?re-@VWnEqH&wQFSwuNkaH$kerxS8lTAXX?q1I-NP&V>J(b2cud zPq2d&fX#)yjuT#&Wv6`JL&v#KH4S7bA=VfqsBmTI6hKR50(%vOcQmdEm!ZtW{VqBG zNp`ySca{xTTY$}X!dvUUcW_5&pB(iN`XYzMWz0aZMGXzxw?i-C zd5ZAR5|(Qak;rm|*1`JBIp79>s&@OJ^d#yMy}*0A0NB1V;70Ihz~ZW;F6Q$m_qh6R zwNIwnew-Fgw$1vVZqKf!56gc=^Gu!@*&n)5Gc2^ee==EZ&rT*;5#vhb1`xo0C%tjz zle||u`oUm7<4;$=GT&p{Mn^^RiR;YAQl;xrCHI}gVcQ%am zk55cl%`DYhos?J(i|e%1s~_4f7}o{xj#()K+7x;t#roOVj*Z0Bs~qs_a+q6VEQBV0 z+TVu#7z{I?`%$w|DIg*^d}NM^1JNg+1~zY~l5wCEe`KMkBh?Bv#K9T4j2QZ6^|n2$&vpz~N;TX*BkipC@%jvhOp@8Sb-B{ey5;y& zWso#!1u-6>jE@!vBTON&V?b~+e`TrS)y!%!!TrtOm)7OojN&&}Fb-X+Dyqb!=@}j2 zcfX-u^K9a3#E!n3^Rl;yIG3sv`D~V5V$DDdBRYDO=1SFCMH#cN}6&s0fo~ z{a1WYf7>C_0dmN@1A&EieAasPw(pAk{}Xt>hIplLU&)bmTwQI#t*&d1Rv$AcuMTiU zt_ujOjUqx0MuF)*-6u367Iw62BpTcTY_H0C+Uo-Yl=uQVF|N03Ud4cyy==@}3Rt#) z$_#(;bN#Q*nP7k3)J^zS;wudarbWD?fC`i7EDr;n9eTs%c`^vqVdfVsd33BOkD)66 zzbo+St+5;-EW+~bBQwk=K=fCVTh>MYvOqb{K@NS9@1~;>#3Qml{3^b3E366aA0j4E zuqNL3Sy%AzF#iK^cfz%+B%@@Rbe!2~?j3v+)rt-(+rA z{mJ@YT)kyj9nrF`i)UiNg9Q)n?j9V16Wrb1-Gc;ocXyYd!5xCTyF+leP1Zi^>~sIm zXLgV7Ayr>}^|p&JPLq2BZobiW>U80*d=lwJaPC_D^H-O7mFr`!3TBL@PsOh{~i~5Xs@V=s4KtL#| z60lqQ|0&1dTh#8bEO9*gBUn}w;)V?Lw!)B~@)+Xj zuba}}lV=6(&-3h++ef-#zRSoQwi_;#@{N^--szM^BMITIj~Z+?iKX9_6EZYDN#qET zi>$~}triCoB{*x$m;TQmWBMK4a!HP;oaW0cJHrZ^4u{i=Ki&-0nol{vV(WZRhKj}W zqt$s8D^=fUNVsCCz6e`?r1L+Pjl*{ssM)qp0wk?1T__e@s zkq_UYVqyJu%jU$3#6(B-#x(NyD(rXZ67721B*x6Ah5|p`!jee<;YF+=i6c&!05J;I z+cPzj*Ixz2i(nykGZ_o1fv59tJMoT|B7}RO>>3)MG2U9MWj0Cx@G?38Z{&-|?JSUP6i&5*Z&)>`hId{qXlR@eqxq%sg#`^6)fFU=&65=C8-}cFn$iZcRl0&XoJl z&-dAXemy+?=S-;lsyjGySZlOQ&&2CxJi4DXR~i&g>yUcm-K;!1V80CpC`io~Kv|0r z4us63(jba^xcbCsmf&m|oCwS_qrqsu*cSaTm>>~wn^R&pWGrsGN*c8WiBELy%x+}6 zCS2=-YuJCt^G25t6Cpwj^W*6 zE7X8d=-Eijn0BLrkpBwiGZ{SQA!)paL#fed*_GL3xkwD)VPYcGNd9y(huvFz{~?@h zQ8h*E7Cx9#mQSU@V8M0A>K>{-x=m64&!SIix+Ur@7-b29Fy4rQ{NI}zHn9KEC(euhN z>lvMswHcvSE8h+kP5)0f(ZqIEoloO?fWh z<_#~h$-g-4{wcoM5BdfaPV_LH5W@sQK`fT9#qn)7!=cjc$>!`nuRjb&!)$&&A}19l zQ3C2*_Tt1}WIVFE+7w_ixMok|oM%!`|TpR}^3 zq)cabpVm4Y=O#yP7XD0NXf|tdtPB$f1f49s#3@xwsrU18^LlE47UIxSKjsuGV<~ao>T)iTK@encQic;CRBQz{ET+x+^jY zUoV*;$>eQ_WF}XoZb?@03wXN%8yRHhMs{J93V5|n*-S3EZS7q4-9wcGKwmQyOq^D6CA7_-IZ!g{mA zZLo&h~AH+ijWkGxU6Bo2}NUsVFWDkm%yWEbUzSEv6GldiCS6Z&Djil z559ee)!>F;t>+@K+TX`IPMz}a05VG```q_O?S4}4)mFTj9S$R7Z@ra66*@O<7Q#_g zJ})D8cHSP0-(}qXQ0;7%ME^58fREBVw_4b6E^*p6n!fDd-sZAqLMQ}Ev(AG105b;# zymVfr!n?}Q&`^9VwIO$*?yWMBUWW`3mn|OH^Xi$bMyKY$?S-<+qyGU&J#<|L2wZvk!+S(97a97^Wc3u!~H>BJXJu@^37~L@mZIvS%r>QC`98dmm zZr1IRa(OMm$F|r@_Uz}PUGV+?BMNt)Zb6l1mve1O#lm`;AuQ^Is`f7hHIsb<_2>hpXV67iW+0R{=51|Xp&gwwTBd^dn_jcnk9Of)wHY#!~tSMk~ zT342CnX$#e&O@&0w> z3ZM|lj<;@nYPsJEhjz&YFbvB0UN`bB0EY*VWk~kk&YQPCJ^+ng1%Tb$elNvj82Ab4 zyK7>7lL+DMVZz4{zzfEb%Vw5a6mQA_UVxd2+3z}5kC^~Jg9pG+?4^Q3!eEUE0KXbr zM|rFf<#{@(>ko2yr%VA(h?|{V4_m#(%8;p7ag=C%{#RZ%gy52@;zI&ri(mMzbED& zixWe=KI{{`>%wS7497u5wXkR7X0d_W!?@~$-2pz5fCXcB$)YM@vCahEnpYKKayn1* zaVwmD2Vg5DdYrY00sJpIpuke+2~fnm4K4v$>~ecR3P3*@KK(tl2dD%?00N8_peQou zgHKP_7>Hdz0TirA5&sqd!zD&0lWIuwNtW*+za2ovO&81OW#4rG_K@ic4fY2B36u{2 zsuT*PvELh+;(-;Wox!MNZ9OTmmd>tHS?ZahR5_;rl-{!=a34b3?>1qux$PR;d5of2 zejV?9SsI%x1gd6$gM&(humoA$f?o~}gL4@Jn{1H6PIuipg^!W7cJ9lp;N^kg{|Y^C z58(3Z$~(O96L`5h<1q*N=o;)DjZ1$uBl@lT@KfwoSyBBWu%Ipu@Hq%j@y5s4({dKx0y3YP)VSdI@7y z!zq<0hB;PdPWQsa%nFcZ@LpIL9^u=hKP))m==d;d1}O zmC&On_T97Q!$}^eLiGl$aRgCb;c&eg{Ohn8-QaBhPeiTD7UoI!kksiY@6s1_ZDzDyh_S;gu6{a0QiQ{%WrcU(D zVVu!y?!JQAV(T0hiEo?iDL0qo>CtzIEPSh<*oZl7LD;=jFYA_bR($X(tWLl`nITp(` zMIh`DU#vCCo3&bb6vS0*Fvl25W)(75ks6a`vx8g?fW+hT{5q^;`sMJDxj_B)2wzk0 zB=KgvRC&7A$YHVZvcb26&Var3?};)Lmoov`h>9_O9xcI@V=Ikzt4hc7kwsq6K&@{Z z($WjCEs7sb7wnF6pO1=SO&*a10?9x%ssLR2VzezUaiK3GU|+Ld37ejtmezzc^vwrF zLxw_qf?LV<+$O{#Xv){U6XKxJqfp2XWwBaP8s|EuIb3Qe)3oh|4h}@XD1ufyJ?hA#bj0hvp0<8d__+Vf+2_`RQr?oBJY zlmFs+jzgO33a+~qhsZP(@Ih6iYT`*(Hh@d%iN-EK$U-LpH0wWZcnSE{S~Y)W1>B9Q z6p1uy`ho5%soywv(H(}6_@8yRu-*Tjm}w^8XIQ{v@}&l2eJ9T}3Pr5>$hiv7=odJN zc_yvO?m7-&RE!&=jeiBoJ6{eA!<-_o! zAfBv#$Wu6*rZ&2W){{SBtSn)uVI; zPP@RZT-8s6Q0-`;dA{`m2#+_11}pK^?z&pK`{AXw;gDiVSXIBt|CPpA6dse?8Si{(y> z(ZqzS441RJ9MgB{1k}EXo0UgL!4@&yliw}1hO=a9%?G9e+uY~v27yQ{aRB+2;Vl>* z;_Ra;(!I+x2{F2+)2#AduWUcYb9M&b^ycWOxv)fmYK?xMfi)@XHo(L)(nMXEKAQcd ze56E8ueHSPJQDAeDG>%Hf+kTV`IlfBxe=ulRN|=y5B6X)neDd% znElik5x&6(dW~m{44_^BbsJo0Kjwa z{hz-f05|iYaPf?!71SQRCMljrMEvMzL;q|~nH7)_0cn{LHhXRxOY!%yvXMVdBEACi z+Ib>ZpJEsd28>o^N=2ArRQ7+MtIdWFuuK#TW7GqDdBIoto4ALjWO6U)uX@EXx%2LuuRLzr{a6H8U^jz#Q;%nmbHXD+K z*ax1gK@$63M0PCjEYVNcn%Eb4f0LT<+2Lr&bZr4_;1Iy%9dn8%(WJrx0{a1ckJ0?N ze4Q2MNzW@C6I|{+Y_+R7Pc;sgcziJtLQAM>|D@)76+l2OavQIQ76XIEmJB~{qH!)z z)`AVlv=t%r1Ia<#uN_nGw4L82$6H)~DvxbDw#&TW@pxG@l@)w)K!MmTWr>L)5tCYVzkL}`ji3f? zyc3>9Qjn)9!9=FWUVmcvYMwx85PRkg)w``-(u>9oE)FR8X3N50j2pYVyQ@q(p0E8v zgcTf-MJ5sA^3nnzuLGk)pJH;cW?APo4OXZLaac5q-}#E=IN#=f^oAg%HNypZ?+w{F ze~SpW5Xo-8h4Nx1fS|Q$+5N=)u@bpiaDzUQG_Xg(&^m``bkl#ceXS)_1Rhg73V+q` zI+c@-i z*}P<=yBO=yypE#$+M8*foyf-aw+>Ss8Do~N>PnoOpfYHGSU+_Tc1K6j6jl`9&qBot zbC%}0b%0%qE7$7zz?A1SrpYP8-15}|QgeFZ=lPv@3LIi2DWU3;u6>{(a)=N=#SEKhE(F%ny$bm#3!r)}rpfMq)ZjIJ&q%UgnJ zzC8STa1FcCJVA1Lf77_qzE8PlBjdV=+71O`xzUFnQ^}IBX|6-Xy38RX!pE+|(72Kr zf~N=Z4EbJBtPRNhMw_MD0l+iaE*>I*8B5;1)ZM`p$&zp?x_fgV2lMiR;4^f*{{NHi zA6-v~a7HY!iadYO_ehwzI0{RqbXiWEN zm?jV6_L1~f_e_yO7<4H&7X5jU7bNO}@-6KsUdJ+@1cFL{`|NDLZ6;aVQcx8DmF6>Q z1BPkD7NUjtH-6WP>>A>&u;0w%SptYY?uV`o7mp?Y z%?DS4k$_z`S1(|>U@kIAV&Lh_v5yJx#_Oj!#%c@umZPF+yDA^8*~dN8LL0e1mY0)m z1xP<~9y>^5y9c_h&{F6{$+DiofXFq&MNx_1{XpOV_Enoxg^YYl{f^u-Z1GVVGkQIT;PjnXkWgErhGlX@5 zmR;6Y1elXQH38^!uBtHnYeX==#G0f?5vKYSJrtqrzk+q|ovHjX+&yx7Pn#w8>BIu? zURaNPB9_7ZKP*YrN3nIG_)y#7qSLx~2q9oO0w7r}LKG@}XnkljAYnVwW0lTrW|EgF7Cvt~$LkNQ;_ z0Stom5Hn!cZNTV*% zG8NkGN|7V6L1-0`OgBv=VIqX`3ad1af(5e&LpdKAA>_Gk$HhLW@18Ub7Md8-79rL5^N5^tFcyhs={rw5i*DZB;}8S+TTPRWXNX2AFMFk*t(@p)g@fA;icq z5M}F6U#(kDz0p8FZ0xH0LkKogz-0Ynr#>mLq6y*)LZ6bQ%HaCwt}-ESq1=0=DT+6B z5{+)VL|*uH+0;!E5>8>odH#h8W-`GQkgO(yz@7s=XVqvT7x`ZFdET?*sOgHoVU3aw z!Nd!pA}`!isDi^s()!O2i=Y)OAtv++sME_BL!~x)OAqVVMFQ%$>cY%UJ94>f;TQ2b zD^*Cs?&f^~Q8UwP&dzb3^S@l<*z>{w=CJUx-7eT}V*BCuQr|k!HnJ`g6(_Dg|9c0S2PTZ(Bs`CAWR_W7@d;LYE~nVDUiV@u-p76vNa)bv$}lozz=u;Nm2qxks3=Z zog!RoZ*LEKdG670CF{;gmpAi^etCki|1F#PFhBhETtS8rm<*0z);tB>Jg@fSBMW>k z&h)SIINN^cQ5R)o*9ZBtR`wl<3xO4->3(xNetTMf!#VyC8K2nUb}q3w$UNwteu+jT z5aTrTUtM?hkF7P@lnEtIJF4rg1UUvv%I^rNqr5u;;m?EWY}PAKfI99hWGoSh-XYUG zkPFqTA1o4OgSn4SiZ%pi`rJ=o9=CxXx*KVN)ufGX!M^Gk_t1FZ_0G_(`EzL{jNR3Q zw42DBrq+(x?zyON_@_j8+L%VM2l*T=XS;yry@V>6Lu`WkCL8Ew-?zQD{lqAy_`?|; zJZ(cN6naUF<+Pb3`_@q%3#e12gCHR$0HrWeBWQ*dTo0O^Tss?d2Ni~@)AgB%&{G`Q zGZ`{f%AXDrI{I4%*KGUfqxBf$s-F=-g|vk?y8PA)%p=*R+wKgesc6S&CKAI3p#5$_ z8UcLAq;xr$5PMG@EnH9|vJ7%1pGxmJ!uaS9hK-T2U+rJLXmNLniey2Me3RdVboVW; z+YMj=%1$!%i!LXHW^=XLr`Nq^CnyqxIeLsc)V(3k@_uxeFQGpXA!M`ow1N?^DN?8l z#dCz@%TbSz?QLn*Iy_sVIfI=vvh3#qU{bQ}G_(%rsfdUu3c=Dc0)-KWV4ONLNB z)cQrwihL+UnK&n+rvQ0`-P)%aCy`ETsY~EQBZJE+LgC{vk9A1C=AM=b8iJ}6?|uS$ zpUbLKlKeN)T0yK*r3zwRE_S2`H+Z9t*~+Rf)R?kxyVJh&6ncXgp6+$T$e;h!k$d-e zqs+~Q6USN?9+kJldZogXF#Ivdlt#-yWzV>&KA!1rMEb_0AgoTSTI3vTKZfk7CciAs zmM?ntq}!<0;_aX9IE9vlU@%tvs;!Z$e*{71pypt?kcYxo3XW2$f*k^ORsCr$eW89O z1POubZTbUj24w9F@B*mJTux$dh%zW+6B>LdZEml|s5$3iV(dCm@iB8{A->=d6>pk` z(hs03YNer_ijaDNgRZownp>+Tw-c*MkJ_((vP`Ulp9rK!?_T$e`O8PKQ4QyV|Bm(N zfA!#9JhI(cwQiLaP0BZ|T84kw?I#^No=}J*labDD>+FA8G)@nn<*fF5>wLZnFA+gC zb5O^++Hd+4`~(FL=_6Y&g?cOm-q*H3m7(pni(z25!Tlo_f#$waaI3s;t)KETGKf$B zs+@Slo7`_rN6s{JPBadpm7I^?Z3zXWqZ8J3T5gg|fPN*HRx&9Lj&~J66*go8Ji862#20Onf z?h`ga4#BT(&gVRC!LY5be1R!o-oQF1Dd7A`VSO=J1d`Y{@dMGyBGw4DTv#J+_lg=4 z#4!|^nUD8EgSkplD!V4#uQq~)6)5unv2QWtA$;`YpLhiSo~iD>)e0pAM5%Qi$-+PRT_+oUU2Oo3Uxo!K=%;lM}J}IR&B~%t)1PxPhwyY zwf4`<;}ttSSB0|2xH-8rfvuODyZGqqxa@*9Csv-*&OPMX=RvG8)<3K}H7G+8yg7)1 z+iRP3bGG)Fe1?j5la8j>r99Msul8M4|3(}F;Fu=b-h0Sief+)4E+0rVCpaOA^?AVk zI!z{j^d=f6Hi`f6W3gQN9h~o?8(2!wE#sNswn4Lb>Cl8ouJ2ktBTC zzRJmyh2wr;(^SF2ZA!2m=xGSc9ux3-`niiM2bZ$t?JxZ!3pDq5b4HrR|Mr#17?F36 zDYOSLB#w0&G-k0BM%A0zTlV94ueWq2+mRC9MJk;IZ7sFXu!SI>yJW<-7OZoIt1ipXg^$H zs@1-iwOn7S?UT6hLB~a|jeXAis9pfZUAtvKttr?A`qLU9N~YXlz+q@IS+ViP^BUHX zL}T)Mmv@3%DKbMX^7arBG1FdOo^T^MD>>2`(S_Qhs%dfww>s+~psz;IegW`o^z?n& zDKq5&H7r~CEj*A!T^Mrt48=sbRBU8GuQKx@NSJ53luFMw-Ix^N@-hB^gOkYiZ&1^4 zAS_yeV6P~rR%xdgtltCn&bSSITc^9?yUB&5Ki>o;2oO2aIOjqq+%zHJK}}w zF#AxKAjG-dRpBs;HI!)AKraYqB1BaXiuxj`ueNUMzHFq3 zEef_npH=<2+T{KH(q6Z-K5CScp7#~QdSnQ|_`>Q6d*XkIz@cr2x0~OUtp&P{F{~AP zfNsWk6V0k2Xoy)G4*`{5E-V;^#TDW)Jxor+O2GG~AU*{h6C#{)^ALJ;xJ#ur9FkOZ zbo3yH*y7m1vvDg|>l#a8rL%(F%jHdNZ$~2gGJNYtl-^?5;+%9`jKcH>Bd`t&ND=Y| zKCM$OA^Y+)2>+~SKI@M6{GHODU@1ITaZe2pW zV>U~{v`S&ygUOiaj9TWdjMvd2$k$(s!kK@KM48s8E*lR!p3bC7jr7qCZyY~oybSu} z2;Grb_uv*^1^Q41o6xPCyQ*y=Ldu>T@|myD;4lidqv)!eXP&ghejmn7@G|br+Z^m) z0Ijju6I>}}oFn2RR5RA$YjLDVrg;p=`KvVrGzE(`*yB*H=)^k=Ga^DeTnen^UKuMH z9X>gPd|x(w-l4C)I$lj?_sRJ_@i4>L(k;Nb`81mzY4dQkOXbxxh1J`qO1CA;uGUWx zt+RH8!-!ChWY}yS+(TMomj_dPX((}X;At`T#pgouP>m5wx_A&V;Ovxi8wFe^XcRZ; zi{KJJRBYRqDLGdcOPb(_c6rycBjpKetU0R}tqe)VZ+!N6|dIbFNkV!uVD}qu{RT1GYBVg58R}f6Of|ZrXqj>7($NGhs)IO-&^^ zSB>s$x9>0k-DnOsGp@RV^qHo$F>F^8*?)a!H^uvOH=fD8c>;tAaH&T}VIiKE@iKYG zmZ7OkSs)f&6~FFA4C69aRM#Ly3O#bp1cky7APN<_0kubdM&}b%40gkNxuU(fzMA6n zY6c>VZ*0A)o9ivF-PSj!g-^pdc-L}-ZxtTv8ty%Ui4=88t%s80-<%?iyA|LPX$6fQHn<57{p!DG>0mYmRj+{!6ty#LB%OV8^kPDQ=}2=m73fPohry@0 z0P`(6J?RD0n%HhrDIZ48&K2)lS&4|RSE9mYG0lEtOdjXRV8TT>85xzebhOjM z;m;(*5p)ag*8Rq8@i)Z#uplXZOD{bWkhbOBc2{_2^X+9t65znJVzC&$8L_bj_?cm) zCB3|m+GRjf0l_4eq4Rg%15UXQG$D>0|XU&~!G z(JXrdyofZ?Ok%~wtqNrReF^f&&zT1}i|v{%PM7t`k~y3hWHz5oPg)#8;GqaF!!e1d zN~7lV<-aA7Rzeyj2J=RGI;>!?hF~cQJQQ9XNou=6Wb&xas~U7DHVb=|?iw5X1FpNi>iA3COA(ARa;;^Ij%YfEJf*D7arC-0z_j6aCUYcoGt<`>rRKM*V ziYxMR=SwFT&U&*RO}qR^Jdy;$L3ll2f{D>*fU;^aB<3dGsERH~RGG!4ajlTySMlbQ z=ge0aRLT?e{emHwX{<1GOHRO|6h=SuDsRSygF&xdbR8b`+OS7#pZW9HN?jS} zqs=d7>=fpLOwz+1ek6*5wC{5WqPwuwQ?8{0^kz7)qDKqBCWQ0Z@se=`5@EN80m2ar zVXx(IuFJ7aI#No6kTsxN{KE(iDvuuvd!{8b8YW`!j+w(^)S0^ak*%vPgKdYq%v0{1 zV&vpgUlPwMs`nkq&v;sxKvX0gI0Qac>-14;d<1p&eW(4M(Q1)rity3ugj5Yo+4!-s z3%jCY5Jt0ZquqH_-&NI~!R5-7dXZBrO^VTyxS180_=C-CX@cbD!+RKE->jSX9sf(d zS$>*wPMnAjXaA6=V8YZ(h zOqyOgX7?NTogOiFsMW< z1%FY^8i3w1GQ9j7v{>c(C`I5@M-3av(fY&gg|e)q#Yw&pv3D($pAk;{q0cu7sdJ%? z-1dcMjQ|@AvpDohvND583b-^0Skl!XM0({;0y`QH1Ear_o5dvHvTEpza3#el^u&}` zaWn1q8ixAdnUuUvf{qt#Fc07?8c33ZB5RNx6ZC^gy+fkOaPU?})}VNc8vCv= z3gG92zN(ZT?9w4*ou#3n0+2AsbRJ{^)Iz5NCVhrL!$U)bHo|P-K?Ln(C)}{m6E_s* zSCXg&b|x(biiD0@#{^viGK|cHYpGy59{Sl~`6Ir4sfr5{m_f6Bxh)|hxxDE6iC>Z82bj)KGSFLEz=|xjKDTN;3 zvX}RyHc$czZaZc~7E%mzzB)&HU_5-X(h=@BH{8z38UgH*O23PQ z9S=1aa{VQ_%ovh!ncJetND>UXj%+tX@!(~bn_(SlJWs#T{M0!*B;mW0vSzULVS2*A zefd+efX&XmGCxx!+evva4I0hYK(PT!c_1TgqRWK9rK9;A#m;3ik%?GZqf=##_oq|= z?pGV@EjE1;Qm`M*RCUo4(YAI(oQJ-Cb0s!RGqszM&vg#++k^}8cG!L8W?SaD&$ia@ z0Ad^r(RK)<#1CNof+$)!u8@KOJ)cARaC_dd7qr+0Rtr(=b6f~8A2n;tKOkV;7k+sX z7i36|H@L56kEu2`*WDQlkpK5uT*2xcVRe9&vM3X_m_LEc?GYr_t=~VAMfWkdMmIu) z9i}zsB776@qChwtGyAZI!^oTM*71PvCa_=-P=%Hgb(M}kL*)7oqtQHCU%OCcKND{y zVeD@uD_NG@*y2AAI+K8NX5$ z+$wEWAc}4R+F1CtY{%C25&rry>wXfi%NF*+Sf2`RG-+G|j4kux`i&AJ``TDDkG`s@ zWW&MJW?(0kQ8zV=dY-XhA5tNtHyr&ZOxEQH8=?7R$U@V(FFk|XbA)`6e?N^D>w+up z@0&yxFjog%tj3{$mBl1GP&@{e+o(zc&vNYMoN}IM@e2Lu; z;EOfdVh0vnR~aJ=hZHR}hp5*PSyU+}hSpijgM`)c?QPnncrs>ICW<3_U5nE8E(FOG z(>m?Syo=?_z};_7GW3&d9jovNm_Je{(f@8wIkRATsIAQ8%M(sQ3Fua=FVImX2G&I^ z93N7rHifVInZYHJ2(~Ck8SG0e)#bR=XvoWuMpE z7j-2&0k`|Aqt!$YS1>#xWOgLa-jLg8f@^U7BCm)*Rz--;-LFHiJpJAj!jq-4DY+j4+XBfVXRg6e; zXe}P7Yz-l%@gr1ew0#0H;+`Bl9b(`U%&!Q$i*>|OnVW_|+TMeP<~vbXXdy2l-olnE z2(zT0S1pzstRnS;>~r2mtSYL9`HDi>(dc0UND~!8tEnLTSg-~iCW@z;@As`jBh%c_ zpp-Zc?#6v?%w*RNn0l!#3kO(MNDjD15cn&n?@93{xY4))j)uu`ab&nyGO`zQ+zEMF<-y$>R`Qy zb*ZuV_#@u~4+6@;9|x?f9cq8Q-gQ{4JD%Gj%PDAjH!5`;1V<63qEZAbB5wp9+b#frkrdY1H}SQ5;4x0eTh@0}R2 z`0!YH!MXI^B47Z+q(k2W{NOs@XK$qVVXz;0BhcESV5T=_1w>d>AznSHW`1{de+~8W zql{3n__#iC-(~ciaAti76eaoQcZu+9Qfb!{-#8k(^9>vhq=)};seQ4ILiyPm;b3{Z zmpng>Fq^=yg8Oe?Pn?y~kqORvWh#+_vtyMJii~lcq$Z}t&^g_|7xcI*Eb531!4)`) z%VExP5_H6Rl>6rEH@q9lH_fwZk`5-jd-yb=`7B*9FeCd)TM$7js_l~*$h6PqRh7~m zM29-sFtLxl+_fKvd33cG>=*VvrlEk$AJAHSs$65#nRV3-bZ z`$3}*`p~%sd-@K!=Y2)#ARWUX`_Q9B^s&@Gj>H~4agoT`9##$Rt2LB%NY0k)KlR_p zrWjOh{XJ+qv*>kR(TX~C!a_qw6E*SvjUiD}ooJZqeUjO@VPDKaop8ItT<_5g*1JNu*hD94p(qGoF=7gUh;(O zzVR64xB01;og)|!*bkvqthsw%$cToC>7v--jw6IHzL)+vq{1%mkwTjTDH!Q&k5+S< ze-uT0r4_RQ=dM3TI3z%Qq${lY+DpDzPLL*x{UgROuc-)XP$9$DY_ati)fO`Frh*o{ z?EzVVTbwuST4$FkqsL3dBJbZ5A&F}4M*j|TRd>LACPxCCC7zpluUicVC>W7}fnO6I z<3%RC-(Fl2*t4*IM)b&L{v2xtYh_}LV!w7vlXc?3>s9Vla`0@!PH!2aRw?&QmiL+} z*2cigAH>oN?P_trlq;4t0DS5Y7{nF3+@H&L#aG&~8_ef_=bTCI4-C8e66hl%&C;+;C;KV zzCnez%Mld_iGCRFEt%=e)^GS2`k>oRQf5-S>ts*;f4l&XMp2pcgnGA1J8e&*7#3#! zzWkuCDu=$s-(j-VuJ>LM4BAZgM>8W2WJ6_G4E5v z$Ioo#h!ve){wz_5aocgH!GRYY?y8RK@$}miAE%SfRW+By>wR>@2 zxVARUPp1mJ3~mml|8NWVuV{qOk-!K?W6BiPVFK$-`ojQ;=z#w0qZUUH)k+S^D zkB~{7+NK3BN?NJ6E+RDBy(Wu$!(o<@>EeLuZ+OcjvXzcle}jL&9&{=8gLoj`?tD78 z@A&B`o0jM?D|$NckPuvAVQ^3n1ZBJIzV;~Dtf1_ zw}NTR$!N+IS5QkV6+eGk9er?6Pmw~N3!I90;g}PH)_v# zzhC`T@9K4BbJ;T_MmmTqU7d2>aOo`;$Ko{K!I};M^~9@ zjb?Qs$6!zgr$3Lu!kYdm2LFa1io{9#ikkiJQtQgm`*@MQABo_<2(%AZs5amMc|$%G zn?^wm7we&fDw@SIwPu+vrQ?Bcw)RX%@vRccjrS#~R{)G&{{PDJ1(Rh8k)edb1^P(_fje*Fk!9T!(Ff4>${y**K##%FLmp-%{(2rB1bt(KjHqB4K^GiHo% z_78aaE|?LJ=#9lk&k#d_Acjy}Yt}skmQGxKG$Uxc+-h88Hy+*E`D>$S4={+K!SO&W z*0V^g=Bxt1UpOJ`&th#G^y}?j1pzZK|NDemOWl9wY*qE-C;Ru;Se<-;&5DEh){aR|A_TowQ2sa|yXlT!d&mGl?TJpj0W~T?p5h}nH!4# z2yrx5i9YFOXe-z?42t+~L>S>4w^(UqfleI{|Dt1!d(q<0>-TS^=vJkUno$26DfZgi zi@Z|#5dR^w>-_+ZlW*PYKo+IfK9VWbz}-h6f*(1j73bz}iY7m+f1r z_ky$k`v)3@QZcf#bvo(Q4rc6CwIXt8S_2fYia*rbQo@HdTB#sfY4Dt)FIaC0{~I$& zQ!qmT{QH>0KiLeFOjTo+&KHNjLc+K|6%V|KArJ&BQ*YPUcYna=C&7aTrzWyJU2;OF zdO_IQGR~HqEhQqgT3u=|ohiilczTZtnEVkbX_;CGbxEOAo9p+>&Y^ciqnq3q&;Qyv z3d;XGu~lM0IeqrPrcf)##JW)dyWnzxus5cYpApd?d6eh6$%Dn$oc6cINI;Nq0i;wh z4{5c;7xgnE6Y1|C^A5@^&xq!QAywQrz~`7G0S#V8qPvor#p5O!K0>QycD>uiO=CUq zpF>njD6k{?Js}es2J?LgpU)LjVnIbkRR+iL!G&M^Dw;TOe?Y-f`n_*`6CkX2F@YS9 zI1${gFuJQZ&!G|u{UEh_Q{vXQzb{r@$)^PgQu4YsPkLR91v9drCL2R|BfI zkd!MS!+tnM=+Bh!u^$?c!?ynozFezr01^PUCUP}6M8u56awYsyqs2Mn`*Dwq$xV5( zb@uNpb!{KZH{f`GsO7=Sp;N#&`@ECBW0NNf=02h!<#eX|t6OmfWr7h!H#^arvz|o) z^)wQ(pP$uDI9<*@!(mW`?q2uj2(^4@_(>Z83$Ym}&}Jnmi0WAIRmQH+dT64ds4EKR=5S?2Q+9}=g%;Ve@j0x>9%83JER?mz%b!ssJjtZDyz*I=9Gi+gy>E2=0dx3 zsm}IbufH#$Ei=S6xsP9h&Ut_P1H6r2;eX36D_<=2X8v8Zaaj!O?tSL$4e{ih%v3JR zeKnl$m^nm{HK5Dnk6eHa5Ryn8r3IA$epY-qSO_AY(S6| z3MJh50+oznd!H;DQ|%^&s)0c=V$d~0+sl`i_3y!wON3wK;=M2ZjAAeZ4)1A`(2QR~ zzf2$et>)Xw+#TKMz?g@k{e8Vqus94M#-E)Lv!qk@Fwio%8pSd-4;fAnOVqY7Y^T+b zFHrr)G=MMHmw5xHBg@t9NcaP@7;&v9EWMbWw#&a&Kfa9F6C?T_eu&ybaG~qepSE>I zni2e>qgNdhw^?o>DEjsm=CDt{@%#|Z@dm^;D{q(k1RO=LKHCIddk=NeB_C?M_O%Ts zKIJ^G%ubvgN0eR*65OjG3lA7O79RaMmQ3rp7q>F$(H*&r!M zBOM|l-5nBAB9ao)UD6e zhX^3hy%cL5G1m=Pe?ihBTRD<oTzE=WS=*|qWn1Y1g={LI z3For82~TMI%B#z7jFzis;}}psu|>D0bLy_a5glV#x+fvW3tV86`7}+?{u8gMtZGQ5 za;s7jmoeY~CG#@i^hDE0mlmkzzL`k%tbFlwVAKjK_4D!Xk1tgpYnXB0+lUj7P}xI? zIb6>zC7g*}c){R^iRhb?dz89hVNMbX7w+!rN(4 z17D}ajxEsQ{(b69^9}zNhs@-!_0bY@rODD|T+mR%*1vr5{q197?$@DnUzBg$9cz8c z-Dgc!vDa5QB$igXTIe0`>6nkVe(yUk8=9#@PzIb%ioZ8a=8TL#^M<{hm;^ z4j|q-#MYKvofSrGNFsOq&L>d0*0AB0CmTN49x_^5gw*9~h-1=?be6+*nd0Qcw;fOaJ<<{Hx}vo= z?O-nF8 z;88_n0lc}nyl;N?%zhImww@@!_I65VlUU;V(wkNn>NRb+F~pxftMu<%h$fNj79+sV zp-{~7^Du9<1{D$>dn>>r6WGHf>xgIrIsxk6?DXtj31buW2SHm@;%OW`n*Kg%2L7bx zXT;q`xPk5bdNHq4anb9Nq9!gfgx=5<_%#wP`zB>u%k}aU)8UfoQ5(!dyH4yo>)Zn; zHhSh4xsJCuOIY0*+`wI>7weOx>%|9wx6Rz&_bJ@C`+Eu8HHaY+Fe}$i z;LoN`4`*b_1vf zg1BYD;+fuYz%Z31ca38(TITH`(td#cvZtrZXu7a!w$koB!^KfBMqzI>RI~=1;JVP= zUo^wV;81N}%l9-XKu|}H5;KiMLH$%iVCTMl4nj}wNt@=2ItMChzYAlLrw5N-wq9;g zOqOW>Z99F@s}Tx^s%zXyOKuF2>x;bFd~P;&a$C6GWbl7IYb&W=DZ}#G*EWf>yX^$u=bV9%<&@!{J?hC z=6p+1c%HR~oR)DpWY|nCXs`GT2w?3YH{@E2x$lv(N$6q7d5m}nWTCpcCxR^{58qOV z%gblS#OmTsv<3_K^;Pyf)lnXCq!7UuR&=~>@T7ZG{or=ceY4K^N*?Ra$PS?Jc|BP(wbT->-jv=lk+Ws-sAMbR8D`-=`cn>Uzz4F;Wd@ zihTr(Fv#viVJk*`u3YrmN_*fV%>u~XdF#E`!;%G_l`O2VrFDz9QV`VEmAH_MC`qA5 z54IlHQnV^1I(A*KD&f+!cJ1nqo!v~88({ewdZKScH`NE3obi);>J7o z_`1tdm2BsAShA=WUnbCfXFTLGZ21!VB3;-O_4@n(q^82nXgcuQE5e}0(VCyI;!nwW zntUcM@+;aEHxydoU+!z?+lIqyIdb}|fNh#6W=DYRve_R-^UW_RUASiMgULZ}ZQPA# z4;d}|H|88TLcOk8l!X86%@Kg%VNa%*w_3I&0?HaF&BlNN09;eRb1jU;2kGx$1yhpe z?i;KlxK7oGpYRdg*XksXu4XR`5RbP<4(3c0D?Twu*1a#$&g_aYIFVBMNKgAod?H|l z=YwK&uRblYK&9zRm8FF$3ia2E^czEIBmjb|0mTMV1sVsQRKssc|Xa-<&)I)(eqIuis(-sRRBe?t7uDogwRhrH?4}NMwH&2e$pBWxkt< zm6V;yb(Z+=Y1{hU;n)g=<$4w?S}o_~W~>msMDP*y!pEVJZ#!%!epJGO8Yt(ju`$Gu z%rV!ealD0-@ag)b^P8FP@L)mh%hGi~rb9=tLGnOkC z%XmRQmpPz^B(NOq6Z%yM5xaJsb(d1xDF_ zaJ{O}j!y#)44>6=W$S-`n{R%*wBR+6$l061&$8cof0x#S-p*Rwn>I3Fm4e2};JA-_ zIb)PfH=6wvsjH`_l*Nix^kk6#0uW4(F4m*jJ(#4$b<1Cc0egux@bdj?$xz2kK3wn; zA%OOzu=or@^@MQFQDVcvUCzJ!?}Z`fh-W+upV+3f3Lu|&i5!ZUZD`%R@#bnioz_LVfRdw|AYR0PaaJn)dNzMUAwx4pS!c+UG(4uBqg1{guA z(mj~MAH$huc|;5uzuAaW9c}hwLcGQR>kT-g0Ug z>aY+Blne~W9@>3>d#VvEVoE*P#DWa_Hgkh^h(NXc8^CRx0?e$g;jQnz5AUfihEglo zR5C@wLPJ9tt>*vz0$~#a)>&2lkw?0)sqj|0N(RC-BOd!)-NipSg^7ENSg2OuGyjjmJvX* zCbz^?^@?o1ZEmb7+yTyMP_p_9~GovQMBM0O@ zG4wMzuCJ}FO`bn^s4;u?d%o~h(xkCs*u8^LV1Jwp^Tg327iypkv1Tu5lz@3vYI?qM z!qg?_t1~_Dz3_p$U1PTqkBZq4jT&9aIf|*_ zDmXE~dW2HQNs=|}vCU)~po{p$;rC9kqmGZMCbH;i-4yj!l^HbYu!c3W9W1xDx^$C? zxN~5VaDDNyzN{P05MdD;;jZuPl_wgNW8(GV_i@!!WRlO^sqUC35G03yaHexJ<{MOK z|7ZI#CgoUJcfxUh8rJLilYtaPZX%9B`v(>E+;Jzt%$_cG?QIVV-X0y4@ zScWlDLhJcR;a0TqkJP1BWjouoE13-F>suBixLa?kHIBUy`z@{99A+yK99u4#iw+OB z-lHO@N;E{*bKhUAKYSA?U^z?)gsABHjEZ#1$TNM<^nj$JuCc^U{@dR4IS&#M#{gol zrhsIS5sCEl+&MskffJ0(cmGGaK6T^R2?wmy<)fSle6k*GOiGC>OWp_C2WBlz7v-$6 z=bxjW39s}1J*Yya+~_0-fBW2_weKLd=6>6pF9Ub89NUr%hg*X@*o-*#q=Oc$M8gWl zRfFz;Vve@ql_)^uz}SBSJ(vezJ|+L;=@tg}y{s+^Ip$4=akrK<#)dTs(Fr&@?357F5%rb`*{-}JXErSOR;N|O+_Tm=A9l__*xmeSC)izxLqJX)Y^ik<|EPdDCw%G zR%4CDQ0mXOC$8h&vJ4+uUHvjGCeCN%m39=p@(;vXL**h!b6U(&&@!OJ0`6=&Uy61Z zmN{S6s=d%iDL9H}QX)u6_6Up%SJJ;O97qD&uYPe52IGQ=Oe}Tms7%#68S@mvXglW2^Lkg}8S8By z1rdAZk`7Se@$DkeXXr##-6`JVC4L7mVq%uL_~+n;XOqe77wZvn-%Hl`!Dl)A9WsBq(*Hq@_29dXCdox#y7Ep6K^x$jd`x_O4Rs` z_R7Eh9)H=DBAxCa7p#%;L!6J0)b)+lFwmWHe3Zll^RCx?-N*0M3v}7}jEZO{c(SWx zs6DYdd3V_d^&y$T6P)i@{J`!vS?M>LZT8|{q7ftHXMTtimJCh+gfbYLqF_3TC8Zce z-z9R$8OysM!gdED)D~VEt>mQN3}ISnS#K%AOhsvV}iOMxsiOtci!l0>OA+g$W#BY6_)=wAZ0aM%I<&boDBlO}Z`Jz9%~+{0WgITg4wV%Uq2SEg#W_O;=IY?LHyjq06977=f3? zSoQ#ZWZ;Iw^I&{=U8>Ake%W_*II{dI{>6i1L7@)ML6NTbJn-_);9qPNU#W*qkX+ceFf?K z%Fo~cp_6a@d$DGu9XDilfloJnZ!Iz%Z&lQK9X3n1l!~qh>=^uQ?jrlnnMsMz!8_(d z2g3}zZ5)rI#mdCNF{f&`oPsOGvXQVYHP^CEmBeR>-vUrkr0Dn%fvjOvilR zR|zbHz}}MET2T+on)W~b{f)-0^0zw#9>43-7;&k(jL?xwY?S}@DiQRl$!0PBPAJxS z{HHKB&rDjRSe{!jJ;^6M}?&O4MeG;@)A$1)$U}a;& z%**%%(abCE>1><$NY2Ym`RH%y!z#h;l71;Q z8Bo>1%{y>Kv+B|>(ei#RX);o3lubUz4CLM|fr+)x5;i_ozXN)8U#8=D>EWsWu0Du| zMTq>uh9)MT8^?H~yLqifV-x!~t1sez59kj*TY9$CyvVQf>)^wILq_t?HfacI20Oi} zFb@2}u~ED0(@;T~*HBosWT*^^(+oJuDAE{=Tmggtf4 z?-Yvj^Y&N{(Av4PeialJ2wFV)VSRX2jOSWv(V$AdYVkBKS_tol7%Nw^DpN+v? zJvw239(*q>M=$*b3j!Oh=eNFlBTJamH7XbT(_slXqFqL#4fK48>+o&|9U9YT(IU6e zvLG&IEy2h^VJ)rX;t(!^;?3mkurnYRr(J--Cg~PuT*WJ<%q{g+c zN%PBydRYkT&W<@6;i!h#s0H^M7~IxSa%^0-&ftzD(Fy!WbZHDm98(xTMqbOBmMC2n z@WbAfGYld*4Esn2P3*zqo1@v27>C=16&Oyg(kKGs$wb;@>#~9J&warYNM69m;bf|< zYQEUx%>-0}!0(YiyI?~Ei-2<+lOD2h(x8#IcqDga-J2m)O<@|m_#z}*klTP3gg@rO zG+(c4X%D8{mjaRU4B44{J=yNCocdB`d&2kur#BvID|lX5s$LY-VkiLij*-H*6yVA- zjcQ(jVuLT|<6(;|hemC}WAzG`tltZM)IQ}NS7>w*II*PwgGSB5JxFEVJeDcRtIcG~ zGn&j`4d?H`rCCedlP$qVaQ*ateCD##a;6p&#Oqs#W8dt6V~_zGoi=b`Tu{OR0T(gb~3q&AaKL8~>Tc zdCP#C1-#`&a)j}K5K;&^$+KV=YIsk6qv)oTkT=o_s9wd6M^+3Fu}OKMfa8R1OfobRDz8SJVNngfmWiV z0rdfEyr71#TFMhtNq_S3@vj!_NdZ0wpq2OXlJ|?=&C}50L&vLuqzc*|v8luDk>v~M zCzQ_(G}_ly&)*%AS$x^Qw$VS@Adbe zUnqi&jI0ZV|Hi#p>+YcGYEX~uut%zs*d=c7=^fclib2@A>V zBN2N68-301`MQT7houE9{!lr^yoruXyiTXvM!e2ZvfFESdGaQQiAR5SPQw#PQ8|$Z zXn)|=<8|_of83537qobU`o#Z^exG~l!FE&yM)#kW+k{>dn1%VMN(xKCkyng;PV#Bf zWW10>#6`ref556XDNKxM8F+${E&@djY@$3FWF3J#QA*@Mn6Ro2B|3rCWWR`MCiz;~ z!=aFr!c^>^lmbLDOW(i5Fi@?`a}a+`J=Xf@;Sj$~sf7dc2&2aV+E263A#D z=(MQ3c20lHytCF@@-x68o}HP@NVd|+I^dIn?y+p=YMBQ!gDdm3r2=M>*cxlr-|3pD zjQ&*+)8Et41`=O--*OF8z9jKq*gP_YW3tiIW%1_%{&o``OUK&FIXb6u8dL)YuP@@M z&6rt;%}kfPb>_@p>Jj&w=(qZP*N8wG16f`O(EBQ#Cu4^eq?p6l$<+qM9|J*IZRLn! zao4RuGJ$bvy^yCYah85pMCR;YJ9cRt_IO3=HUMYtvq~5Xc~3kq`FenjvRoQS{S)kr z0r@{5m3d=*Tyn)bA?PmDlGbdTz{^#agk?h#1K~B;BXFqq|AP1gQF1@zZ?9L3>~kx0 z(`eijc_nKuoBJL(LAzt1?2qFUk9=#Vvv?l(#<7ygy!nt=}TgVF;)q(iSMa83e!6*eI)juK3TFxG@^l zyAN`uFO^J>)O5OHj3#=zCtqluzd zb!bu4Vq3f^gNWUV)6I0eB(Rr2p9eNRnGawoQ%u9;fn@s zTZ=~G68n3DPG?*M_7(EUMx5dPjR;mwRjivLU;(jKy71qRWdxoacX9h({mGVk<`Do& zzS5(et8SZ4%84S9)o`6N%hO%ZxuXC^sXd_PZjYM-5+S-ruuIQhM?kBUUb8^CU(E%R$Taq^kJ9h{4!T46aNrtpLQ!D(7y-jS-UcRN~C>P5ai|9+e!M?+!xMwmEgt zFU$R1yp)Pl>z@jo6A-;>yccmC@KVzKyh_&;^Vj1Lcv@L_(Vu*0IH$@pSOWapELL+$*$}kmRjY zUm&*j8AnpoSGOulb7ROw%GCM&-z(~cPfSqrDEuQy4In%~BZYMFbly{IqJU52JziAL zKmFctzf8YQHVmKX6Tro#QRi1nGJPmEH8nAbcah9Do&KU%(q9$hlK4;^q%U`(Kx*PM zp}kIkh0G=_l7by2wAf;Iyc1RxO97EZG1qyWDdbEK+VrLI^ml0j*7(2>_5;04c+s}R z&7pMuN}crt>glf;T};H=UV|~4(T;5qkNq!@pqF25irI!Kv$R7L_=z4JcSPt4&=z^} zAQ2Psvgx=<|)5u>X;h&l& zl9E?F3~`ByFT&DmqQX+VFN;_ITVbpx{GfmO@ERVN0 zrHXj#1)|^4$;V8VB&~Qa8->l++GTf8C)ch}yq`*G{b;R+mz(zWkDZn8vW{>m)3{W% zUKV{i=^8scRp~QY8A3Hhu_7f!7}E+Znsnuw{+iu>2y~^{gF=SXgm$i;lCXKyj?44j z^c6#RYlw)ba#yH`u!m)kKI}7?4OI#}JJP}*E>Q%N z#i4mu=Oymf++XKfE%0Gp<4e>DWMtx7CPF)nvTsPMB;3?D6eK7dd#oawVp;Z$VQ0X} zSe{9pX8_BQxc=cK*|`J-*7Nb(Fz;k%ELnFb`Pf1oYzP6{`Yx-FjMP@9K@j2>B`q!F ze|M98bL8sk$`S_yv@K<=lbW&F)r-885R~Gv!Fg%v0?`BoawEvkt?N)UN{f^{gr`Ws zLgWW*OVvI=hqa&dGs3D0a|WYh=Dz8!2NJ^HdaXJQ8G)D7t_qE~tDFu5+A4 z&Lto#I#jJyu{5!>)FSwE8b|rTTOR1y{&%#o&+C$VR}LyXh1j#NEWG#u-yu(HtziUI zH=iLd**XO*if3I_^O#lLl0au6HI}0Bf-Lg5?+D(^a%i`rbaJd7o85knOyupS0(~<4 zbp4%oGK^bwHO1;3Sxf!$M-uLGmDNk=cyH@icQn*kv9ar3%kt@py< zgKL)U>Abftd-QfoUWyqSjW(iHW)MUJrdUhqHIl-ddgWxEry0t8ZKG;yUM776GJd=h zx)OlH?LgG5w^E#c#57E6+Bx&&5*w)%u;@@yj8VLXZ z*w}i5BUll6a;)&n9l8<8a;w!9$9_Fn+YK=brdJY{w~ywF>izO3&X$24*RkzUMetK2 zS`*S+tGTFMm!My8f6#k3qNB%GzJQykG8UMax4vu4Y3O4pc6aPsA2lN|6eq9wGU$Z z2a`2zl&XjtGi@WHMc_qC_N?ZY{H}%FgGm6*bgAJ{g8ZhMj3ilxQf+U~|{9b>-m) zPoRjdAo~u_D22Q%58B{j0_)CD!bfVM@yzD&6KU+=y*_RlY8u`C$~#+cQ)22#&kzsY zms|Rr?0dSyDe2)ySrI8KmLsAk5pQWe5i>lPpK1zs#sM6IH}TEF zioNgKBH*Q7FfB*Cvhm-^r*A_Y^QL2Tcz1$-b#Z*&KKUX>^3$$`9gjiow^vVUjtJud z;$?FK+T%pno@zaLp&I~w5F8+tACeN{+1bYJnHe$OlNwo0;Fe^>oUn0n4CJbk*jp`g z)phkq*XwLB)ID@s7X?for@^@{G$`S`nmEcolPO>_l zG0+H*`F5k7fWNA+t{`?Bahc!0N6)IwYf(T{V3&_?R6e)Tcd_UCbCQ*4M=8tJBT|vT zw|s8-730uhot#&bY~54xg1&R8Ez2S5 zC(LSf|1J^>NXkZws?+MtBmvb+qNEDLxWmiPuuv-{EgO^EyE$06p1D-X4HoO;5NI(q8)V zp^D0C44%;$g(MJ@O9#VQ7Gy^_Sr;v_k(K+uuuF>K2*?dMlBLF-FkBE-`@D?QEPub8 za(24LPGV>ITA}0f@tkE4HqH$b4l84PozCo-$$`JXUDKh%` zZ%K*6&M6QKf%jPX)~Ljf94$w3tu6~LPG^=Whl^cj2QtMdqJF;)wVQJ&um6RNL&AAW z9#=Tj(;;k{-ZsoB8rnSPsF%Uvc%~mYs_`GfhV7$&gbfR4V*BSy)!G=IjvYQer&E0T zCwMqrGezoDw}y7%Kea_mC+(riTkz>!P14FhO75O)BhdD0_JU~liaaq`lESGRLjRlG z8r8^nd$-}#GC*<&Ib4*2@kkSawn7Am9#SWOl1 zXIhVU#oWxL6gRkTt$Z9=is5}{1pCZ1jsy~nm%-GIRRBgm{`=jJl0 zrk>wJ<#h1R+|di56S2lDJkP`T0N?AxrW859W@3}`5$zK<{J*&-&LU+;A~pPcQ7f7e zdZ53mKNQJ)D~QRiQ+~?s-_+!_;W~nA#C7-bajD7F`|}uKSJ@HjYP|{==VzyjSp{XX zdsO%RON3voi&4%8va6BC1bD4>V&8nkkbr|vHJNT^4EymGnM(UlBAjDl4vq|==)c`D zmVL;o0tM)tC{ytAebPrR>RK@PV-05Cl)z)Me*SKpjtBjVejN@r`D9Pu?NO9=a@fNA zi#Jwt)(;$1T}6&Jha%@4R0E|${NYv-kvh+R8U6SGMMA2SktOn~fk}P%*sOpf^*TcgTJtyWa& zqPGLj<3I}urSQ=-7H|!kbaZoF?2>nH_Vae=PC;?Kf1OQ{(o;J}vza+^_T%q<81aW2TbpVkAC257%Z=Ji!l(;oKyVY;yD7iDC2IL)qe zJ3mRQ@90MTn?Cr&x zRSw4`I;QVa-mJDLH?DMJ+&loYg(dDay9>J3e1)cFZ|;CIQgle^*G_<2P_H&qD2=lf zxeB!f)*4mzogDG3r|*96%sf4e-^Ciep0aw?>PO}4{S8*xF6IWY9pHVogxF+9?-XmJWx*j+;;LK^ZRA!WSL1N` zow?vhRR9*Nji%v30x%mgb#>n<>nn;1jHp-VBij`nm-kKYw&x$?B-yVds)9)wuGi#K zIJ_$_4h#8jCe=$!fQ-mthFJL(SIA=@@wJz!Xv ziq&hJ3ekio8TUtBEm^qk+bJrG;|AIlBCsn4OEL0YM{#EX!Ri_I6|Vn{7uDF9sTx34 zW)Iodg$X(?fhL9>-Y%Zl9~44XpNhp7PD%&mVrVh-_S^vfFSqs$2{3)JPJf97{hGN4 z2C2d^Vsh?gBbfxkE*p|yoD=8Ae1Cg$d9u>#3Iy8!oV2!PbT`!8>`{`Ed^g+0)Ow_g zPt;GWdEcMJ8Wo~Q=&n?;_$0L1tHorM)-f}ePV5c%>`5^cs8sB9roiSBFhE{7QFhG! zIMv5}{0Q}-s2b&1ZV#h}8 z?b{mz6tAT*__mir+MwOC@Nc114}`6zz(%tD0%ePU0YPB>NmK$z!WjIHqz2O@9*#+S z=kD@9Zl-8^|Go9jg5ARJ8)8UwMmwMf-`El^@Zfkh#f$sC9Z>T<{W*0X5!!`(>J6s5 zIV<12MjNd;wcgf7m|LAD;KtP5e^QFA^E$!3{l*Uv9!ADrkozhHZLHkOzShp>AQkg}W(v&z z-u|C?0`o^E%r7>EgpVzLMeaDQ{Uu-t`NW}n-@z;1n_S6E93n)CdlzEC*#^^V z2>ov(^WF?(N%&XF11}oDRZx9d_h;8ZuCxGnNLatwX!zxikR_r<5Za7_sYj-Z`>Deo z`Wpe`EpZGwGB^Oph}zeYeeZ?<{h}*k?x1*mksQqs4@D^DdJEH5&sbhEwH%nfU;z4gJMA%NEI1J673 zdVrYVVt_)b2apj+;LDFNtB3B>WQ}GqY!DzVr@G!-6<2rBVc8$w#|N2+WMa)eGc>$- ze@-H~dVT!)P0dSB#$&pH(Vy&e|1bIA^v$DriMqF^HXN^;ace9`bXpEMvr6r~GoLNl zwSMS$hl2Yr9SH~`#Qy?MM{2k{q*7YZef$qe@$2@b*UpgG)`t`Ao1B3kLAF#-umAmG z?>B;ocR$#5%abDlhi0pp@IYS>I3z#Jk(&Elq*g%67SA4qhPk-TJ)h|5xB26l`6E%S zk}oBN&NBucdhfZ|ah^`ZbZ1i0wwix6?c7%o<1qQsMo3pPmGg)WE-) zB6$^G3={6G$4YPOM(POY4fXmuNDE4sa52Ac6eoTLOdJpDJj<>;Lq ztI2YEe~&aH`#)f_&;tO&L)}LKJ9LYm_$|K~J*SHGYEFRaY!uvtH3ykBn2poIuQtFV zM#o~(oF#$hW`_Gg&YaCv|AhrOR->c6-nZC5!%2D9S9QW`qa^^Q0wg8oRkzKWg(})m z9Bwy1Vrew@uVNC7q_|!aL3t6T+Q@Eu&b*{w8XKjOnBW8VM5AY+l|+cu2M!!Fpf{{u zph4lb`-2M18NN=Dzf76^KPePn`HG+BJm27a#52H1_e`>fnk10!1J)x=eH!=+`f>$H zm4cVTnT@YUOZer6FcD15t}|$(Waj9XM={M(Cw{;*-&ct>#pmk_Vs;fG-`dhf--;Cw!PpZw;385 zJh7xyG$&@E!&SMCf=T^1ri98%S^Z-dpvIw9s)5126)p4ZnT7uG3g_4=qvMhZ!)->u zzn9`ad+&rt3!fIhlaDi|ijWlg-E`Ca%Xp&T*~UQ9$T<5IFw`;voi$*T$#CHf?mPAG zZ{KVTXUvbgT}v&1X^h%spQ2EscMEs){C^Xz zhl(~jT(0Z#|IL@+aWb+Y1+oJJ01At`Q>_X765-H$03O$5s4#4$PGHk^5g90Wp2h>8 zMCLznwfL2_;D6aok}T}Y*ODwg$m!^QfXj7kasCGX%2)2+%6Qyt_VGx{IoP0kuP5~34-n5_gVAPC3%6+C%PD&MDxDVGT{o{2 z6ap`p9{SGgx(SbvttvV+6`q%-D#WgKJV0PMXv*>z{un}S9|UmobAs#*PODM??TiaX z?s5Pa0{nC_Hu~bWnXg7NL_z?ecm+U+<&>2p8Y4-e4wL}<83Zu*>vNV_-2l)%R$(Xs zAi%wXbr3T?-SYGc<6qmoDJK%9eS#mi5pq-#p$}&P8}L6XLTX27Jmvu}?G9=N;#R}o z#hn`jtepah5#)`LtYyK00orXqQ&>;XeXRgg!s!vkywAFTVb}+NQzZa_CXMsE(i9yt2Z;O^)H+qdl?j$lrF+v_P z8eJG|6HU36G|xH2VYDMi|GWK+1{jb;5a3%Mnes!(3jIo#@A-CCQp~TG)a-HlIzdwy zWb=KOvceWI4Ndgh)1Ss*8WQ^Pqh=MB0HqaAC2tzO|G{K20pJbV*#-vOuhFE)}`S z*(@~+s{k6Dl3*SZ8csH_L+dw6<}pTuTP%MuXcFu*f)ga7iEAx5?3VwikBNpzy2q>_+5^uf0s!&iC!-jC;W!1rR?2vf7TXu z&1KtC0$2Ljxpu4r26#RK4ht_(!~ogcP&-fD^N<$6fZqe|YKhSnV91tCmPWkU8O;d< z!+&EaOsx5OzsX28A8f_|gh>#DVD@#37Esc*wtcA}%(cC$gm1ZvMC~f^P&v*|JnK-z zCujAf5I^of$chbtSD#og`KN$H*xKwTrD63ZU-+n*e0CWWXZdO9> z*RO6f%9Vf;nZi$Lr1jirag8Gg$a2-^Vu5|0D$%oaeoGa7srLC;xd9@{Lcvoz1%Jh7 zBF~p7Yj`?B1^eMEnT;sKILWh{;@a0^mrf83kY+?tm7edbt0O zT5E0Qy!iM#EE292Ag5O{{Jy~gN#Pp6=K&A;%^PXs8CmF;VCZMCALm-jnQ--Bns+5g6SWoOJ3M~}?7Q!-mJ;AXbo1pj!uY2YkdNCu>>xE+#hxT94)-A3rM z$GvST@Vj0Z0d1A89$m=R4q$T_gO3*)6(zehm|`zo2mPP| zC>y|DwX_6p)mAJNFhO(Osup&ptBgxDD_+%LJT56{k!a2l{^a``^d-6^VeW$ew&?H@ zLq|pyJgt;a_|(&hi_o;!WOvE7-k(tF=_;w)eCBekuTwJe?x?;_jkWpDxY9ShZg^DE z0*1QcdoKcWKr|?EIR$r?k09v_|Fy1|MN}RMY5uBx-!$aU?wzae}(z> zPk0B=sLfbItaLd9yh3<*D#9FvJ36vOxnKXTTL+<2izqg7#FP9EjwG?2{&E2*J3K(8 z*gW(Gs6_9~RxyA#jsl$JRn_|e?p7mkwtvn66Vq%!Gx}NBbOth=BAX#Qz@5@|Jxp8$ zJ_$P78FNC@;4nmiU@)=h^wV|P!06~(-Ljwltv?_}mf#C`d7e(`Y9AiHLqO894*|Fq z_C|m(2>}&q!zP2c`0FDTNUETHc-_)lEHE)aa>RY{@V%$A!3&Ac9tTU96@b4-%+Ufn z117~LfQ2R>$BYdw*)LZFO@Lzf8Q6#P!fzr`m;Gp1bo-j{0rP224v*P}*{Of|YS%eh z_M4jcnI(q9_mh2G$Ha_6X5?t}e^bPjwn#dTT{!V*H&<2`lLUFM_P}M?5UKUf{Wa(F zV90JPB&6m0%iOL`PTW>AGHTD20R6Bm1<&tjR+Y&1}B^6JPl}CK_Z8MVdFjCz)}|%4+&c$`kADFu$)hQ@A7~ zNoIC`4wh-)SQo`Z5Zk6y*?N>`6q&$`Ni?JISV=k{tBRa|J`|5Z>lI?w_C@)Rjv%bD z23Jc;VHac|S}OyrJlk5QDhv(9-2d#T!)-w?{!FNX$vIY%>5;%a0F~&TXy94uH@Gm_ zOn(UjNgo3MZu}Gs=x=q-eae*4FJ|DUu^tFh1Yig4%{?G*P(p#YqympgLywI+kY!kU zHZ27{E5<*{0eaS&0_au|c+%vi&78axS^%A-RSQogRe^6{iKtH54e5UahuSw8mTIv@ zKUQM?k?2SF3fC#*Z$*ph@n~S%_^;BH?Smk%9tMES{_hx2^Aae8-(pY=vqdkg(3Zf! zlh3U89Z=yg1t)2?!IcF7bx<4IMt=u~{{~pPgz+KRgj?qDNqhzjN%atpX@c9jvkrJH zIEQGoc&$t3KLN~Ou9*)|O3FarU1FJ_dywcV>y^U^~Fb{D4`N${7QBi!PW025UEPwrV89Q)2NQ)HW z{jcl#;Z_j;!7F$iLpFZeR*X){3z?~);Duz2wK|3{y8jZ4l!|^`kw@Bj2z<_2XK2<5 z7b5V8-em68j6E&_@*6*yI~MN2-3Y+7<>B7h*Vtfubv!(e4C6W!CWqL$F|8eF!K@|g zZ7c!<6vYhOizxZs-$$Y%?Ro*PzXfs#~-IVMCeoKv8vefj|SXT5c?=XjS6Q;VUO!6hRrc2Iv2A4xh!v^yq8 zqkXSo12rL)H9bmKSD<-FaSFCZy?XSsh8VM*q48PtW3LR}euG!TbOVQAOsR zUMv^ClEyt3lnfKK)fcsEWl1Vub1>HI(UygIdN>D>3tLySaJGI7o~Nyj3)zWeP=n1> zi47Z!dUP`cQ(Scy5PkoXKxW(GI`WaCt#U;crEoPqQBM(UB2;hRUifqGu^np`A@dd# z`vq2;$q|-DxKKwVQi66;)JBLeAh(9Q!+btXPhMCpHh$H(8It|D%JycZ4Z-|b2Efp< zV#ThOn~{<>f9DY|T$Q07f}xpmf~yX7So*v){zoK-LKBb0Y*`(L0OitKOFll!YPm$l z!v4ZTCiKwuPaV{=|9aTk?jf|2 z+D0lsv0ij9%I5zS+a%lS3)A3n2q>3xBl&DHLK+GxIR-X1MWWBSct@@Gw8YS%d_sdzU#OwS7f_3f1GylZCP% zA1-us#JBKnAMl~b{Q60qG8|M!u8#yd5@}b65s3H6FWZwqIba^{JM%>^m(V9RDXm|G;`W!dq?%r#qu8Q5X@nBftd#(6vS56=@ zA@I%L6lvs#W1cbJmR}e(BRU$0Uv9)PHiDtYhj1eH67XSS*2)klBKc6&DC*E*qRk1$ z5$M*r>c)Sz`6IsaXbcu~S%FByQ>m&cr}g&CiY`Bn*2whVF`xBBLf;LL0}%ryQ!(Y~ z|3}?lM^&|bVZbnPz(XHGIF!=e5~2tP5Tv_75F{ieL_oTckd&5^E@=>H45UktMjAmH zX?WMh`+H-&-}t`2zd!E1V_c7W?Y&~IIiLB=XCgrD>$f+VvdfJR91qqiwQzD5dPvOLWCv6R?1lHC!w`bz4}sKRjKo?8jX;lBrfm!mCmTTl<*SnP3S!ZDwOA_aT? za|^qmiJ9;Nei=^8uJ}cp!OyG5ABRBPpEu?&*x19OorUtTGvdcN*`F!&rdLKFU=T73 zDNk;ce^9eThk^{Op9m6|e++%hFW+ePFYm&A3aOvtyK`iMJHb%Sv<$=2aG@`RHEJRV zDUHIqpC{|aJR6d0spc*as(rTeH(2%ysCP4b`!HJ(JY5NX%>szwW`;LFzBEo-0O~H| z{}`9-mistXMw`9&7dgeLi3%=;LsuH;;q*?yDBrq5N%nlRwW}t5#Dm7X%6+lF?sV%LPv85&6$O?@d1V`7|2K z#U^vbY$lAAVf#_&nCjU9T236O+$ej}*97?K;YHy4C+bh?u<)lO_ z;=VJVhJS+}Db+wqD^?56ET33MD&0Ik#(!)dQ{Ypxi+3Irat&LS8*lE;qa_hTsfnZ-(>-p}|ZC^?EKfBBG>H zK@e|h$ousGX%bMZ@ZZvEGX3SQDt?kn6Rh_0TAu6nGTiO&7m z3m1_KCA@Oe@I_#gSxx)XVY6M4sY&A$bPKRff$J-tm{Asqm8>w3@dwnX!~>4MO5rVn zU#O4yQNSnvyEVfJ%4PY3f;NDuHHW|6{nolGyACjzkBK&22<$L5G$ z4c29C`m-u0iX>#(C>=>V6&1=Qo2_0+=I@7K)7<6paQk7y{<+gcQaGOAF2_T;O_=RUw~>rUo>nhVVJ z*NZSZJKPfQ{XP*55PrRHlNO}o{?Kw-EE<^u3LjOSK`_+=zz}A3c6QBQA&w}XrdVV- z>#e}w!@|j_XjWV}2z)NAaU+)~dO2M%iSq9A<)5e3PXYAgF)a3xVxtF`42zB+oE337 zoIP(q+WxwSZe6jzybm1aF|tzWzne?@MtUom$!aFTzQPh&LgQ&<=rVL}#E#NQuc*1;u+sD*i(8T$fUHl5YO+~?!N|5R>-B9WD zG(hah1RWyPoLY~~(WJw`-P`_n}W$m%+O;9z|5N{P2`P`k;d)^^?qQXz|Pg1-FvxSDat_q~akBq;K+Ek)E% zJ4etfBHlshZyck^o=3SYh_lS6LdNstVoS@2{5LV19}%m;Izs)Z88*i}rsL|FSD z;@Q|t;d3Do-(siTyB3H=-s%(Y1fwRyqb;UgBFrbMYI=LpgV6CXdG0o>?^N42jG|!Y z%=rBqT&8VhILx#bQ~6JV(Xv48;I34epeox}ftb}5YDKAMrIw)%5 z+mad6_{DyQAIluo@!Ov{u+C%Oa?$23t!bMLCm~M=o4?=Ab zTnjXGf`X--T?y~J`YWjf3wbXZ>k>{~Ro?|}IfG3eMhKf8w8#nJ%PlCdqD15e(ccrc zt?3^ID`{Xc?!k$X{cygBzi97U<+*^}Pfg2YP!=T(d~*wx9eNzE-Fp3MOABn0MVIeIM&LFAoQbKTzD^O?hErMfHcPd6TXTJtrWlAPd;nlc&jg{w@uYpKgEc0I!lr8 zZ4tDeJRobJyBjHOG)eY}mctTWhh6KvMc@P_Q3hXkfw zgo>{6-GuZPVu~bjUk37|33vho+`&;o#!tZjet1ZpE|ElNX41ReR(zZll z?B33{$Tr#V1Hq-#R%s<p)s>e?cP6S(h^U^g8MLFXz{I0h_qsYq-Aq^W%yNrm}$ zlJ{Qmzu;J2AP2%yD?*NeQFC%X5t|E+@YBdscqhnhtNf&eZPQ$7A?S8o8%^v4wP_Y3 z7n!+5hX$IusY48r9J+b12g&`sb8s-QIGNN$O|`+bjT)s!ym3cIuizpaOO*N#K3bZc!&P7&WOaIG86DhM{P9|pHQwUucLm%5!1<&MS02zxcsuo@*Y%a@28E+ctrO1&lP^>w;J z7*X?wNe5<6%0%#ehvhcHc-}J zUp2$Gqu=*0jMfL_z*sEk^*?$E*YYiGN&Q|^k7xrL=FCMNs^9f4-vEWlK^P?|%1$1= zU@tvJ)01^D$8NT5qVj1dWshO&UzbvR{a7+z$w#{dA&953Ji-M8_+e>&S$F5u=Q}*Z zPl&Dj4>3tF-qZlV3uh@zrwVkY=~B|!8#0mVd{t$(3uD*YdH9mUKHQH1F#}dXtphQ0 z|1O5*)wgxs)Pr!)$sQydQe>i&CqW%o)Cc~8#0*M(m6S%35V)T5JELN;%VLd&_+dY_$= zf5_@>tPGAc4{ti*?Rfa8{CV`Wx2mB>uFGn20&iS!DxXgiml2=gtjET>Tty3sh;M+tI;KIR<*t2lJa}6`P8h*6=CRa%qb@T;P|~YBPH-6`j9^)GWgxf@gOuLt7$rTUq@-lA`b3Y+Bm5+G=Z^|;*mKxe*_CX0 zbKX07dIcGC`z;2v@?4Y0t*6>W8bwW?-q}8f7FPY=^s7?~P_aJ6>69jVV7^!9yVHIn zgh?&zwZaDk9FA#?cocoh-i*esj;yQGBI)tV9JJw^tOXj8vRp6v!%?~PsIg2A*Gajl zx+~V^S4XdF)@A;9$wpDpf9Ox;*_cZst`%7O3~X1MFVJ~rGu`p6Uwf<1FfWg62y0p; z)6QYAIs5yEz?9D67eKLvu~JneE?$fzad+niNs1&fAPivG;)hL283MYp(JK57ev=r| ze)ot$1~leYbWNha_swr@YNMaX{MV$f7iDd#mRNYP6pnc^T-0(3o{0P9O;=Wp*%#b) z`_~*LSh(qOB67?63sj@b>>cFV1*GwC;8U5rBJeVpHK~@=(D>Q~l4j)b&LL84xQ&uz z4lUW?x%l3C{oi%DP?AnQnthZ#II zJTk(}4+?KnH_AWte-=h@2$mIp+Q?~U$C#ks>G#~oUG@8}0s2?SQ{EBI!tI`jA}-4W zF~Q?kf7f{Cd|Tx;Lc|$FX4HfXWQM@q819F6y&>j1gv0lFa91~9>;d>Opl6uUD&46n zAjcKC;Kf*U{4uy~b9^{&)$jFijl8C{*SE*R6{eljsS4%MJ=tFkpSSSfQE(?DWIxWe z7{UH1&ic-eAq&1Fj-h}|;9I)niTq2s6vO9!Qx`|NW*hd-H-Rzcp66?&%y^zm-;**X z*80SyrKJW(U4u_A%X-sAH8nJ?=+!}=Kk7l+2)^duf%n$wY(P@qPq2oNbV*)>q7Y8e+)jYko)M2<7u zsGXB;*z~Not1y5Ke7WbpFQ0W7&Zv6W78R9oNy|ott_pGdq;C~` zM2OwVFRK%uDb;+or%_b3y>0^B6ED8U^4Xf@ISSp|J*+(Ov7XRy)=$C9b1i{=EMF0d z4&>@9Aal$sT~j{>`cV_VK$;)DLls)Rz^C)XM@A zc-`c?R&y}H79 z3w6|Ow3>$zLBF**7yxU(K!7gtf)cb^uCvverJZlH>cUZglBLA z)IK&x1~>EgVTjOI8 znBWve%^?RUp6DPRJdz3GjZsl04%389yNj@!hK=;SH#-B4Hc^T% zQwNE!-+Rh`@_bHl*5Yn4mNuL>{wzhrJC|v+m1pwv@CoE7MAHA8-LJyI;fp-b>~8hT zd1`{$m6M{I^Qc|!z0;on$iLm&pZ=yAzyZZG_Nng$x9{bpDU-7HY7S&g!ofa_Bni1Y zBgRqZ3vbwsIT497hV0BxSX)>nZ%*zby+6A!Dg{_O?LgbEZ{sdpd> z(jI;JON1_ROW;a&xCHtW17b%5%F#wdSxMo(4$O?=-{UJkNoWaBAfU&aQcVU!kH3pC z)}1vL3^!qWm%)Gt{cqAcquw)|MMKkh^?v;=G@WJeCl+f*pUpjN20N)fA<81IpG)@N zYd4B&ak*VT*NSB_0(!0h05mU{J;F__1Q>rM=0dImm`%wxg2%Er(z&Ji{frFPCsMBc zg$8*3{5_1*7Drl*KWgBIMyR$Zw|Ek43UaAWVUNKIcbQSpzEuK$k`7-xlRCOSfD9N8 zO=$^y4;Q+Br!!!KZ?6g?mSu&Of?ftSM?oK~dtVm}i3UcQ$qk}M?Htw}gXe~FiN|QL zgG&bdrtt4u0aD?9Exb~=A6}M& z0nr)ezkZU(k@lP4u;2BEY>*dSo(9;^;5RRbpmCZ0`vp^ZX+OK!^;k&L7uV-Y19sp| ztl?hk0(~d~U1U~S0f_-Kc^*LoP51&P`&$h#-hp?vCkp(Rf%xAq7@RCzRfwNhy^eelf}Qi7;vvm1&JK%d+^ukCDAxL+e`<(&tDQf+mw<+ zjHy5;RPEnifQ7frQczrP87wu9Tn5A{oaX-Dr0Js_wUm1&PVLZAM!zf?U&P6 ze)s4hIC9eDu>tG3;OpD8VUyx~pASU5yHUDlKl;J{0;i4qpVe$H5173@a{YKg6h}Jr zc&|#+cqYMhhWC#OkW7Gs$G(oos*o4HyhuZN(*!z(SpWWNY*p&L=rG8O;$pxWk!J^U zHVDsEN@u^blB@XYD=xcwPSE=E)&6&JgB2|`)^F(^d}6xsAmk?@Rx)ij|Bc^JMK&p4dOg)8KMRe* zQB8Ce`{R$}26IbltVbkY{rycF^OODkqwkTFyz-|VI2sG*3+RX0B4DO~Nr3-z-Y_JM zw_j6Vn>%c-n_9$Sk}!=`uD2P{stE@N%Nj)*d09;7fJnqdpY0!cYf@%%+S^vmd%X`b zXYk{Q>qW=*9U9wMxSogxm`X6XYtj7ci0~5^3n=0!i>i&@=G&*4$F>r69k2%){n5P- zoC>hPIRAac+N=ib+Dhi5LEHWX{d{TJV7~HHjFJy%Dq&Df&(vj807jdC{co%p-Ea`I zT-(td30PHe8G^&H#!|X4=7$g1wr#J$AKsw)=k;RX$&T3P!QG|7@+u8M2YC%e_h}M0 zJ5I`0I0#q&k1p6wDClT&<2NReq1W<3R&{~OKYdgOep3yMt<`1I3VmsiBscykSr8Ih zYlN!$C+JqCL%?h5q!7}lP1^r^T9^s+O`{!%*0V>H-k@$sucbMb zN{}t3bEXYjApN(EcR#vm*+@Wsz`hGxGjzds1hqv@DCUi$x!G_A*>wN!H9)NX@0?%$ zjZ#R*gqF;iFWnfj0Q9ncleo^ZrYYi2EyEr;7kHLU%oDx=P6%S)KWAo@Hn++OtQL`e zUEmrPwEoVp4(NrXp7n*O(U)9+C<5G`&-C;3UcWY+z!Q*-85&e_6Jjy2t z4*#2J|DTWknZ(U)(zkqY=w$+s#?PTSrXz-H+XsJWy)+=Cxl1>0>}el!r+4BH_zWxe z^&ik|x*}C3F%24&Z@*Rp=oCF!p-pBuYK<7~Y99>9yx&FUychvRsVLp)D`hrC7k|K* z9h`IeUNQsX5c-)X0Gd|G-Ee zyKEwcvtlDF95o8f-lB+U*ozlnfz~<6-tD^pw?pN->sUDI;Q9+Qf=oeLVYvReiO1E& z5yWuan|2wH2ec$9!W{4tqMz@9iMxIY;FO=*XJG_gZ+f+U)iWLN<)rYSPjvzDd=Ot5 zC{Z&5Y&8}buzMd^Q3iw zJm~4t>o)awd+xcHAehHj-$AJK86%%OBB(U=26yC?FII+7=7?NCkAmO%3(~$kt8J3 z_fnsy<=8Ax3#7C1WFp8lT752zPBM#xs2SWle~*&+8%cpJ@)Z#RQiHcc$)1csoXI|R z`E04+NIjsY6$g?^$#w_@q_cIN>rW%-kaCT)e8A+y;0I%j>+D3D{;al&?gc4I%id(u z1>`dTU&T>BZ{pu9$r7UOApA&J2!ALDnXv){Hq$L}uLhji?r^>x4L}6TJ>f8Fr~tC` zTDQD*+Kt?FHSnY#ZvQBJmkwrc6!!zEE{uW9U+srZ*XstAjvb&eni42{DZ6-O4unHv zUe$ujTy%Zxhi$E8V2PuR71%z|bsR5`EY|(N1Jsrh0Tr(sDEsq5i8dqGAKI+?NYK>a z`E)!;Mo<7GAbfMS*%0&{c-8_nW-wCZ{)Wq88aC-oGeqhJw%s#yS7(6L&?A0N;)?n& zfZR~I_|{}q>*oE@3SDn019F8Qj(?;V`r^@UTDxlXAbeFL{Q;L>UU{i8c+pd#hxXQ4vHIEMkfYj1qd~dhyC~TUW&P17Uk5glhsV1I0M?|2#f_iy5;fk^l1CJ zfF$7E|3u)G5(D5%7XJC*14y^*5wSLNSV7iz+3sPp0fezSSJz23(PyUKlLkHU=>O*p zDSN%uB@P2OH`oFSB?_nYAPHfK;eLMXMEXDE&J_ueJLSf^Q=sp~J<=((qOLwL@ghUf zGnLT9U&oV<6#moYW%K2NLqlm}c42Q^!;AFE%eUU#y)pK2M3v@X;o3YoB})=c0z?~@ z#0<3JAHPs1+`^y9Q$R|E+KU}+0$*?7xM(xnuND4K${x#il`3M~x-Ue}=m-2Q2+e*{ zd?k3yLf3(KDSv@f!wI}ix6Ukj?Q(q!{7nEMmZ-9zkb0wydRVNM;V)qE;>5jA#9pt)zrCq$jt+AK>oy6-#ll-Tg?cL~@VieeT(bdCkcmok)Q zYniI^z4AE$`QBxK&JMKDRTv=8F^wKxyKTn=^~T8PYnPMJfy3pTL2W4~Z5fnSbdz4( zj%k^hI)gC`+iW&;~;rWyHdfkVTv$Zn#l%ID%0qpc1%s}OaDf`>YGw7W5Apvvl5 z*Gd`@Q>4VhxC62l3V8}|%|RM#Z|l=`x3b7gf@lXwT_`iW12rv_BJP>WIDvDVi`O@i zANtCm%lGG!(+z6e_zqcx-tLeumtP33;-EIp>e+_%(0!)Jipny9sztv5>5@H(8+9JY|hRRB^)l#*&pyCmiKb#+$7(g;Y7;A(r z!IJuB64-zoqcU{zHzqFn?{AA`_V~G?=kU4y)^eE%z`|1+js|>s$`Z_ddYx!9J1y70 z_1-{fPT)x!8tKM(<T z(fm;+jC%u$ljOkWH|GP@ECH`2VTxa~qCt0=6l($UaDBiy#!vjr#DM4fK`lF=-5O*- z@_EI#r$1JE)&5dU;a_ByH@PdRkZy8L=0uMxEt}DXd$hz(daAggQ)&Hn*RAaE-WB3s zHm8D!9Ye|f11LX9N%NTfJzDgXJ44veJkjtg<>_ljxqiNA-YKHZ!-BdC z;8;{A`zsGP5Tf^2J(tN{JD#u1fPL0R9b>2);wKqbT=`?tSRXSYfw04hW`t&#E-Q zgzJ0h(HrJ->G8#UH8HYbW}CykAK^%Xy#-8igD9daoDUUQg#MM{GtvyQ_Zg1hUJaaU z`vjqsJJb%%e7A@T$Jfw3tZFUI!oJn&rp(C9msa;Dd`e^PKC>0>5}9*|qjqR$XS=v# z)lY`&@0=5s6Cu#Y4vH)|`-(&fz8rVp@d_&J0GffVT<1TY#&@qb0$%M{6v^@O3Em|y z_9s*&F3b~LRu5;H&Ol{DE;#TjK^K5UbB76?nL1buqKpWuY*jupe}r0#(1V+M0e(nt z^Em;hJkTVKTv%ixgUZ?DHB~CLG?LHa0(~l+fw)LNg2Vtg7)!l(1_X6NIm;Usm^%2L z&x(1hq*&pzw+Ep6AntU1X>eM^SD$y|WW=S(4XDi;a862$<<1;P^1S-c+EnMBM^88j zcpSu8RKV~1(P!EYDErHx1Yn?adAWwK;&$ajJ+mgyBL^2`8^wUUL;fC-*5B#I4-W5s z@BNKqmn8TQM9+47U8L&DSBcYPG((^*iAnka;Ehd;(pLRB95)e3jM_9}xdgp7NC)L*+d(I*UN%`lmE+37 zm~xun8&#cpbtn!8WTKcyJmwJxJXCvaoq~d|lnD_4+}vqS!1S@`kjR<~dPw(v_gj^K_0=)_g&v(> zn1*Esf9L$E>(haS#)q9!Nii{w0Z$U{rSG>NzVQQ@F$CX>SMRTZhj2eu4B-%PYuMpS zv#yW7@03p}yEx!1(yxE_VtrblMJe4KHw>zVaH?$Q##u(QK(D{GVHs!^xu9Ys3Ft5% zeS4GnSAuR?d;~{##Y#m$-OowB<06#T7YCCZo6J*1;$%X2+DEObnh`&S$n#$J;rSa* zb)ZBVd59K))@n$>Vg5PJe+sLhv}K9z)&i7X{GzSC_pWbt)f_5f)E;{8#^l2jY3Cj_z7L=WQb_X>|3yv!(RGnGRF|LR!g4$Uqj<4M*zpTAY zDLVI-pR1qwP5U~BWu9Ke&Gc+;b6)n9FI}y@vH)O)DV$8SS@n9u$N-{eSKNyrF8PZs ziD0&$5z#a!n#FUYvOG=mVxkwl0dX0X9qC+?n-9K!|1K+EPmX>@iR-$%@FgnxLAfLw zam)=av~<~p{Q33P=rcn0hkzq!e|q$zRNK#qGL!dG7~HYFqtSXW4GoPdS9xaxPCEQL zNk_*_o53vq$_yiyJ|Pxa25#;?*fYkdmA*99N|L+FZKqP1F@eBFO}&U(7y!2OuON^7 zlVI1NCpM=uptgn{naKBZAa)lhzT80{OynCa)|U>y`JJ(cTuF?5*+S>1Bn@_SG6xpk zvqUCu+T&0^R?~{JBuuo7UqFzCy2I~BPdswp4xyVH)Gkpv)1-hCnD)%Z1PU|ZpG7g& zImP3GPRA`Gq~;{!mr^K<4x88Tf72U`~lgY#vhauzDu}U!bvBCW< zz#|Kq3B-V6%w#&b1V;=%_0Q@YiCp?-U04b?6z;OqfUS_Ly_`$hgXTZl@F14`H8Kbk zM0bPIE4%5;yRnaO2Fia;jx>9^#s67r+$g`Zi_qhW!Q8-E%vxGBMvo12&}IA|7XWoD z3{zmfcv+c{i?pkQhE5TbU}RUeJ&}LnlUvH9gK-YIaINOq40TTd=MP)Lyb*9ml>0r+ zBq|0I*fkA6OD>1czkk%0MN&;N-zUB!F<^Z`(0YFi?zSwt`8;&*?C&d2tnlxh$A@7y%!G~4{mXfPljmdQAX6i}6rM^;`QI|_LnEm{TRVDUops&_##x?1CGR&GGq%Dc&GYZl9 zE1kl@2oN`K2Bg!&1*bFV=n$M-_<9_4oO)z*56uk&N&&9LQvKoY_o(FwCvRqxmZnQo z{x}xY$S{GM*ntS2_%9EbI=Ucj9}e!-aVSup2Z!uj>EV z4ZvCkN%^^kJ~Gaib&)B8ik7|dB0LJ|Y#N6*InkZqutR4mRd0;R7Voo24P{RKswMXF zJL$QkTD=!K9}9c}qnLNn1ujTM9LMynQl40E)gEORW;wo(Bn?D$?p1=aXs*|oiQ(W5 zvueb1;+7l$A#z`OYio=Jb7*rCozs$OIV|cP^-q4H%Z!^JC1sZ+v2kl_8}=>N2TxzJ zYgPD>u{$%cwg}76k->k~6hC-%aF_i7|2BF+7Ja84nLL_bFn-d5sftlz;MWh<6ld$v zZhLc>!g*)!rT%9ZiLSGgB&jcAhu&-EDyP+zn+4aRWj-4~nPU4R<|UK|DowBn!?E? zn;}#d!X;fHpjU4xvO~c8tJ3({7@^E$oqXbdp^?U2GXEWh80n6AS7!xeN23Gy7j8Wv zgfD>&2t;`9Sn>*KXOD~|rJ;=N+Zp{zLaSf7>^!g`3P%wFiytH1#<1wjj2~M`cn|O2 z1?kU)_<6FyBmZnzd$MMo?t$l_PGkUz&f2+tQc-&`#0x&fkE$!+>A}W;QG~ zSc@NgCIEn4VO@urj^mwi85t}D)GQBk8*(2pA9A>oQ3=_YuRO!6*}XIIxJwA~B1w_T z%E4-Nk&!Pzd$EKA8CAtj!*YUUwMl3fe&z#o1v7UiZNcjgG9&g6`qU`5Ts+>rmQYE| zdaFpOM61ZmhFdKvMJskGF;q@XFN&R z_+*JW4BQ@z^thlNP-t9zBBl68!)*defzMh`^VPlHDD-JPWBJfD)i);OxbCz(gJjn8@Udj8#tZ$2aJ&S*Zf*lv-oa+|wdWi=-`@oC#n z%{+43^XSjj+`w6$LSD&IHE&B|j!eJsmv6mCIaZ28Il+0)VmmZS<>=WBaticnr%>EGqMA`|5H#cR@?C!*Te}9ZRJo&<)VW9!o^> zM+>yd34CiTP>QYL;ginsG{7qWoPug))|L-vQd>>?!V@3&Mc55$l7aDfQ`Fl-!TPO# zJ40d{v?4psVW`+YEApLU1SO1#c)VQn5#-J6K$0(!N}E$IietI&sDv{iqNf(F^Q_)h zSP!r5KR;J#8U>ty8^cuwJ=_QMG$~ZtVY$5>d zwwZ<;RRg%7U?e1{f%aI*9_!@HPbo8}_GZA0=Brk)9_f(~T22gYjeWS+{tOYJ^G-C6 z;8&`KR`jA}A9g7kC=PxLv?)v+J}#3lhQpjdf73{>|47&JN=gZ!N8e+pP+yuP(`!#Y zJMJKmNc>rtUzDk1K%>`*v-Qu4Y#{7m;Lz|3dwgeX61QBkh!)CAPyqVY+88wN$h4GRk0-qT$3M@{M^x|U%GQ=Fk719? zFB(2iz|~ENYVdf}9W-4!foAOvr1~v@MoL_2s?9?hP=W}-E-(&IQZ;h!1vtZ`U6G0! zehwr#H$#01pe|=kES<0hcX- z6j5GffR;I^xPK_u`_A?f2^+y?e zloH*2Vk-=FAcW#EkKaWFN}K3Q*}hkod!S&DGWa$g<`_MfW(@!j9|7cp1<;Cww-f*r zEg(6V!6SbVF*PYPZnPMF{C~aO4kmazAHUorD8FUS;J7wOoFYUN7eha1Kg;s{ds&|@ zC}?GL>EP+L0imaW%L;yqkZq!ohljqr!;o5>A&Y#ASF9DkljOADCwRUkX1F)V{zOKo zhZ52*fjZo;Kf;h9?3(#bNUAU19 zPDjjBSj%LC%vYrh<~tz&b)RaE<$Kvw_hYx2w~IbzhuF@S1F53YvtG7EO!ZAamgdIY zLrF_$LGlD*w(^+P#4#&oXuO5w_@%PMbgK z3qB_bgYtJsPWTW1CqE}^%M_+t|H7fSydbQPO`OJMLK8(gi# z5vVKrf$s?vt7)5pf2Q2v-^J`xdrR+q3w<~LdJmrr@9*bXUGW2P=D=K&)=T=EPT;QjdIS!sgb;c$+4w0C6u_6KSL_G!0k*@> zuqO)%L9j^F1I!687A_gQCM|1=T6k05WAD3Z+@6-3$8k%(&l_-^sb!GEm}pb&0C1EQYk^zyR10_ziIHclbN=ys5nSX2Miyt-C(DF6mL@ z8Qq%_Nqt{7U&UKzm|Q7FgTf>x;?8LHP4w6!iBl!;4It&Rq#-X)X-jtCD5Qv#i`Q@1 zX>GX#GknM6K>n?~k)PGp1q2ew#mndC92~r`9xaTG(O->Inb(S7Ss*RGRojE|a1v7LZogUDT+RD^!lc(Jd0TV54`r>7;Bi<#Dz>pib72Qox%h9#Q zdOy_jIE;!nIhDWx2i$|xQW|FvL+8Itqbw@3Z|ZN7OMkOoM&>>4sfp>iqW7DZ7mqyb zPcUj{l#t=XR z?>GR0fpVj67xq&7SH8h-;0CuZuRJ-Xr~K8GzJwPuI=U|4bdEeZ{zKssWXNSq11!dq zN??K7848b}e0L!PX4nLEPD_;2zE*DRf)HCU$dTv4$-YTEp!&*BcD#1(bk#m@;a%Y2 zGJ3F(Pqt5buY%_rgHNStwN{zS2hqw8AZ_|)+if3vBW{V4O_Z|(jkxU#91m05)5$s& zRKdU2ZoB7nd9vCApg|T$D+K4&%dc_`>TNyB_rNk5o8P%JM;cyZ{9tdeiFV3bWD zy2QYdrz>240NF-raL~F4Mx_N;{8$URrb=29Y7UlJ-zA*hr%5p-A*GUQzBk@A_)NLI zFsOZQ`Q>&a4M~Wuv2uK+>OkV*h6G;2WSM6?#nT6Lig8wh)h_soNIEoNx`k5`3Nf0z z&t7WQx&b^HLb4SAMAvKIaP}Ca)TG)3+eczmkjeW%&apYp?H8py;b@-!c4kiNhb?I!T`xX*e%A?j}_o^1ovkAxLA0jx${Os{f~ zhS{<6>`K0Bwav-nWLa10egVzF7LnHx%B#Lku@u6Qa>CvmGM2nu`T!2)S69h*{X=vmq-I`I{vW}fM zXwBXX)0{Ij6z8$Fwl)r>6tYPG;!$JiEb(p<_2=J!z{<;yAfuQEDl_T*VOSIbRiG$n zkd#pZM|x)JR^q4Bnm7kZaw!*TK3_QIQ<5co$t=mgRD#dwZV=M&k%vHizT@KJl)AO1 z3KvnSKOo>5lSqsXbtnHN)!nqz-sU`AeAgwj=&b$DhM5xFI z1a{QJNTm_5^|K0&poq)?@G3wyipY8MW>ef)>6+SY3it>=u-I7*F3?(Fvg-+mhSCk`5KiWf=(B+*DT|TUnJW+=e z6mwJ*P>xj*C|DW&{^2PyQqal(-b{c;L7U?-@!oC#RU2{Vlj;%o(|l6La78tc`&a@p zBuOAM-h;>66yxzqsaOe)y~UJ{yXd?JwEdFP&vLtSq2Q|_H%2nF4ixQB0x^T$Oo=(z z?%%Z`X^;ch>K%FNyz8hH8l_mM`Hx18Dlb(1&r_zh19Zi`P(4>W^b({X@*0X6&LfKsu$ z-+T{-9x3N7f+D*w4KoOss@gWddF?9bvDkvURIJ6}ogj~*5b+YzVrWA#GOmof5dn$c z2-*+pkt^M%I?DBbi&tYAIHnU;?){+};)iHL7SH07*;2#0fM@$uzL`s6w_(`QucZXs zAvcjSR_|N()H&)#Yh$V%e4s1ud%b9C5fuRefk+KIOapo{FBM{5GRNY@96o7mI)a!R zWZnsO32U_XUv_<=FtKy~^Hh}XrJq6zf_&In3Ip_2GAq7_vmwoqOcV3iV-G?PKpQ#+ zfm844Fzcd*I`Yb6+U)f2AKSePBA}!fLx$`3*!@>t`cvmmCdjE(Jss)he~!$8`9?JWcX1x6XyK=)5Ri z2EP1-fdR@-;)nBiW1FXR8wU*!++yRQ-9&#nZpaKpJEV{Wi+cx@5BN1JPbN1U;)od^ ztxlGn{at6a6*sWpc)0oThU;Uh?>Cyjq1~;=K_%i~w{p-)Bq0YxzoRsWB!3x~GD9v= z%N~Mq#a{-!Drgx-g6h0+d${_=x&}LtDr47JNxvNw&R~wXiTmc=XcHnN?YXC3h`4As zg4TX03eA-zgn*Wf(>N0`6VPz?Y!fo(SQPDJ<Jv1U)tBn}3pL%6aMbju8c_7CQ; zKP^~l0m*WgS62|hW5j3JQi`zaSF}d3vuZ2AwuNv-5^)k-JZU#91FD}l|6n@>k(HM$ zA9rvo9X;p(mUMN#D!~#%O_2knTB$c~luILkB4u52rgfJjv#A{J7^V+PS6s;1&7fdQ z9J~usFM?n6E@V8K2+~Vz(+YBmt=*Mbi+jBRd@p=At=YBnbHpZ_JsljhX{Tghve8t` ztru1LvIIVDUqDaolG5u7ZOvHoUGp`J0-_}N93{~al3U`^S&ba%OMmV==)dof6v4~J zq#BK*oS*3jBSbev>33yn^eAy&n>FkB*vc&M3MgTByAdgPxQe zehI4o_W@V6Ksn74qpFewK`kju+tm(49?hpy?k=-N+?2Rqa^wE%Xl8spF1hX3QzTB* zF>pgoG(7CwVB{MhtwWi_1jo^#fCp9=JE_1sdp&wA*1LI`gk1C6vGoZ-$ZM417z$`8 zLq$5_e&qL$(;Qndq5+~oH$u=e!HK?aE>#TlMD+LFpS#V?D5OpsNE-+>GtXvsCCuqp z`%C-#^Xg!Yndrf}Tu{HpQc3}9Pcr}e zsEA>UWrgKSG3kfM3K&B738tH<3{!@GP2us?{se#_?1yWl*c>m~gV%tVRxA)5uxKPj zkLg=it;;X1Vw%V?F^aga4*X~3d4VyEw4>WS7z)iJ$~M4}XPJEhMCMX>=P@if=@xHkw|H%F zYBku$|1rYFje%JB)zpoUu5P#SuIRa@ zo2GY0)~``fgc$hxfNz6`w^&qDSV{AH<7;XnW61wH+b%TBF|h!#7owV&N4K8+(y8}9 zT)bhr!ono9eO{G9qq{~9&Lccc*CmP8wrPjN9#i~a2C5_ZOAP$Ek&We99Y@_RG^n*lwUo_4eI`|9Bj&m3v zrtQXFeWpoYikNF(she`RGtxPXQVJ6$*62y(aIM@t+qDnw#PlEOFQ@-@^g$b9&;aaG z6a<B7bJ!^nNY=;mvK4YCIE=8zxJP5WtXKBz8XR5lQek}U&(f5L1BwvsX7kVjl0(FF) z02Cqkm91(NlYGPBQ5|uTI=I(+=dG>qG4-Tss2Av-6yk+E#cGaV=sb~%A(UdBMnGg0p4EqgVQ zw;b$OX#(0?+L*U@1miapTd=-gz3d>tl3c>;kmWmr@HmXWDK~Zn5DH_#7d;K;crN9P zV>Q`Fl9BUk^@iBD;&UFWv>!6nH`+35nuNu_ zlIYuJcdN*%Zqj#lg*HVJyk3aNav$EAv>Ai7k>aZ=Be={E)>hO)HbG0GEsBi~UNDR# z3KIip5D&V)bRY)@opZ`_-dz9ar5igjhz)FsvSvCd19qgoUV?-&fRO7nI@xnwpP-3AU66(^! zo}K_sM%r??ri)6qJgK4>R+JvWaMM_c?m&8}G}4i>Q^3y^YXfhW^_*6hqf`wRDvhKI zZdt)lr)!&vBP>OOwca6gu$>hTe73^G{UOT9pYNFQJ*g*`S`3xh6DT0sCx+o}5Hf`y zVdaJ)NzK0qbZqCIezEA&r+nw3Tj&^kUMKY6cb#1!OKH&mIMOjwP>}fL=J83=obs2I zv9c+}53}A<0`VX=rL@3JqEe~voB+2Tdr=W@8txE`~h)xIMRyytbrkw-Qs+H>7 z7QaVH6+S9Y{qC(=N>8v7W>wj0CdU5#*m20- zE`OF+Y#foA2WPFH8?VD>})lh5aH_jI$>(k2;u6tM9V9Ug(CD z(3^T6oM~kRg_r_takfA4&xtN}0Cg_9@*7;&Sl^hO*GChuO>p;d91M@g7vrQ^*rqC8 z5a%{Sf$@_!Xbxm9svRgh?1CvoG(bjl4h*I3k-E#Myk#|KhpKgiLnZgIpAsLDkl=Eo{pQ)6kEE^Yh`k?1oWW4ut zYY%%`qJrA5^2qbpA?urd_!9>7qjcH)ybEUOLG9L+yN86PTV)2qNd&j^LcHf>7W}}4 z?4Dc`*x-JeM)8MwPtW{v`fjo&=CQuKF4}ZwHOcd``r8k*JpU)MM%<8!<>Dgk z@3CxeLs73Ck(7(!_-tCBv+!*9aP6I$^KnNj;98|*&VQ9^d&Igqr`U4=-wEUZfk7qR8kQEalkDdSGax zUY}f$9civi{Ym~qOy>Qj&U5n9$T*)8o^U9wFj8Be`wIi<6&wDLbkK~+V7fpl zl!%kRe;>zKCV90))!RDk2Kzt|{xq;S0};b&1N=N~>}bQALv)3=KWA8W0U0&3`sDRQ zE|Lybg3F&6%@ALdjzu_ZV~28&-};-qR)hCKFs7?R9HHJ@LnjSomr#Sy|Lp~UI^TBn zqnq5=*n$x3^XYeHuC4`m1H;g>P-T`q%d3KQ5}7*`85A9wddehBo7E0~rcnYCB-Iu{ zl^ttvZfKm_6T!<@e!B#M zU+V!2=bL!DL%=R6R?eZp$(wMH@w8cv)?hxEw#?W0D6WYhA&!x4K&B z7rVdzC2A-`)mea6Erm|=rV5H4MygG!g6gNwZ&^wYIEkzm)sO1fOnti8TtL4rg%1$G zfcxP30(R16bHCy*Vd~B|s{HiiW!f5(QA#vdnG*`wSk;aAaOE)wu8!yiu6?U+>Lmbz zkBuzxZ>D%){&`MnScOP)m^3c5lIo(h`^(`VhmIKN8v`y9T6R>jJm=z`R;fHN6D|CB ze8H6tA@WJCB)pJ??qX~^)O4CGYx0fC*;f6Aa$W~>wKDzd5s604u_b_7ev__=1`?6T zh=VlLTj{C5NdyD`Armg4V@kJlc)0f$tv0X!bk#2uITfu_jF)jTmjk3Q_H8>l?o%4F zT_=T~!Ka!a+gtWBBv+n@5Z%xv*Jqn-y>E8ci|QE@?OxXr?L|_}KnMNWhcy$wYug_Y zG^*VMGtv={F(rFh%8IJa?K?Dox^lV@CRVzA6J#DOI}!Akn9V73L=764GlbhZ*69j7 zdF1E6&Cbbv&f{EAfBwS0*x8!qOMjOMS$DRnGmQR%Za(~cGfwui2mpEKM`rJ@&21}?jInu%ZOo79|s=g((ZGAE4 z!v~saCUGv#tY^;FS1an;nD`8?h!@Xe$nuM0{DYgJAh<^ z=_79G8C)u5H^SUtc7lo#I!AERww-0S)IBj_j6pEa+7#9(Oeb>o4Pz_Af0=s+zQ7&Tw8=2G>vLoLu)`fuI&0{%D zG&1y8b5$uYXU!;Kf`E^bRO)Xu8B^1N1jGT-l?4WC0uH^rpv6$L0} z@j7VoeoDn4xp4NG1-`&#r-jfxb$Kb6R0(=rrb@1Hr;1La@{hgk6`>n0Gq2vr(jew} z?T=7EgG&ssAhwV^C>y0H+s9UsRFma}Kv}jyb*vfoY$&3%bd5*MKgy=#Gqj5o7Y&VCGj&*(MDn7KUP0T}Brabt7iE8a z^66|h{b@NZU55!H4t}SMkR+n!f#?!IM#-3QuwWWb&UnL>_T>ym=Pr2pLl)n$&^jb84N7|I zJ}*GgHKX)VPC!a0%bO4-&|$_RlB(mM?o<0UJ5y7%_d1|>iL+be{DEH<3F4j5-uX6D zIh!bpF7zY_!j{NMbp$Xb@pWIB7GUKY88VXW~s!lyN4t}h<+5s z)>kWVVPi*?ld9C*P<<$7NSs+Mt4P(9Q%t;G219V028J!sN#XBVP%^@CappJBj(PU+ zwdBXw=ndHRKh>t}k~LEV9^GULxR_VllM`8(GuOIay~QKvrPUyE0A-O(abLFlp{N-Y zs}+M1`WFyft37?qK64LR)8*ASr!GhM0mr!K9tTcCPgr#h|D8W)+DA|_Ut6VIN?iSN zs!2jWH{?3mp5+TLHJVHWtxe=ohyQ_L69Ge#>N!lwv{|9Y`T%k|6UcP z*>TXqYi8;75a_Dpb7Jx3AUa;J$23G6RP@ydy}E`uL%aW`E~ATX2KViXG6b=;-YI{@ z`3d@{ZL!k63w-am(9cNY7T+pqCP(Z0RF-7)H+HuxI5F*2Xo}}j!;<{zuAQo0i@z0f zNN1i_XJA9!DaXI1M`DlxYvSpRiTi@~fsE&4i_JyJ%m2>zT`v}x+XTeY{uJU+owgGV z<*8lRKQwVgu>bYi&(!Hz8*#N^2A8P+&`)8Q8U@TZN+)M|q(o5daTv9FV&^~l%E*kN z42K!#`n*X-6fj0oxpx1~{<6COC3vo4n{$g~ul;*#D?9&~aW5az_Y#Bi~BkA8T~!u&YQy(N)ysqYij z_)jZYD2;T{Xah6}EoMvyk&FusQr3;Nwvs(^n-HBF^8rP$@Mq?FJ>@cTeB~A)XnDoO z#hLHehLp=K2h79}n4+Yw5iotI+^0qMBpy+_Wi(&{YdCPz|lgl^J42%Y6j==53X*j3BBIF z*1mP22&KGyHY(xV}2^314?lLMx`_j|yRV@#to*wi!p={bj?6Gdy~zE2Dd!SWErw#PE!-2r8a9 z+97(#_PaP%E_ntivlezi%9pCfbC9dPM#p<%iUkvl%!l!U(L%TXd7oq z|AkM*65y2OGYtSufH+>vL`wOM-L1Jm;MfppQ8x_mU8EMr!SxAOLQLmM^Yx`tUXk!Rnw}Xg#3FP(%(34D6 zbq#1pwSNGFClW^OOV8Bzs!IUL@Rmbr2sJ5Rkq;6!;EhM*+!JhpiP{&o6aAj$>5zzX zO32du+vO#3vpWHWoPWW79qYn^o(DCx_I8y_i71u+eTJ}?+2`YV)|0AJReO6NFNYhm zB!u9o0Ni9UfCh`&rsmy)7KnQK;aAHZb7iHA@7>q;7#n+Fju!37n)f#MDOB z00MwkIFUsj1viFpg2%#l0487+W~14a&>|IC8!XYo!BaOU#OJbL9*4(Y$sgOD8#FIq z7jg1WV&RcAN@G-=8K=wV9{p3$t^bAqeXF6`ppvsF^j0*gwgGl`IcQ=zG(>2*|NHx- zJ~STg;lX5PD#F3qtomI96E$6UY4WPaKi)%96~Pi0zVxp6cQfKVKtm)_-Xu{`{Hd%o z{tHr(qeH|2rUdCW0Pp-=*=r!h!ah_*kc%14cY z#~uXX040<)?U`!fcjt0#wsMDK+Z7hMgAMeB2fg7#((=x$F8VkEE-!-|SK6;xCyG9k zL|X@mQk)&AydDVr0ppqw{wrJHr4OyN8elvAws(~1n9G=5KBoX>`H7)NONCt6%a_J& zBIYJ!(&p*Fm(GU<8{xOa%wYN-XfN-F#qIwG7@{D`ED%1K)KgBP1;bLN_27Y1c!BK5C*eOm4?V&lTR&l3;Vm?&0kg@U&O63uXTI z4-b839*?^6T{E@WD?mDW?R;+LU@J1G%cP6C)NB7Yoe^^b=G~{ws@rwXp^y#>#FW8G z9u}z`kw|x*>1pi~CRZZzotho#;zf7ig?#+6o=^P}T*jn7Q6hq8R)VOK!yEx3Z3>bJOb|~(GlsnXOA;Me`G zQ(I6|K8hE9x;aRJMrD;nD6Xk-gLO9Mc)aivNe|6?BRnuT)HQ_qc{Is`d)4Y^Zb9|= zb~N9@NC_192$p~0l4Nk=6AoJ4D}#?npi`x%*<$T;Tmt{4v!dcZU&GW~Uq{7xbEtE6 zu#HRo(-Cuvg({QRdBWHwJ|#}LMdA~__LS(U^0~;o=cbtEht8)G3L)}Re^5Xks6<8@ zTFEVTB$yZiVAzAo-``}>UE~9F^dC^hSJKo>rp?5onD|oedF)mrnzp^;vN=HZi<&Hw zOoB_b+Qn((8C8Ggy#hMp1@$A;#H&hkXk|m}DJ?@Y?`frANbc8RMSjM|1>AV9|u+rkm?o>gzv z_~JvVcW__Ui(nsl*51FY-IWg0jbZQKo)742H5Fa|ifm?DC9qV7yj=Mi_rr@T6Bqr- zgm8lJeA*w65tkb&{F7{$jD_=?{0}DX8Vc8^gL%s7#j;wd&7bP z>hJ#2GSU?V1)Q;PJWDlGWzop@L!#x1dHZGCu?Pfd>faTY25QM(M z#74I}jy03ec!}SLb^wjBmlV*~leyK{iMnVZAT_ zl^D|)sS8|?E7wZZBo9SvVqw3`@mjGRqQ6YE9sX=a+x#otfcIexPD2K$*7PzdCiv_B zs3MCKd6LVJ7gMA31)=dgS{vY#AU7a?t$M+5OBie((8z<+{L7gfW{npj)d)yZ5bPBK{9F!<#CvC+R#~$>uz-IkDO^bbLh!zY@ClKYaoa1X z#hzReBe!5&fEbJBRq0Me`f&Qxf|jaW|83AH_(N6r=o7?y&E#*)K@e zY>jBu$Dr6TA-Ee2ie)bZG3ds*=BFvl*0vH-3~mKSu&Q6)P4DE1Vo@=m#GvY#>VL7} z+6)KeRK+#VvFrr0JRWtQj~zug$L>!+1&1GRss-fEoftY##q!EYpE5Y7YP}_&!00qU z6G`L=Zh1824J1UgI#KAzYHphAW?emrh^gw4FzG-MIn@|w7;2@WrDyJJemNhg z;v|{UjO3mBTdxJ|M6S`J;+|?fT~V~9d0!qn(rDsKoRwC zVXz`yi!zTr-^#M9<~d@zDi@A%6|6!UTVQ=OOUi+o%SxToojZF7(IL3$U(E_Q`w!?Z zX-zHF69csr-e70K&ZtC`)U+LU+XT%T4ef5_pb>?%%4d9$l!YO_7?$1J zKbiy&zY;J1d8HjQCJvdEl7do*yO;|K2wJIDvfgi$;@Tp4)qX2C)n_qGrcu)M?2^i+ zLqbu>USV{|noDAe+Yckg%YL1+c*)?>P5fm0s<+h< z@XY{PG*XZktk8)gY-=GRKA_CoE7Y zVg8YLN<-KY%v7gLqx8Bvx7j-BFf3McZBX3xrs4Y5f_eOmemnkYIdV0;IZ74kZprAv{6?`Mm-gkAo4*HKJU~wCxGgqK=yVYtJav`lZGf1*;N@2wqK%xj}(M{?v$L7 zcf?F;;_mA4C-GxSuQk;iEb>LQ8AI$4NJvU`XahzZFiqV5nWi@BhuU_qDbA!x$4;$@P|iljXuc{bd<5-fTewCzqFEU_8-1O=t8 z2l$)d2sa5D2L3`i4Li_`940FDceU{GG&BKx@W-8W+v+|iVYtkxeG!Ql1YXpOC#-`b z-6Egs^J^~#`qyLm{79CAVczWJXtXF!)_ba@6r?VMrh{c~hm#p`eQ19ZDZ7~U^bjld zAd#4&YVeXla_+<{TCa zPh>+>1Z4@(GGbX-Y#r%D=djJUl!OQ68R zqD#}$P^{|odmWSumq|H-RY=Hj5Q9GTZ)4TTMuDZ=pX!r)jgQ&q2@gIz>W1r$n z>Utz66BO+S=qtvwIVMS=VFU1gm}*Y4!AMA?9x4Y-qkQ^~>B6tBrnFGtLawvyCzU;R zAc-lEmjaVetaXIF;NUsH=D;+?NJ9FnPKX}KL-vj)rR6(KO<$YDif9$f0ed792~4Yh z?qf@uC&|W>-OnT?OI@@C1C^S^_3eL@P0IYnf34@m8%(qVcxg>e{3j4DYJ`AHBvSfb zgo7|cncfSZ-v!^MitU<^-ktbk&!YIviB|Nul0w3u7{T%PwnR)lRm>HY97Xnhnvyd= zc!|8d_-yDZrzJAAmpRuv;X?1p6YCrR*x0u z)LWM=chSg+?pG1_n?AMc72JmFa`8|)GVH-$?n58ij6-rQ92H*CfJozt8d!VqJeCot zpQ@}rFe#;%pnZ>7*~1Q!YhqHouHP9BP3q7Ze{-?UYnCsDbP7r2TAIMHDKRC97P@}$ zCmyFOt^~ieBl>Mpb!v*66A3zpv@TaIuet~}U6d7a1;$@GZ;_{=$W0>hSZH5+mTQ%h zIzA}w5$(A00~!T`5rd|Bc^{Hs>=)>c@evIeF_;hz=FlMTqL&YD84R;EyW8%C_1ug& z8ZGuhw&3Uf7Z$X&pIp34J#@&~He7F|RBYB8O5;1d$caKmwws*lUePY-)xMPF{Pm%$ zSs4obfFu7+m&lxTDA6h=PB7+;2L4m4oD}QeTz}$aI)N1f4bBZz=ZGFfNt);5nPC-x z)Po5R-sNlXBL+&gL3(wRE_LwaM90XZcr34T((>O(HCoo2^!GWomQly1DlE=&Yu+Hs zVMoK4C~9sde;Z|m5fjTS*8EpvGlaX+D^WC=NY{(Wsmz1ZY%P@%8`bAt;j>b{n9THN z$y)jT(JaXm`B7z!XR4PjDTT=*BSO{Qc#6DkLF|7^otDUxf_{vTr3p|+lP1q*`Ue}( zx-p~i_@eq@%kn!15Dbho3QGZ8%Kh91u)i{�Kn-87(87yIo2v+x=e$t}=M1nL5t6s56>FFahwB5g;eFx?*h4Fik zk7;lgP2%x3vS5BvBeb82z%1mK13#Bi_`z4h)iNz+dcv^3&PEJwB%LWxqz#!dqF3KR z2e?DMN0J{h${u^%F`zfH;jl?$SDlUfppGutVr+A)L3zHdajFWs~G7ZhKJ4-EM);>j>Mcopcq$2xU@`M{@PRR&H~_lDEPm3Ux4@@k3YI%~QhK4AP@N7LA?Z#K?W1~g#kuQuvv1TCysSl(cU;MOuHW|LDkzQ_7QC>`vo_RcaA z1p*es4^do2^H=e%IbS6Lv>BpYhc%m1$2IvfC>fQd<3oQIzh$#!D^^HBlLgX1# zL7t(x-OKD{bJU|r%^G~Ue00cEAgc0ZJz@mlb zqE{b{<_XJoGqOhiqRv~RRL@0tQ7MFYt)sHW&B9s6UDq>=lonkk4~QGR|KJc zDunmnBngMEDGI#?6^nYfFUmNZAY|%YNtL@a^hQNlIng$SI)g-oy!m_aep>EwrgrS* z#*j(Cg{;S~(mYSg%nMTL8;R*LS2nd1?jxB#zDeiL1q6-U84Tt)V0dq;o6_p*7}n}7 zIB|ZsuC)Y0-F~X|)(27$S>Tl|q=e0A`Nbvl5@l!nso zeAuN#BqJ$7G7{rdi1Pp4uFxt({u)K~*=YWTUmXgx z>n+(dQUuWh>{M+)Ezk~yrH)HC1Lr-HcZxT<>zFUe0M_e7Nz_O3ayJIaji&8@>mD|L z^zW&e#i_KVC*Xk$~oIP%%lAb+9O83$!T^a*dmS+p=b z**TrWy4$sPyWc-b9UY(@k8D`wefYX#&!1MNf*L~?GZjsNJ5pWq>~r&-_sWE=D3$;T zDleK~wg+pE>%jZcg2+vNchD!6ll=O*sOT~T=s3y7rFiDjW&!&$lzvSMLbaQR)s@UH z1Hr?G2j3RnEq>NiL+RpB9O8`7DR(28Jrb2|-dD*jnwrmOUr2l47E2Y&YqAiLIrgVn zMzavbZM?bq{%~$w`8VL3W1icEp!M<2La+5#`3&$H&%H$7{Fea_^j# z``FhmLTn0MjzPI^U}lkK0jW+VlA3%YZ2mEK_&kK>O2~VpJBsE=?O$KmT+*K=K)T_x z+!HqeF~q^=Kr24|)ru&AwWQRwsH@Wo$>`tC##K7(4-|a;%3DT{8Lq6)8TSezlP>^a zfJRGKq|SA;T#4Gvl$4G}PIdT4Xs#Sywx9WRTi8E$`1#|oneORhk&~`=91f^tkrG3V z@W2fa@pz-1{Kkv#cjwcmv6s&W1KI!#sF

vzt&RNNH3?3?TyVjCtoGCY?=lOv z8q|DK1tzk9Hh7{ELqZRQPB+{S{uH~NsR!W9Eednn>9Lz<_02?#Z0V+G(w=wN*een2 z;lic-r}OF08XOPEd8lLrOFg?!?*OlzDxkI|sNwZ3v>O*--5RttcR5Bi85n??Qw%?8 zw7=?cJkCQY$osuDbti-w{>+31R}GA{4~$hYwt9>Ez*$m!(SKs(;)K4K+cVV0)SBc4 zswb&!GsBvuWWFm9gj=2fwMe(NSSWs@9k|Z-?Z^S7z^d#QR+lTV*z<`ELSvQ|3&X%6ZSRGN8fZE~t_p zXqV>KIV%HS+v9Vo`r%cjx4wh0QR^sfJWeDNxE2;QhbgRv1Tm(vVHIL8Q?+ zlEYb0=+x;@2yhsC(&423Up>Dz2$g1d?74Rs`7WRLcd&+w_E*9nr?!K*$U5}qP#7;L zg>3tk17YmLU{M9?s4kjIgm_*~8-6%X&{Q{)ieK)+FDBjAXUsUg5x5Oegw%C{YAX(} zE~^X=uG4;_Dd6E;wUfRcvIC+#F2O-vF2N(V#t_prLTv+GWX@_vq#4rnjdQrV^WJpF z{juPo3-LcgwU-N@rqIF;9`j9?ARO{pee)lnU#NRKy3?dPahej6d`m4B57}e(s#wLFwOZ z=D7dAe)G1dK6DE@2uyeQCU_rb;k6$Pqm~h%1;vgmSx<>Iz;=9oU~~H#?ssr{;xhHU zUY4&SU`vw!V3xg?5f15=Pv{sX4nqr8v)&5lRXa6sjDfv!%TdL*RCF{bbzpQ;TcSJt~HUJwn3x4%Z#DzO!=!tv$-ok1mYS5hx7K`R1xBgyAA)b&U?O3CVYFq z&+j#0WB1NA#C)8dm6TZ$R=a4TPFrJ|ww8}iRoNb6pkapgWhj=A2nVvYu`>~cx=bQ&8`R|dhHIz5=@++#E#*mTS^BB>Fm(kH=f=5GAj&`VzydVav?PLcz?AXk@|yN z!Ce;)$ZChTQXj`%&M5ueo`*BPEb4XUeS<@QL-tvl{1Zl`1P`V>jEGRToLl^(Keh15 zn)IQ`$73ZA9fi(^S31Uy6=Fta6C(HpNlB3IuE+1q!rI%FG;8xI1&>+rT| z)sZ=msEKeoajWYPfPs0W$U*E{>Q|23K5VHCpD_r!c8X#UEAD5f`4O=h2>noEb8rsj>Se z!L*y{%BUpB`0|N?5b5Cjo_71Q4`uvs$F#GBFT=hpxNzA2>yrUs-lcb*cGfzccBD+@ zl(&k(rY9|D*lb(h5fTo;sPVp`}Y0z+Nos(CuY4l#6>CxAW7nalFW6Vt)W zTZ0`WW1HLnN$SX~Ap=e;~SIcBNmr-KUThpTfM)g zcL44k+gU4Rd~ib#{ZyhtLPtIz8r0txKl$G+Tf&mpC~5DiyS;ZZ_eYzQ2abxTVZy4Gng4t?tdW ze}iLb0t|x%zMJ8~;Fva-3&jGsJvj;GW6ni1lw^b6l&9j)c-4};yRd1d7V6)$iwpR) zg5FLz^o)nSuk}R#t@tD}vm#MUj3U4+%I+;UndDeD*xf&`{5912o&j3#sPLQ?A1_E8 zLmpOiZLl6~7rvO)u(~+SXI;E3Z;W4U$PZf5uZBvDr{MeQ3##iF{%bCr$0|b*UXbiZswW?|8d+N$0;87KJ?blBZ~I zvI*fUfa6K0AFfqk)*If!VbKStp>byZxQ#TCavIGx0+%=u*>&VOVwH>d0HpZHb?I^Z zRf~p1WXJGI)t0F1!V{kYsv=q<{7kWrB&Td~6fZobLIyCkh=R@T=$|ZncqF%2NIDQ! zkG`-J1(D{jsorKnX;F~~rbV3%*gim;tscd#Mj}ni`RLd@T&pDP(L?FEZObex3ctH& z)_@T`Kcp$El3lGps&X{>>BVyZ5j3nI3OTQ6rg+&Z%5b@--!StG!s68pp`xX zipCyj*WR2*FMlYm_jM1}hH^p?!4M!oZ!BS+{Va?ww}X{7((w3pg#gElP$vLverEWI zrWi5qthd+KJ0_J@*~XqrtK^H6(h(Vd92oxI)%XgBD@I_g6|#WGZbBN5MM zaOiEAQ`i9OLU6Z~Xu}fUk$IJ0|23|~Z!*e!UmBudvdgfn9!#r1<3RyjLYaxCJ29x@ z@>?5qE7o^#wQB-wlqG)XgiNtbK+|pZmAzQG@nPy^>0_C-e~+wn^RkcY0Zr=)2zm0f zX^deI3nGE%bd5-)1%;|mKYatu9nG-p@O&l&%n}efy4!k~yxV2WC?S>%!X_wtHz9ht z*8A$ciZ?3L z$hou-z3kqjZg2T62%IO+!0}kPo|_8DbvtL?Upu_WzMCghXfbKehI={sE?SjRJKUiT zKmV!8Z@Mf{fpJF7`%P~7#e!V?Q(M4f18+qFbd@g5sFqSfAlO@~zTKOtFnWrL{$3+^ z9pFMeN*}>_m;5i68$I@4d)0UC|GQm%XvNxlB6gJze!m-+de*xqn*(ufdNp7gc0y4l zXS8>j7D>qJyv2xlZ~#o&rCMw5?6*0pV>3&|3Wb@@0j@U%fpbUh7e;5-;cXRlqlYq*+PF zVuPTSDp#Rr`CaoRjxYw4wthyg_f`%o49*WNzMIeWWm;{h__8^cI_I?d7n*^*54fp3 zf{r3pegeJaN{}>@I}zYfVSJ0Yj+T%{1!9D%fz8QpTj3aVKyc3P!7C2OW*sV>It>*4 z9~RMZzhtVZHG%7;RT|suBmKPCwaMXaBzWMj5x>nyByE={K7_&VdYpAtU}G}qB7CaO$2SFB$Baa4l-xm$fPRpVNiRH<78PR{wQVZy_H(jC3Q zfPc$3XTNP4-hL{PiXxrwJmEfHW|ab0sl&n;E>~t;F^_7s;+RD}Ne@Acw+ONXBM%oD zhF$lvwvAb%;DR5cA1T0Eg-U~gJpER1Hg(*{b1_vhYJc>Xdz(+0;P7ckiB-&-*{wc$ zd-YYvs1D`1D2d;Y+^q9@m8_e>aX-%vH}*Yn1qS&jjx!gTW!mTT0%YsoOFXcJYYxk1 zZ>X7HidG)%JtX8?=TXKe?tgwWUJn{x+D4B0xYlxCs~g1&x^2$`+Mzdq53A<>ye6Ol z5Dl&52%P?0emK5AFfU#Jzt55r>^vL>Kj&7w2laE8d;K$W+*CK(YShF8$E{p4g*U0} z&YC{f$3~Z=R6*oli1_UN<+t6#zBuxw7FrJ+C9u`_+_26SIr<$CP91)F=TT|VMlySv z$7z_()IfxHB@Ig)-$PI#m%C5nVSthvYdzN^sh}$;1clWbGm+VDHeWKNvgaGS7R+>6 z6Cxu4Nk**vwes{vyhGcPWg44KQ{5&NhpR}2v1^ipVdp@J3RH8ep;ZDhmDUAlpn<02 zcI)6Q_4_aET%A38RG!T-ho_E`yFwQ|0}TLPY|!pq#~E;J)OOi%TPf92uSx@-t1f<3 zX2?TWnyP2HVB|*e0BwXCyZKq?+wR4dFPk$4*;oTVd1{t~j)6-Ly(i(k^Lhb*jMfC5I zR5stS%I!q*uh+$jR(6y5>(mM8h(Qyy92RHM4=`@5T@D%?tpe}W53$(OA1xUa@}lo& zEP6+5engFjBgRJ%-Z}9>TiYle=x9z98#4}A0=hk}_#h60qVuZO9UWRVxnWneNy{W?2B<00X+tE|o%>mIhHvVH1|JB&c1avkOM%t6%k(tJ; z!%K%#Vd}*VFR|KwrIm*J%bKH5=@9h6f2U3%nrM=wUWO2(O=BXWM}WB_?YTfn?sI|G z#^(Zicy`Y2Minm3=Nl{fbId6cv8EWWkxR3|1S)@kbt9MuRkEJBd^G=90vfkFb#BG% zGc4n^FA0#rIg$#_YVfF)c%KTGFo*lN3^dFztM_L+a;36VT8)Bg(SSb z=b~Hr2WDvM*TrSJ%E5las(ycT8VKs5U)N0;S4Zr^u-2&G}~wpRE^> zA3a^0x=vQS1Oyddq{|&uY%p5o5e8WZ4n4ectkC&~B~p{7l@|}AM^F9I8tgs&B1J3M z@rQ6wnn7KreTLuR1~mUI+}utWysuTs4yYYp^*?^!Y0Ed^e>pf|Z~&a+Fvn}PDVli1 zpg;OorPQ_S>pCgHOSk?%0d-ivE^DKIjq3Mf-Q8t~zc=~zxU4v}lKjHaJV>`{<9`yXZ{`?4$F1ng&0k&2GR&eMG@w7Mi9!&}0v1sbk*XaZ} zyX`{c+2D01!pR*wrkdOVE|6w`f{#D-4fu`CIY?qSO$@#?E>O%wTDmG8Os^DTdU;Pi zwM~=i|6tGm2Oe0T244Q`VfCCFAMZaOee(Z*^#7lx|6>wCl>|8Y-wH8*h&n`n6~E1s zHTlF23<>;2%SzcOJD1C{IPiCuaB&X}iNGw^banSj5bgyBu@ zRIDJfCkx8Je6w|4ym7pd8ZGFb)dnAyV$7eIQ+KK01DjPMNJoU?h>RSag^@7&pL3$n zC5}VRuCvA~Dq6t4D-YY6^p!BN5&SZrRSqow6L^iB&CvsV3m)RQNDO+pkkc_)SX9%2 zFzrl)Y5RCQfAK%2{SIN;5hMm$2(Lhxy>x%oxBuHKSP)*3dN$Jp9?Xw$ss9Px5dJQG zjO<_VYtr1V9E_oW0KE38c_tIQZwOp>rO>IYyl$jok~?ldZer*S%8lRn_RP!@crpXj z^|xp~_m`*o_9`H0O`{E7!zFGKLhJ>ZU@sWz`govfVxuvov-&@ilUK9Q-kJrR8oc=U z>!Y)F?k_C66;9y6_Do>m?2V1aRe(JN z(?+bI{(rov@~lCm#@^jO8y)OGFp%JJ_;coNXCELbtAI2VKTu4u`(7xkIGCRe#vLi4 z3znycE(iS{q#nBIY%sw1^`uRkmOVd3lI8Q3 z4!rfhW&t+a|E}#X$Wim=<2KNHdH|SBzwK8jXY%p`6_z#tlh<156v9Is#}j1cVH_iFkwX$OW9UHn+3=rUBV5Lvw%5bSW*E0 zvN!lHwM`HF;eU*hK^Z9|09?{9#$T+?17^^+-jLKX-@}HCuVC`H=n0-om7qU}u7?uZ zgLzkltz{A2IsvXgK)1&Bs=Uu1gk55uUC57`pH;m@%xIz`_$gBkvYG-^pf=DmQw4l@ zDys{$6&$3fiPeFQ!XH3{VVCOP4oFV@1j^%PAiox4H>PER971J5MT@6KbkY{w2ie$| zGYRmxOkkG3GD3t9fDyP62#zcO&PjU!5^{d>;07RORX}>ic3d|Ec+8M7u>6I#M&U8# z|4w$$Q>H-V4nUS`0BU`{CvG`H9`+yvcZoX%&+JU>NUZ76mARZ?L9>=IV z{kH_%SX#Kq(=Z~t;vpC`?*VZ&ZUGFy;p`S0t33983?qoxA+!~Ya)I9QK!R`_WQu=# z(~+AFOOs_1SbVAgvpI&?QH>BgGQt8S!U-g3oPk1n+zZFWFSaQKpA|aIeDTqDfTq-V zj@a)l+WT1rdV(N#P+&@QWKZrn1+I=)4O(RB)Xh<-1MccXPr6X~-;6=q4(dk@Pq=JV zXEc>eL4ePt>j(6a0AJLLOB(-*|2RM&!GAW^(EoKOG6_=vvr}aOoC^c3UdKl`ulwmk zJDow(kB52OqyA0Z2tWiHCXROwia&{s2+Q`gO%{h?hHvrBP=a}hR$Bv=|5}7r%Z!`^ z3!vgDx_a@QT~jJB&^8oR_DIP6w<&^qRq^J#3V1tj4DjYRRXSyI?A_{Ye(&nQs<2|> zP>IZgJdMqqp63b4`S$8zJ`iFBgb^ zh$FNJGYCLf7D0$~Fl*q{fROni%L9|4lohMm%My(wdxGafFBe-2h$j|-liT}cd;({| z775~}sPFW@o8MF*Qrk3Q$5o1mu4@4``qJ)C&q9;&uE3rN|4o0i1!R@#0ZD2-L(_LV zu!_8?U@YNWh6>@F-RcBy-^GEkxL_t;K1?ZbHvscpT=rTuTFThJdB)TgrIA1-dXT1L z>{9obC{znEhW7|>?HusoO9dw#;toNBf&w3x24q%9fNA%&E?Z6pFTd$?$B01M(5`UDZ)@s}0r zV19`r!2Gz_SY0)MLQO{?c?U)jeh|o2FMtB|dTcMqGnRZGdX~z9SqU;LQA6KygxlUx z{O6ZSB7W&CPl291j3HD3k=!AYTm$k;4c`SbCd^^|q&k97M_)qcWBayi5nwjn|J0DbjzcWI7PcFg-+>A7N*)YRX zO=nbOX_ztAg3tbpaxa=iP^J_9onc6MxdxWLA)OQCNK689+Ozjm!NUuH=V1_geaTUP z$&8wOjgcrJ5IF54Gz$Ft{`PRkGrGxf`K@1WU*p%yl|``6ri%r;7hwBZ%zL=A&iKdwz0tfM>vk(e?I4!# zTkVS#@7{|AYii-&`5^@UJ{|jF_h(x47iizhz=bo>2Qdv!ICLy+J#^azaTD_^^s@Zp!q0l=LtK94w$5K=?;hfV?P zo}JKNgOpzI3Nt+J#UGET{Y)dM-hBd#RHwc03FxA_3tw-g9R9ScQL%~DK73>7wd4|! z<36R(xRdMKDnwDNw+d(~2DzHoP<#Qa3L$w27l{7lse!A2c3)O%P>C49uRz5gVnac| zu#H@Y866`Qnt(vTJlPQ7o^H4YrhC0T`q&e=487S2;0vACcu8pmQ1&fGF8#y>amc;o zfao*NZ|dA+=>qKW6wCs7{uh~tb%RvG?Mvbb&MYc7`m=|C`at!=IN+zQ0`Wj2kPYzS z4HEtis2&`N*|b@Xi2fRhTMXd%f?}A%&Dm~oYl1NcxM*?#r*!st`7`^!L9ZoIP_C%2 zejNGpN^o_%9xzH>2U1VuP16qv0e8I!8NAD%rMd4$U=o1o)&^wlz5s4pJ%EB=s=kvY zHL?V^j1l4PwW2kIj@SO>0ELsg2=vp+5tpejylfz{f}Drx!S;^>;6;viKt~WdUxo=cR;Q?9G9@; z^1aNYvxAWB04@&XpQ+iO1FE!?dueawFIltDvj^2UhIn^1X(*{7tTrqa z%7cR$4r8f!fanDq6N=ZUdX^w-ah0Eywo{Se&)>-5-6ND?lE|zls_BqGZ+g zSd>EieY;yzcjSN-)(dhRFVBBG0|B@HS6g2i4&@uRZ^mH8QZv>;*%_%U*@^6-nvkMG zwn8D2eVrj&WJ!{2NugwEoe;*pMT)4OWh|vdBFfHtJ>KK}@IQ{f4?gIjdG6svpnYu{OwK*~O_-{=OhX&7N>F^&zzqV-$$pbXG$@Q1I2yU+O)!1ULDSuL zCTWUG07EpKVu4Gka7%WdbCS=AT7Gda1;-;p3vN|gZ4B?ozqq^tzH0v^bCIz+%|e$< z4$<~0o!bnbN;{mR=u!(UyNTM7t?fazCEe@wLT>@3BVan9;&R<5Gsr=aM9{(jy#?&U zFj>jOxwf@O%Os-qE0)(f)SBPK&$MAjq6UFm`dbhchT!(|)meU5M1o<_ z>)pYhA&O2=?q3nOI^!>Whhf>FnhVCfx)AL@K03K~Ql~kAYp{jm#!|S1 z%xn44gyXoex)3BYPDHjI)>CWr6jgk;jj;jKtoAZj_E%R*3JDeY)hEi7c~b9NtE~s za|7Rp74vVZFu?3KP+qsTmE$&Jz1ER|VRefAc~akHGLPmw&dpLG7CLd*=iWuIfZ>Q8 zgJ~DP0QU6oeFrO)z?2DyQ{9=#{Q}02f;%IlyHs_86OUrE||eKz4jT zOsI|xZu$1Vs6wbl;tk^6c;+Ot^=~Vji8|5@5e#GEXC!2U8ZIjCi%+9_%Kt*({BD=N z{HbQe=SUUiGlugTBxH9{u9@m#gKtZ< zzI{Wu3Rkb-X)`Mr`2>1ErlseddLTgJB7eXgoD0i|EYPwtSrQis+n4_0v6Xl9 ziYL(o?nT!SBDO$g>A&Z^uMOO>&0exV*gNH3znxzbX^a=Teir(dc@>DnpD<}S2k)2z z8T{Z5{odx2`2M3wY{ddJae{I~0;&-wu2-qKkHAh-6z30n_w4v%^CaAzn?&_>LIB+) z2_t`)Uf?F`)1%{iWa4VvQPU)zdHSkFPI5%TdB=nC1xD{H`7_e?UfQIo5D4vP04Ap^ z;}gjGRB=nIYLG>yFzGJP2z<8&iYp7;n$d+!hPp&r5>Y&psfjOo6-Jc$(yQx6z$wN% zFu%22qg;_5O#+OMP!5LRgDJePf#-yuzz1ug;a!YNAcL;r*v*wv?@BNjo= zm%!M3ViRJEJe+8S;ZPD-?<&{gav|OpiENA-w+Dy!TKg+@j}@~CipHv(mjhM=Nfr0R z&9SsrFFE35)I3)gBK91~Y z5&`w3Iv*v8q$hOJ>Tdc4<+9j(UHx7nz3TO2?rMC=f75B4lvaN4K7h@uZX##1h z{lBw6u!Z~XuNl9lI{S83+n)_DOJ7r#d8+5T%COII{>^vUv&i^wl~C27-*LN2Z2QAL z_2S(*=DeBq@jo?`oc1LtznotOPsZIfjk70+9Aqy1QCbOI*#Q|LIPB+>tn)y2b}fGK z5%xBqfmbwzfxm#q#i8(oMO~@BDo#(8-+Rv&?t{?q_F6Uvt+Yh?HWLY*lbQNe<)zc& zv+KPHM%AjnU&gH2%vqC8PEIZ2Yz~}^BJ(+o*kH_|&T#MQ`hgzS(g8@QOJ%bS@r|CLRc9QH>y& z>QdN5f*M#*JX4Q>pyn1^;GY*W=Bv4WB?;=IS8hS@-tnaN|5A4xx7dG{;QF)hh?BH@ zz+MJzMp8WA$$Ni7^>isxvw4r?+96L4n;({><3$zj*Z|yc-&Q@`I^Hp4z|X@jk;Y-B#jre{K{p2!DR|`4W%3HnsKUf)LgGc>F2PS!2(lT=T#Xy4NJ$xj%&gy}A zOV65UsB+>9LXG8c{bh8HnO}C<3V(~_qv?ctG57xj#*ci_C?Os@b83S0|M?($m~gAR z)V=NM?$0@%{2X)(W+T~Fzs5>YnC_&+AD1n$q+Op;>F1FxdjjpEdQG{5;#%K0Cht6d zuzvP4h!l%!47_;vULQ;h(M{XjLp5PPVU@(Mdf@;mtuaFgQ)K(MwnFHs9G zRh=&677R73NXR|xv?Tt4`@2KOVvLWh>{%`^W z$R1L1UHVf6+9}`@x#A_R9WtKNx6gbtMHmd_0OCnVl z2~uwRt#r9@u~b3@00CJ$e~{3>zR8*-ny4)^V!ca!NZ@+#kcLTn);mbzX*r89})JD{Of0)qaeOSHkYqIBM zjkUpGW+9i$NFib=C$WY1{KBkmzade@J@xU??jBWYbmXKe1{HbEwa2NO)@&hMfYb;6 z*wPztY-myDoX{P@YTYNZ9#%c>8B=KKa*vl{?{F$PQ7V`pjA0&3bI3Gs?ry5_m=FEh z<;KkAJuFPI9tzVmCIUX6>{|YQ^Ws2``Lcu?Z|UTMMUO+Nr#H-?-zkU?mR3(x6qg6TrOrLoCbwM4>iEeNoqt(1GA~yW-(!~` zuj>*lUr%1;sV{Jx^S%a^wA}GWF|zu4sT&DH>=7yBH|Qy6xeD{@&qc&JGA`f!sc20i z=eRc?#)SV6@ z#CXzAA(mmhd6>!j5(Rqp)8L5_O@i}IoyF9%=g2R$0%vEMIe+H3PxfEqk!#LJV#O+) zx~BOrOjk@1gA9fG^(&og9nW_RJ-IKA={dM?)Mg|jmeBX`m@rIkN`o?Y`|G_cOh^CR zoD|-KL_*`9BBb3Ue92I?_lQ;?H>EEvUL98A=GFw@jF;21lA19ph=$-xjE^! zVq}FOudkP*7oQN*^~K+M>yilTUM;NFlEk66tyq&Fm*)%Wub`+>eS#2B>uWY~Zd3QT z<`-T!;s?c3pIAs)NM;+bITtTU#(UJ}*$1ItC*0Iitm^9htu#Z}S(V(&w%Wwh8vlLL%>Ra5Qo=^kJSmK^AeXPt;)3TxGJ*2~&yN4VCFYpOR%7FzkLZxz?r+M(qWX(R}l4@xp?VRU>eo_|RB5SbOK zq~BUwzH*>bRO4-E@}r^NP#A5AI4-j~I^YjMO`0S|x404za-N@?(e(^Qo_Y%;E1V0E z+~{D|emW!P>^Z*h_P6FAx3+$^Oy{F>*UY+^2TK$5`Y}`zC$XH)!y0h+W_Tlh@w=P) z?#ZlNui;aQ-M)JpgNC8sWVizrd1yYJz4^MOp5vT)oSxgD2y>pj=O%(Lk=lM|(=RUj z812XuVplxZSlCh4;H{1flcuQRKUZq`_hwAAC}O6&E})zpH`iCZUx`Kw2ktTAIqF=& z^)I2M@%Ew7tnUIYxdx$S``?n>^$vc8G+GLU^bV)0v52+T{f}vTDz!~^F`vsKl4ojRDXcMJoXj-~4zdijiL-Lh#b6%J z3AWs$Og{PiRuJ34|x;0(X0(W{rfr^YM_{< z7AVpg$B%FbxgDhAciGGZe@y?t)=a3{t@1H$OCTd%J@c6}v2gXdMZJWD1`5ae#?VO2 zM6r3cZbcy=V69!tu<2KP5NjBPbQEl2Fg`1VE4&%f$6l4N5Gbw$#zvx2!tXE(+2u8! z=ZUD4wA=1`JEcp-n6vBoHe-FRFH%u~hlbPeY}jl^mpGyNQ0HAVktKoQl4Z?r=~G;; z>#Itw9(LEq&DII^5N}?xv|ZEQsXbS8Q+he;Cja`TZ|NuSB^zFrd}`>MJob90(D*zi zJzt@2E64h2U)0&u$TwSdIHApm&H341ZAtTQ!>RJ=80iH;qnvAjW^$&AyF$L*Tlq>z zvW!>nF=An+Oy(%QQ&1wPw04?V_`bK=d&H_RG66t#qcrG0CXlJ;XXpQ!<75$GawLqaxdz2mUH%Z{F6AoIM-hj$Y6YuOqfiqV-; z7zg9&a-KPrP|SLvJup6Q-yIglzcn39wiOhZNP5lV*!z%O|6Pk2lYXVrE9PvKy z&f^FvKM9nsx_M<*&E%?ap~)?im9sKmDLlj^JI)J{=21jPm#DWW7{CBYi?7-ADqoU+`V%1)FFX%Q?Ym49yuBK zj5nYLHIg4xMF2lifG51H?({_SDoi{_O~%@NqsP{uyv=6}oMS3~7R#XT=tV8&6^%+= zyFmN}uJC@vSBp@8Nl0d=^0u=2h*P7+#9pTFZa@UqdZk zeZJXIe`&HP?nF|-++Oln^(iQK`U9}DYC9DOG7LQBIZ4mhZ#ddC^bwK}FpOf@dv!Ow zGF&VXsl^KDOup?FN5mOU5w)tJRnt}==6=*B2M#o)EQ#Gj~L@La_I(g=~%YA(e|ZsDD)XzwzMq#Q=yn% zr=3;fs@`AO)&gB%^4LLyRS+15km~VKd-Y5cEzs`Dr3AYpI&qRWG#XiyPbelQ85%$M zBBil}@Ln$j%y}@m|1~OY~~O=@(fSO!`zmgF5@>%M+3h_WBDsgw3}+ z-+iAIptz$9tOUr$k(H24Ef1da6)lvQk0(gliK`^0TMi{loJ(M5O?W^%SU(DM-yk*2P_#ks;o|rz5--=8hnFBf`+cuul{(vK<07)cf3G_ zryQ!DL+>86J%|$i1%n%Wie|q8cUi z7#`5doZM~pJ^9$t4kjj!2s0yn>-UOtN+n2O;rszLtxvdXAQJB7b3QT^gK`lGly*p2 z;~Pl!jlMTt*zC~+V_zp?nnB^~1%AEdtIEccCMhCb#IG~`hX4)DW)GR%avKtNooj1r zC{{8_S^>UP_+-&>jd6@k7AJ*zmbZ$3F?#vlR?BsP8kh)l_iN(!aT8;EYQ)Cs>wTw& zZx;qy7hXsz5__|``qTc}`Pq&)nmY5&?60KcwDWyi1;0Pd37&Mk5f2Pb@xzfJ6)||^ z>+t*+Zy*8|U|%Jp#L4(MfWPl!*|opX1Xnns{uYe z_dlKqN^5Uz=4P}DM-Nz22JSfjhWk6lsj0)f_CvOjLj^*s|6g>67fTE+s2{c4!|Dl; z9NO>apN6K?aIvb+qY#GCX?gVOhhx(GoV1;3I#_yC-;Hr|4))ly-mhC)%55*MShm5vm@>xB>QI}*Wh(`%?w+zaDTblJyW9xaGy zLk#)No3>AuN?eiQ!RVdCTn`Vm&{_vrZSrCMozzTv^LIKZSQgGBNC!VS2ki{6Z*nOR zOO_xX40Z^RQIHsoSo50JZMlR!^>Agn{oQkrlDrpm_%iNF6W~*LvR~ZRX42w8i)M5W#c@ zyyH_o{r0fLMq+z$j01-M@({K@eI+3cpHxL!BZ z!|(<)&cTMQA)*6T6@Z&O^|6AR_SsCnR1x&&8UwSK&R zaNTcZ;*8xLtX*qMe8fxdksl66IGd2}4LGgjN1XxvlI<80*BTlgTm-v`e+VbqUTE|S zM?R)O7b{sC_UrYZh!(p<9h!V_hbvSc_iYv`FdQ?d5cR}9sJz~J;Q$JIICf*mOrmR` z{h_u0=qM!;DK#^7JSX%FWmA=Kx5d%1370hB%(1s)E{x3$8_v|A#Cur>T#$DaZnUj{ zk5Z?_Ns~*{;;YhAlOB$Sq%|k-E)>|uSftoS($;S_@2&iB+}!$54>AcL`K_D#pFs(O zpc^z+vT%DX@4ho8c4+qQ3GHZhQ5y?TU0iS$F z#0~MYn)69Zjr^2BaRWqM7xOVO>O#PVQUlbd&P`qN;h5OOh_H=h(%6=q*OOKNqWG6O z!5jE<_VxKLegTVPFq8RA-QXYXi+k`jWalc7L$JOAxp+z5>u1tjGBfK zUs`?9+F3}VegP2*b^CuPfmm-Ob`?z{S%xJs>|ZR(MER6Lh@HAsM!3`emZ1U_L?baj zKq9WQpFO6waf22rY|3{{Jh}MHD!15Il#P_#tHO&#(=AV~XA_WMe+TI!d^ULeKFIX& zly^A0knazA+C^4fDHL4d^#Y)}(0;?ef8X{dzlttmZPC~T9Kw;w(#-!bK8GZJ?UP8i zezh$!LAq>0KUyh6i#MKPip>EU7?JS}!Ko`^w)I0L1dU}O3zTR)L5gkR5Ide;KBg*( z1P^R36RNy_Pr=X-L}+5j)(=G?^w#E5eR3B6$8*I;`^qRGd=y1MQZf3c!x2=Vw*z1B z)CU9}p5PddjK+t7`3^g}wbYAEzg+s5?HpX@U-{L@o4c>fSgZzF$)(kc0sp&_6!7Eu z@JmY_nBo?7qthz;7CYR6h*s!BOec%CGa$e&AE>5 zgXOtE{2v>=?cx7oE3h5t@ojS4ps!7uJF+|m`YF4DEmXT7E7RB1!S#S*T;Gm_h3cR>&!42S=8G=)In+lc?$OY@G zf`OS#-a8M@MTLVNkF1gEUkrh{k5eWtbOaQ2Jh+1w5%ttKXf=DBB(kgFSN5TD6~F0h zV`mBc9ttQKx^ZmhnkX~o2W(H&OJQFEA^RCbx4EVaUj~KkVV%q0|6Lf6%XoLFIBuf| z2Q+x~`3z7yTEW(4+7{B6k$UsDXI527(mz1#2{pE3ekp2RpKW0d&9Aa&J#rgw{n_9F zl>vB89q4-80S^v&)sRsCC+Td`uyT(^6NG0Z zrhuYu3ikv=frnp4{Mdf-$d3EB>an=}hvsDO0>Y zm#}3EO49wnxwrc?ZxJjVECyiE*Yv1pb|u`$ANWI3`K7ZU6+{vZ2^8lAYM{VATSk&w zz`&azJ-s@?!NWP_85Enb1tS`Ni^{0ShUed~n`qTO5Q)Z;nk@KZZX@2Ijf_ z`FvOnR@j~RBaAJS#X-{h$MU(ROz>r9Og1VvaHReZG!c<_ literal 0 HcmV?d00001 diff --git a/doc/fluid/design/concurrent/select_op.md b/doc/fluid/design/concurrent/select_op.md new file mode 100644 index 0000000000..52c226bc94 --- /dev/null +++ b/doc/fluid/design/concurrent/select_op.md @@ -0,0 +1,265 @@ +# select_op Design + +## Introduction + +In golang, the [**select**](https://golang.org/ref/spec#Select_statements) +statement lets a goroutine wait on multiple communication operations at the +same time. The **select** blocks until one of its cases can run, then +executes the case. If multiple cases are ready to run, then one case is +choosen at random to be executed. + +With the introduction of CSP for Paddle, we mimic this behavior by +creating a ***select_op***. + +## How to use it + +The **select_op** is available as a c++ operator. However most users +will prefer to use the much simplier Python API. + +- **fluid.Select()**: Creates a select operator and adds it to the current +block within the main program. Also creates a sub block and adds it to the +main program. This sub block is used to hold all variables and operators +used by the case statements. + +Within the select block, users can add cases by +calling **select.case** or **select.default** method. + +- **fluid.Select.case(channel_action, channel, result_variable)**: Represents +a fluid channel send/recv case. This method creates a SelectCase block +guard and adds it to the Select block. The arguments into this method tells +the select which channel operation to listen to. + +- **fluid.Select.default()**: Represents the fluid default case. This default +case is executed if none of the channel send/recv cases are available to +execute. + +**Example:** +``` +ch1 = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR) +quit_ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR) + +x = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) +y = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=1) + +while_cond = fill_constant(shape=[1], dtype=core.VarDesc.VarType.BOOL, value=True) +while_op = While(cond=while_cond) + +with while_op.block(): + with fluid.Select() as select: + with select.case(fluid.channel_send, channel, x): + # Send x, then perform Fibonacci calculation on x and y + x_tmp = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) + assign(input=x, output=x_tmp) + assign(input=y, output=x) + assign(elementwise_add(x=x_tmp, y=y), output=y) + with select.case(fluid.channel_recv, quit_channel, result2): + # Exit out of While loop + while_false = fill_constant(shape=[1], dtype=core.VarDesc.VarType.BOOL, value=False) + helper = layer_helper.LayerHelper('assign') + helper.append_op( + type='assign', + inputs={'X': [while_false]}, + outputs={'Out': [while_cond]}) +``` + +## How it Works + +### Program Description + +``` +blocks { + idx: 0 + ... + // Create "case_to_execute" variable + ops { + outputs { + parameter: "Out" + arguments: "fill_constant_110.tmp_0" + } + type: "fill_constant" + attrs { + name: "force_cpu" + type: BOOLEAN + b: false + } + attrs { + name: "value" + type: FLOAT + f: -1.0 + } + attrs { + name: "shape" + type: INTS + ints: 1 + } + attrs { + name: "dtype" + type: INT + i: 2 + } + } + // Create "select" operator. + // inputs: + // X: All input variables used by operators within the select block + // case_to_execute: Variable filled in by select_op when it determines + // which case to execute. + // + // outputs: + // Out: All output variables referenced by operators within select block. + // + // attrs: + // sub_block: The block id containing the select "cases" + // cases: Serialized list of all cases in the select op. + // Each case is serialized as: ',,,' + // where type is 0 for default, 1 for send, and 2 for receive. + // No channel and values are needed for default cases. + ops { + inputs { + parameter: "X" + arguments: "fill_constant_103.tmp_0" + arguments: "fill_constant_104.tmp_0" + } + inputs { + parameter: "case_to_execute" + arguments: "fill_constant_110.tmp_0" + } + outputs { + parameter: "Out" + arguments: "fill_constant_110.tmp_0" + } + type: "select" + attrs { + name: "sub_block" + type: BLOCK + block_idx: 1 + } + attrs { + name: "cases" + type: STRINGS + strings: "0,1,channel_101,fill_constant_109.tmp_0" + strings: "1,2,channel_102,fill_constant_108.tmp_0" + } + } + ... +} +``` + +The python select API will add the **select_op** to the current block. In addition, it will +iterate through all it's case statements and add any input variables required by case statements +into **X**. It will also create a temp variable called **case_to_execute**. This variable is +filled in by the select_op after it has completed processing the case statements. + +If there are no available cases to execute (ie: all cases are blocked on channel operations, and +there is no default statement), then the select_op will block the current thread. The thread will +unblock once there is a channel operation affecting one of the case statements, at which point, the +**select_op** will set the **case_to_execute** variable to the index of the case to execute. + +Finally the select_op will call executor.run on the **sub_block**. + +``` +blocks { + idx: 1 + parent_idx: 0 + ... + // Fill a tensor with the case index (ie: 0,1,2,3,ect.) + ops { + outputs { + parameter: "Out" + arguments: "fill_constant_111.tmp_0" + } + type: "fill_constant" + attrs { + name: "force_cpu" + type: BOOLEAN + b: false + } + attrs { + name: "value" + type: FLOAT + f: 0.0 + } + attrs { + name: "shape" + type: INTS + ints: 1 + } + attrs { + name: "dtype" + type: INT + i: 2 + } + } + // Create an "equal" operator to compare the case index with the "case_to_execute" + // tensor (which was filled in by the select op). + ops { + inputs { + parameter: "X" + arguments: "fill_constant_111.tmp_0" // case 0 + } + inputs { + parameter: "Y" + arguments: "fill_constant_110.tmp_0" // case_to_execute + } + outputs { + parameter: "Out" + arguments: "equal_0.tmp_0" + } + type: "equal" + attrs { + name: "axis" + type: INT + i: -1 + } + } + // Use the output of the "equal" operator as a condition for the "conditional_block". + // If the condition evaluates to true, then execute the "sub_block" (which represents + // the select case's body) + ops { + inputs { + parameter: "Params" + } + inputs { + parameter: "X" + arguments: "equal_0.tmp_0" + } + outputs { + parameter: "Out" + } + outputs { + parameter: "Scope" + arguments: "_generated_var_0" + } + type: "conditional_block" + attrs { + name: "is_scalar_condition" + type: BOOLEAN + b: true + } + attrs { + name: "sub_block" + type: BLOCK + block_idx: 4 + } + } + ... + // Repeat the above operators for each case statements inside the select body +} + +``` + +Cases are represented by a **conditional_block operator**, whose's condition is set as the output of +equal(**case_to_execute**, **case_index**). Since each case index is unique in this sub-block, +only one case will be executed. + +### select_op flow + +

+
+

+ +The select algorithm is inspired by golang's select routine. Please refer to +http://www.tapirgames.com/blog/golang-concurrent-select-implementation for more information. + +## Backward Pass + +TODO From d60180af396bb63dc78727488a8c6467ecb109b9 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Tue, 20 Mar 2018 15:52:26 -0700 Subject: [PATCH 43/79] inital commit --- paddle/fluid/operators/activation_op.cc | 11 ++++++++ paddle/fluid/operators/activation_op.cu | 14 ++++++++++ paddle/fluid/operators/activation_op.h | 1 - .../tests/unittests/test_activation_op.py | 27 ++++++++++++++++--- 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index d74c47b981..ec637658c0 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -613,3 +613,14 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad, ops::grad_functor>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); + +REGISTER_OP_CPU_KERNEL(relu, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + relu_grad, ops::ActivationGradKernel>, + ops::ActivationGradKernel>); diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index b2633d0176..7709a551dc 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -31,3 +32,16 @@ namespace ops = paddle::operators; ops::grad_functor>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); + +REGISTER_OP_CUDA_KERNEL( + relu, ops::ActivationKernel>, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CUDA_KERNEL( + relu_grad, ops::ActivationGradKernel>, + ops::ActivationGradKernel>); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 8f791a6ca8..b95e793586 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -772,7 +772,6 @@ struct SwishGradFunctor : public BaseActivationFunctor { __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, ReluFunctor, ReluGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index eab41ebe71..6838580ccc 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -212,18 +212,39 @@ class TestRound(OpTest): class TestRelu(OpTest): def setUp(self): self.op_type = "relu" - x = np.random.uniform(-1, 1, [11, 17]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) # The same reason with TestAbs x[np.abs(x) < 0.005] = 0.02 - self.inputs = {'X': x} - self.outputs = {'Out': np.maximum(self.inputs['X'], 0)} + out = np.maximum(x, 0) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Relu(TestRelu): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestBRelu(OpTest): def setUp(self): From 018f3bda3dda63947c0b37aaaf4ccc40e4ecd1e1 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Tue, 20 Mar 2018 16:13:48 -0700 Subject: [PATCH 44/79] small fix --- python/paddle/fluid/tests/unittests/test_activation_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 6838580ccc..1e3decfbaf 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest from scipy.special import expit From 70e71227852cb70d6aa7e4d44afd506ed362ba83 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Tue, 20 Mar 2018 17:22:52 -0700 Subject: [PATCH 45/79] initial commit --- paddle/fluid/operators/math/softmax.cu | 1 + paddle/fluid/operators/softmax_cudnn_op.cu.cc | 8 +++-- paddle/fluid/operators/softmax_op.cc | 29 ++++++----------- .../fluid/tests/unittests/test_softmax_op.py | 31 ++++++++++++++----- 4 files changed, 38 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 34ea6a91ce..5518ebed3f 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -89,6 +89,7 @@ void SoftmaxGradCUDNNFunctor::operator()( XGrad->mutable_data(context.GetPlace()))); } +template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 47cb336d87..5596fa0648 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -56,7 +56,9 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace, - ops::SoftmaxCUDNNKernel); -REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, +namespace plat = paddle::platform; +REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, + ops::SoftmaxCUDNNKernel, + ops::SoftmaxCUDNNKernel); +REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 1b63f8a499..3e5457bddc 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/softmax_op.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif namespace paddle { namespace operators { @@ -38,19 +41,12 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - bool use_cudnn = ctx.Attr("use_cudnn"); - bool runtime_cudnn_support = false; + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = - ctx.template device_context(); - runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + if (platform::CanCUDNNBeUsed(ctx)) { + library = framework::LibraryType::kCUDNN; } #endif - framework::LibraryType library_ = framework::LibraryType::kPlain; - if (use_cudnn && runtime_cudnn_support) { - library_ = framework::LibraryType::kCUDNN; - } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), @@ -119,19 +115,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - bool use_cudnn = ctx.Attr("use_cudnn"); - bool runtime_cudnn_support = false; + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = - ctx.template device_context(); - runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + if (platform::CanCUDNNBeUsed(ctx)) { + library = framework::LibraryType::kCUDNN; } #endif - framework::LibraryType library_ = framework::LibraryType::kPlain; - if (use_cudnn && runtime_cudnn_support) { - library_ = framework::LibraryType::kCUDNN; - } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 4f20da2b92..7fa892cea6 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -29,15 +29,16 @@ class TestSoftmaxOp(OpTest): def setUp(self): self.op_type = "softmax" self.use_cudnn = False - self.inputs = { - 'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") - } - self.outputs = { - 'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) - } + self.dtype = np.float32 + self.init_kernel_type() + + x = np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype) + out = np.apply_along_axis(stable_softmax, 1, x) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} self.attrs = {'use_cudnn': self.use_cudnn, } - def init_op_type(self): + def init_kernel_type(self): pass def test_check_output(self): @@ -48,6 +49,8 @@ class TestSoftmaxOp(OpTest): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return if self.use_cudnn: place = core.CUDAPlace(0) self.check_grad_with_place( @@ -57,8 +60,20 @@ class TestSoftmaxOp(OpTest): class TestSoftmaxCUDNNOp(TestSoftmaxOp): - def init_op_type(self): + def init_kernel_type(self): + self.use_cudnn = True + + +class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp): + def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) if __name__ == "__main__": From 98685505e4421b5218cf44280abdd79d1cd55967 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Wed, 21 Mar 2018 08:25:00 +0800 Subject: [PATCH 46/79] polish sentences --- doc/v2/faq/index_en.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/v2/faq/index_en.rst b/doc/v2/faq/index_en.rst index 5ce5cfbae7..1044aa711f 100644 --- a/doc/v2/faq/index_en.rst +++ b/doc/v2/faq/index_en.rst @@ -1,7 +1,7 @@ FAQ ==== -This document provides frequently asked questions of PaddlePaddle. If your questions are not here, please go to `PaddlePaddle Community `_ , to find answers or open an `issue `_ , we will reply in time. +This document provides answers to some of the frequently asked questions about PaddlePaddle. If you have a question that is not covered here, please go to `PaddlePaddle Community `_ , to find answers or open an `issue `_ , we will reply in time. .. toctree:: :maxdepth: 1 From b7801b9fcbe0a2c3a1b8a92c1925def166a13e25 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Tue, 20 Mar 2018 17:33:21 -0700 Subject: [PATCH 47/79] small fix --- paddle/fluid/operators/softmax_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 3e5457bddc..2506ffe48e 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -44,7 +44,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { - library = framework::LibraryType::kCUDNN; + library_ = framework::LibraryType::kCUDNN; } #endif std::string data_format = ctx.Attr("data_format"); @@ -118,7 +118,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { - library = framework::LibraryType::kCUDNN; + library_ = framework::LibraryType::kCUDNN; } #endif std::string data_format = ctx.Attr("data_format"); From 552cfe47beec913c6c3a4a2e02e1c3703823a55e Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Wed, 21 Mar 2018 09:14:12 +0800 Subject: [PATCH 48/79] repair image link in rnn.md change the path of image code --- doc/fluid/design/dynamic_rnn/rnn.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/fluid/design/dynamic_rnn/rnn.md b/doc/fluid/design/dynamic_rnn/rnn.md index 2f4854793f..e8c8b71e49 100644 --- a/doc/fluid/design/dynamic_rnn/rnn.md +++ b/doc/fluid/design/dynamic_rnn/rnn.md @@ -5,7 +5,7 @@ This document describes the RNN (Recurrent Neural Network) operator and how it i ## RNN Algorithm Implementation

- +

The above diagram shows an RNN unrolled into a full network. @@ -22,7 +22,7 @@ There are several important concepts here: There could be local variables defined in each step-net. PaddlePaddle runtime realizes these variables in *step-scopes* which are created for each step.

-
+
Figure 2 illustrates the RNN's data flow

@@ -93,7 +93,7 @@ For example, we could have a 2-level RNN, where the top level corresponds to par The following figure illustrates feeding in text into the lower level, one sentence at a step, and the feeding in step outputs to the top level. The final top level output is about the whole text.

- +

```python @@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st

- +

From 4be675bcbf34b23b95f30dee9f84d57056f02d08 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Wed, 21 Mar 2018 09:15:32 +0800 Subject: [PATCH 49/79] polish --- doc/fluid/design/dynamic_rnn/rnn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/design/dynamic_rnn/rnn.md b/doc/fluid/design/dynamic_rnn/rnn.md index e8c8b71e49..3e7f38d2d6 100644 --- a/doc/fluid/design/dynamic_rnn/rnn.md +++ b/doc/fluid/design/dynamic_rnn/rnn.md @@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st

- +

From 784e3302663b83484d0117101cb28211a31e335e Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Wed, 21 Mar 2018 09:21:33 +0800 Subject: [PATCH 50/79] repair deadlink --- doc/fluid/design/dynamic_rnn/rnn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/design/dynamic_rnn/rnn.md b/doc/fluid/design/dynamic_rnn/rnn.md index 3e7f38d2d6..cca2e69710 100644 --- a/doc/fluid/design/dynamic_rnn/rnn.md +++ b/doc/fluid/design/dynamic_rnn/rnn.md @@ -49,7 +49,7 @@ or copy the memory value of the previous step to the current ex-memory variable. ### Usage in Python -For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md). +For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/block.md). We can define an RNN's step-net using a Block: From c55bff79697a9f05a903f65a09ac443b6b98aea0 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Wed, 21 Mar 2018 10:30:08 +0800 Subject: [PATCH 51/79] modify error --- doc/fluid/design/dynamic_rnn/rnn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/design/dynamic_rnn/rnn.md b/doc/fluid/design/dynamic_rnn/rnn.md index cca2e69710..6f414e5549 100644 --- a/doc/fluid/design/dynamic_rnn/rnn.md +++ b/doc/fluid/design/dynamic_rnn/rnn.md @@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st

- +

From e438926781199e61e68e4d1978a8adfdc95793e6 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 11:35:17 +0800 Subject: [PATCH 52/79] fluid_cluster_train_cn_doc --- .../howto/cluster/fluid_cluster_train_cn.md | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 doc/fluid/howto/cluster/fluid_cluster_train_cn.md diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md new file mode 100644 index 0000000000..a95dcd180e --- /dev/null +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -0,0 +1,124 @@ +# Fluid 分布式版本使用指南 +本篇文章将说明在PaddlePaddle Fluid版本下进行分布式训练的配置和执行 + +## 准备工作 +* 可用的集群 + 包含一个或多个计算节点的集群,每一个节点都能够执行PaddlePaddle的训练任务且拥有唯一的IP地址,集群内的所有计算节点可以通过网络相互通信。 +* 安装PaddlePaddle Fluid with Distribute 版本 + 所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。 + **注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。 + cmake编译命令中需要将WITH_DISTRIBUTE设置为ON,下面是一个cmake编译指令示例: +``` +cmake .. -DWITH_DOC=OFF -DWITH_GPU=OFF -DWITH_DISTRIBUTE=ON -DWITH_SWIG_PY=ON -DWITH_PYTHON=ON +``` + +## 更新训练脚本 +这里,我们以[Deep Learing 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html)课程中的第一章 fit a line 为例。 +### 单机训练脚本示例 +```python +import paddle.v2 as paddle +import paddle.fluid as fluid + +x = fluid.layers.data(name='x', shape=[13], dtype='float32') +y_predict = fluid.layers.fc(input=x, size=1, act=None) +y = fluid.layers.data(name='y', shape=[1], dtype='float32') + +cost = fluid.layers.square_error_cost(input=y_predict, label=y) +avg_cost = fluid.layers.mean(x=cost) + +sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) +sgd_optimizer.minimize(avg_cost) + +BATCH_SIZE = 20 + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) +exe = fluid.Executor(place) + +exe.run(fluid.default_startup_program()) + +PASS_NUM = 100 +for pass_id in range(PASS_NUM): + fluid.io.save_persistables(exe, "./fit_a_line.model/") + fluid.io.load_persistables(exe, "./fit_a_line.model/") + for data in train_reader(): + avg_loss_value, = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost]) + + if avg_loss_value[0] < 10.0: + exit(0) # if avg cost less than 10.0, we think our code is good. +exit(1) +``` + +我们创建了一个简单的全连接神经网络程序,并且通过fluid的Executor执行了100次迭代,现在我们需要将该非分布式版本的程序更新为分布式版本的程序。 +### 介绍Parameter Server +在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server。 +![](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/cluster/src/trainer.png) +**因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是 Parameter Server 和 Trainer** + +### 分布式训练 +Fliud专门提供了工具"**Distributed Transpiler**"用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recive 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 +```python +optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) +``` +将Distributed Transpiler、优化算子 和梯度函数放在一个代码中如下: +```python +... #define the program, cost, and create sgd optimizer + +optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) #get optimize OPs and gradient parameters + +t = fluid.DistributeTranspiler() # create the transpiler instance +# slice the program into 2 pieces with optimizer_ops and gradient parameters list, as well as pserver_endpoints, which is a comma separated list of [IP:PORT] and number of trainers +t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) + +... #create executor + +# in pserver, run this +#current_endpoint here means current pserver IP:PORT you wish to run on +pserver_prog = t.get_pserver_program(current_endpoint) +pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) +exe.run(pserver_startup) +exe.run(pserver_prog) + +# in trainer, run this +... # define data reader +exe.run(fluid.default_startup_program()) +for pass_id in range(100): + for data in train_reader(): + exe.run(t.get_trainer_program()) +``` +### 分布式训练脚本运行说明 +分布式任务的运行需要外部指定多个参数: +```table +| 参数名 | 值类型 | 说明 | 示例 | +| trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | +| pservers | str | parameter server 列表 | 127.0.0.1:6710,127.0.0.1:6711 | +| trainers | int | 训练节点的总个数,>0的数字 | | +| server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | +| training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | +``` +启动顺序,先启动全部的Pserver后,再启动TRAINER。 +**其中:training_role 是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,至于如何从外部环境传入,用户可自定义。** + +### DEMO +完整的demo代码位于fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。 +``` +cd /paddle/python/paddle/fluid/tests/book +``` +第一步:启动Parameter Server, 启动Parameter Server的命令: +``` +PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py +``` +执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ``` +第二步:启动trainer, 启动trainer的命令: +``` +PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py +``` +由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。 +现在我们就启动了一个包含一个Parameter Server 和两个Trainer的分布式训练任务。 From 85db0ae746867bf290a04513a41f3d6e5af1db80 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 11:44:33 +0800 Subject: [PATCH 53/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index a95dcd180e..7373e00106 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -58,9 +58,9 @@ exit(1) 我们创建了一个简单的全连接神经网络程序,并且通过fluid的Executor执行了100次迭代,现在我们需要将该非分布式版本的程序更新为分布式版本的程序。 ### 介绍Parameter Server -在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server。 -![](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/cluster/src/trainer.png) -**因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是 Parameter Server 和 Trainer** +在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server, ![Parameter Server 设计文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) + +**因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是 Parameter Server 和 Trainer。** ### 分布式训练 Fliud专门提供了工具"**Distributed Transpiler**"用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recive 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 @@ -97,6 +97,7 @@ for pass_id in range(100): 分布式任务的运行需要外部指定多个参数: ```table | 参数名 | 值类型 | 说明 | 示例 | +| :------------- | :---| :--------------------------------------- | :------------- | | trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | | pservers | str | parameter server 列表 | 127.0.0.1:6710,127.0.0.1:6711 | | trainers | int | 训练节点的总个数,>0的数字 | | From 7aa48dea117fbbbe5167ecd4f23bc8ea27a16fb5 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 11:51:28 +0800 Subject: [PATCH 54/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 7373e00106..d4d41943f6 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -58,7 +58,7 @@ exit(1) 我们创建了一个简单的全连接神经网络程序,并且通过fluid的Executor执行了100次迭代,现在我们需要将该非分布式版本的程序更新为分布式版本的程序。 ### 介绍Parameter Server -在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server, ![Parameter Server 设计文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) +在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server, [Parameter Server 设计文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) **因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是 Parameter Server 和 Trainer。** From 7bb4ea9c1326966761396d34bbd86d6844b14fdc Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 21 Mar 2018 12:48:59 +0800 Subject: [PATCH 55/79] Add an argument in Executor.Run to allow users to choose whether to create and destroy variables every time. (#9242) --- paddle/fluid/framework/executor.cc | 4 ++-- paddle/fluid/framework/executor.h | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index a688115b11..0b171e1dcf 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -185,7 +185,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, std::map& feed_targets, std::map& fetch_targets, const std::string& feed_holder_name, - const std::string& fetch_holder_name) { + const std::string& fetch_holder_name, bool create_vars) { platform::RecordBlock b(kProgramId); bool has_feed_ops = has_feed_operators(program.Block(0), feed_targets, feed_holder_name); @@ -255,7 +255,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - Run(*copy_program, scope, 0, true, true); + Run(*copy_program, scope, 0, create_vars, create_vars); // obtain the data of fetch_targets from fetch_holder for (auto* op : global_block->AllOps()) { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index fb29c70f14..d8dd82469a 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -54,7 +54,8 @@ class Executor { std::map& feed_targets, std::map& fetch_targets, const std::string& feed_holder_name = "feed", - const std::string& fetch_holder_name = "fetch"); + const std::string& fetch_holder_name = "fetch", + bool create_vars = true); static std::unique_ptr Prepare( const ProgramDesc& program, int block_id); From 5ed722da9b75fe6318ab2b3e12787090514225e7 Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Wed, 21 Mar 2018 13:02:57 +0800 Subject: [PATCH 56/79] modify some sentences --- doc/v2/faq/index_en.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/v2/faq/index_en.rst b/doc/v2/faq/index_en.rst index 1044aa711f..3fa220792b 100644 --- a/doc/v2/faq/index_en.rst +++ b/doc/v2/faq/index_en.rst @@ -1,7 +1,7 @@ FAQ ==== -This document provides answers to some of the frequently asked questions about PaddlePaddle. If you have a question that is not covered here, please go to `PaddlePaddle Community `_ , to find answers or open an `issue `_ , we will reply in time. +This document provides answers to some of the frequently asked questions about PaddlePaddle. If you have a question that is not covered here, please go to `PaddlePaddle Community `_ , to find an answer or submit new `issue `_ , we will reply in time. .. toctree:: :maxdepth: 1 From 34b7fc7cf5e03e3a614c1d379ceaaab4bdb6222e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 13:20:20 +0800 Subject: [PATCH 57/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index d4d41943f6..b7b30ecf73 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -95,7 +95,6 @@ for pass_id in range(100): ``` ### 分布式训练脚本运行说明 分布式任务的运行需要外部指定多个参数: -```table | 参数名 | 值类型 | 说明 | 示例 | | :------------- | :---| :--------------------------------------- | :------------- | | trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | @@ -103,8 +102,8 @@ for pass_id in range(100): | trainers | int | 训练节点的总个数,>0的数字 | | | server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | | training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | -``` -启动顺序,先启动全部的Pserver后,再启动TRAINER。 + +启动顺序,先启动全部的PSERVER (Parameter Server)后,再启动TRAINER(Trainer)。 **其中:training_role 是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,至于如何从外部环境传入,用户可自定义。** ### DEMO @@ -117,9 +116,9 @@ cd /paddle/python/paddle/fluid/tests/book PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py ``` 执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ``` -第二步:启动trainer, 启动trainer的命令: +第二步:启动Trainer, 启动Trainer的命令: ``` PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py ``` 由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。 -现在我们就启动了一个包含一个Parameter Server 和两个Trainer的分布式训练任务。 +现在我们就启动了一个包含一个Parameter Server和两个Trainer的分布式训练任务。 From 5d212da481e16fc9940dd21d129bac6a03e76080 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 14:04:45 +0800 Subject: [PATCH 58/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index b7b30ecf73..a2e3d1556a 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -95,6 +95,7 @@ for pass_id in range(100): ``` ### 分布式训练脚本运行说明 分布式任务的运行需要外部指定多个参数: +``` | 参数名 | 值类型 | 说明 | 示例 | | :------------- | :---| :--------------------------------------- | :------------- | | trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | @@ -102,7 +103,7 @@ for pass_id in range(100): | trainers | int | 训练节点的总个数,>0的数字 | | | server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | | training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | - +``` 启动顺序,先启动全部的PSERVER (Parameter Server)后,再启动TRAINER(Trainer)。 **其中:training_role 是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,至于如何从外部环境传入,用户可自定义。** From b3962a934f8ab7714cc60d57ed0c417097720ddb Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 14:12:02 +0800 Subject: [PATCH 59/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index a2e3d1556a..382161587d 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -95,15 +95,15 @@ for pass_id in range(100): ``` ### 分布式训练脚本运行说明 分布式任务的运行需要外部指定多个参数: -``` + | 参数名 | 值类型 | 说明 | 示例 | -| :------------- | :---| :--------------------------------------- | :------------- | +|:-------------|:---|:---------------------------------------|:-------------| | trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | | pservers | str | parameter server 列表 | 127.0.0.1:6710,127.0.0.1:6711 | | trainers | int | 训练节点的总个数,>0的数字 | | | server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | | training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | -``` + 启动顺序,先启动全部的PSERVER (Parameter Server)后,再启动TRAINER(Trainer)。 **其中:training_role 是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,至于如何从外部环境传入,用户可自定义。** From 50e8251388a30175cea43cb34b06b9654896b6c0 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 14:16:41 +0800 Subject: [PATCH 60/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 382161587d..49763c30b4 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -100,7 +100,7 @@ for pass_id in range(100): |:-------------|:---|:---------------------------------------|:-------------| | trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | | pservers | str | parameter server 列表 | 127.0.0.1:6710,127.0.0.1:6711 | -| trainers | int | 训练节点的总个数,>0的数字 | | +| trainers | int | 训练节点的总个数,>0的数字 | 4 | | server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | | training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | From 529878b156aa8d5be8ec023edbe56203c2a159d1 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 16:08:21 +0800 Subject: [PATCH 61/79] fluid_cluster_train_cn_doc --- .../howto/cluster/fluid_cluster_train_cn.md | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 49763c30b4..c23a06b620 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -56,18 +56,18 @@ for pass_id in range(PASS_NUM): exit(1) ``` -我们创建了一个简单的全连接神经网络程序,并且通过fluid的Executor执行了100次迭代,现在我们需要将该非分布式版本的程序更新为分布式版本的程序。 +我们创建了一个简单的全连接神经网络程序,并且通过fluid的Executor执行了100次迭代,现在我们需要将该单机版本的程序更新为分布式版本的程序。 ### 介绍Parameter Server 在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server, [Parameter Server 设计文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) -**因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是 Parameter Server 和 Trainer。** +**因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是Parameter Server和Trainer。** ### 分布式训练 Fliud专门提供了工具"**Distributed Transpiler**"用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recive 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 ```python optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) ``` -将Distributed Transpiler、优化算子 和梯度函数放在一个代码中如下: +将Distributed Transpiler、优化算子和梯度函数放在一个代码中如下: ```python ... #define the program, cost, and create sgd optimizer @@ -94,7 +94,7 @@ for pass_id in range(100): exe.run(t.get_trainer_program()) ``` ### 分布式训练脚本运行说明 -分布式任务的运行需要外部指定多个参数: +分布式任务的运行需要将表格中说明的多个参数进行赋值,: | 参数名 | 值类型 | 说明 | 示例 | |:-------------|:---|:---------------------------------------|:-------------| @@ -104,11 +104,26 @@ for pass_id in range(100): | server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | | training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | -启动顺序,先启动全部的PSERVER (Parameter Server)后,再启动TRAINER(Trainer)。 **其中:training_role 是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,至于如何从外部环境传入,用户可自定义。** +参数赋值及使用的相关代码片段: +```python +t = fluid.DistributeTranspiler() +t.transpile( + optimize_ops, + params_grads, + trainer_id, + pservers=pserver, + trainers=trainers) +if training_role == "PSERVER": + pserver_prog = t.get_pserver_program(server_endpoint) + pserver_startup = t.get_startup_program(server_endpoint, pserver_prog) +``` + +### 启动顺序 +先启动全部的PSERVER (Parameter Server)后,再启动TRAINER(Trainer)。 -### DEMO -完整的demo代码位于fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。 +### Demo +完整的demo代码位于Fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。 ``` cd /paddle/python/paddle/fluid/tests/book ``` From 3b95b55f07d0e8b0f3b7563e52a97b2504e23586 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 1 Mar 2018 13:45:18 +0100 Subject: [PATCH 62/79] - Softmax MKLDNN primitive integration removed diagnostic - Added Unit tests for Softmax MKLDNN Forward Added fix for div by 0 to happen in cross_entropy backward Conflicts: paddle/fluid/operators/CMakeLists.txt - Cosmetic fixes to SoftMax MKLDNN fluid operator Added misssing softmax fluid operator file Disabled MKLDNN softmax operator by default Fix to softmax op unittest merge clang_formater fixes clang_formatter fixes - Name changing of softmax mkldnn operator to maintin consistency across codebase - updated comment fix to comment --- paddle/fluid/operators/cross_entropy_op.h | 2 +- paddle/fluid/operators/softmax_mkldnn_op.cc | 84 +++++++++++++++++++ paddle/fluid/operators/softmax_op.cc | 13 ++- python/paddle/fluid/layer_helper.py | 3 + python/paddle/fluid/layers/nn.py | 8 +- .../fluid/tests/unittests/test_softmax_op.py | 12 ++- 6 files changed, 117 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/softmax_mkldnn_op.cc diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index ec315695a6..6da3a24dc8 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -78,7 +78,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { for (int64_t i = 0; i < batch_size; ++i) { PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); int64_t index = i * class_num + label_data[i]; - dx_data[index] = -dy_data[i] / x_data[index]; + dx_data[index] = math::TolerableValue()(-dy_data[i] / x_data[index]); } } } diff --git a/paddle/fluid/operators/softmax_mkldnn_op.cc b/paddle/fluid/operators/softmax_mkldnn_op.cc new file mode 100644 index 0000000000..cf0244e866 --- /dev/null +++ b/paddle/fluid/operators/softmax_mkldnn_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +#include + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNMemDesc; + +using mkldnn::memory; // Note: paddle has also "memory" namespace +using mkldnn::primitive; +using mkldnn::softmax_forward; +using mkldnn::prop_kind; +using mkldnn::stream; + +template +class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto& dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + const Tensor* input = ctx.Input("X"); + Tensor* output = ctx.Output("Out"); + PADDLE_ENFORCE(input->dims().size() == 2UL, + "The input of softmax op must be a 2D matrix."); + const T* input_data = input->data(); + // allocate memory for output + T* output_data = output->mutable_data(ctx.GetPlace()); + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + // MKL-DNN does support softmax over selected axis. Having 2D Tensor, + // we will make normalization after final eg. axis: 1 + PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])), + "Softmax input and output dimensions should match"); + // Same memory descriptor to be used for input and output + memory::dims softmax_tz = {src_tz[0], src_tz[1]}; + // Currently only supports NC data format + // TODO(jczaja-intel): support more formats + auto softmax_md = + MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc); + // Normalization is made after innermost dimension eg. C out of NC + auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring, + softmax_md, 1 /*dim: C*/); + // create memory primitives + auto softmax_src_memory = + memory({softmax_md, mkldnn_engine}, (void*)input_data); + auto softmax_dst_memory = + memory({softmax_md, mkldnn_engine}, (void*)output_data); + auto softmax_prim_desc = + softmax_forward::primitive_desc(softmax_desc, mkldnn_engine); + auto softmax = softmax_forward(softmax_prim_desc, softmax_src_memory, + softmax_dst_memory); + std::vector pipeline{softmax}; + stream(stream::kind::eager).submit(pipeline).wait(); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace, + ops::SoftmaxMKLDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 1b63f8a499..4c8326eeab 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -14,6 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/softmax_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -51,13 +54,18 @@ class SoftmaxOp : public framework::OperatorWithKernel { if (use_cudnn && runtime_cudnn_support) { library_ = framework::LibraryType::kCUDNN; } +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; - class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -77,6 +85,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Softmax Operator. diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index da7e74c901..58b6682271 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -399,6 +399,9 @@ class LayerHelper(object): if isinstance(act, basestring): act = {'type': act} tmp = self.create_tmp_variable(dtype=input_var.dtype) + + if 'use_mkldnn' in self.kwargs: + act['use_mkldnn'] = self.kwargs.get('use_mkldnn') act_type = act.pop('type') self.append_op( type=act_type, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bf161d6618..3a9a854561 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -81,6 +81,7 @@ def fc(input, num_flatten_dims=1, param_attr=None, bias_attr=None, + use_mkldnn=False, act=None, name=None): """ @@ -162,8 +163,11 @@ def fc(input, inputs={"X": input_var, "Y": w}, outputs={"Out": tmp}, - attrs={"x_num_col_dims": num_flatten_dims, - "y_num_col_dims": 1}) + attrs={ + "x_num_col_dims": num_flatten_dims, + "y_num_col_dims": 1, + 'use_mkldnn': use_mkldnn + }) mul_results.append(tmp) # sum diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 4f20da2b92..d32c719a5f 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -27,15 +27,20 @@ def stable_softmax(x): class TestSoftmaxOp(OpTest): def setUp(self): + self.use_mkldnn = False self.op_type = "softmax" self.use_cudnn = False + self.init_op_type() self.inputs = { 'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") } self.outputs = { 'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) } - self.attrs = {'use_cudnn': self.use_cudnn, } + self.attrs = { + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn + } def init_op_type(self): pass @@ -61,5 +66,10 @@ class TestSoftmaxCUDNNOp(TestSoftmaxOp): self.use_cudnn = True +class TestMKLDNN(TestSoftmaxOp): + def init_op_type(self): + self.use_mkldnn = True + + if __name__ == "__main__": unittest.main() From 0760aaf4401b2e87684a9ae8e7931cf9e51a74b8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 19:20:49 +0800 Subject: [PATCH 63/79] Shrink batch_norm_grad's inputs --- paddle/fluid/operators/batch_norm_op.cc | 31 +++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 5d27f5b60c..36049ee6a4 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -457,12 +457,39 @@ class BatchNormGradKernel } }; +class BatchNormGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDesc(); + op->SetType("batch_norm_grad"); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + + op->SetInput("Scale", Input("Scale")); + op->SetInput("SavedMean", Output("SavedMean")); + op->SetInput("SavedVariance", Output("SavedVariance")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale")); + op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, - batch_norm_grad, ops::BatchNormGradOp); +REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, + ops::BatchNormGradMaker); +REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); + REGISTER_OP_CPU_KERNEL( batch_norm, ops::BatchNormKernel); From a6b8496c651505bc9e3c9ef2872891ed3d328ddb Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 19:31:16 +0800 Subject: [PATCH 64/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index c23a06b620..3f94a40c71 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -3,8 +3,10 @@ ## 准备工作 * 可用的集群 + 包含一个或多个计算节点的集群,每一个节点都能够执行PaddlePaddle的训练任务且拥有唯一的IP地址,集群内的所有计算节点可以通过网络相互通信。 * 安装PaddlePaddle Fluid with Distribute 版本 + 所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。 **注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。 cmake编译命令中需要将WITH_DISTRIBUTE设置为ON,下面是一个cmake编译指令示例: @@ -56,7 +58,7 @@ for pass_id in range(PASS_NUM): exit(1) ``` -我们创建了一个简单的全连接神经网络程序,并且通过fluid的Executor执行了100次迭代,现在我们需要将该单机版本的程序更新为分布式版本的程序。 +我们创建了一个简单的全连接神经网络程序,并且通过Fluid的Executor执行了100次迭代,现在我们需要将该单机版本的程序更新为分布式版本的程序。 ### 介绍Parameter Server 在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server, [Parameter Server 设计文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) @@ -94,7 +96,7 @@ for pass_id in range(100): exe.run(t.get_trainer_program()) ``` ### 分布式训练脚本运行说明 -分布式任务的运行需要将表格中说明的多个参数进行赋值,: +分布式任务的运行需要将表格中说明的多个参数进行赋值: | 参数名 | 值类型 | 说明 | 示例 | |:-------------|:---|:---------------------------------------|:-------------| @@ -105,6 +107,7 @@ for pass_id in range(100): | training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | **其中:training_role 是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,至于如何从外部环境传入,用户可自定义。** + 参数赋值及使用的相关代码片段: ```python t = fluid.DistributeTranspiler() From d42187d00e6bdad0ec11120d059622a9a288b45d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 19:37:54 +0800 Subject: [PATCH 65/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 3f94a40c71..1d394d3c2e 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -5,7 +5,7 @@ * 可用的集群 包含一个或多个计算节点的集群,每一个节点都能够执行PaddlePaddle的训练任务且拥有唯一的IP地址,集群内的所有计算节点可以通过网络相互通信。 -* 安装PaddlePaddle Fluid with Distribute 版本 +* 安装PaddlePaddle Fluid with Distributed版本 所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。 **注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。 @@ -65,7 +65,7 @@ exit(1) **因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是Parameter Server和Trainer。** ### 分布式训练 -Fliud专门提供了工具"**Distributed Transpiler**"用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recive 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 +Fliud专门提供了工具[Distributed Transpiler](https://github.com/PaddlePaddle/Paddle/blob/ba65d54d9d3b41cd3c5171b00f476d4e60133ddb/doc/fluid/design/dist_train/distributed_architecture.md#distributed-transpiler)用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recive 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 ```python optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) ``` From 4ccfc046c4639f6f0a57ebbc8c749ad3e65f9012 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 19:45:47 +0800 Subject: [PATCH 66/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 1d394d3c2e..8ac436007e 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -5,7 +5,7 @@ * 可用的集群 包含一个或多个计算节点的集群,每一个节点都能够执行PaddlePaddle的训练任务且拥有唯一的IP地址,集群内的所有计算节点可以通过网络相互通信。 -* 安装PaddlePaddle Fluid with Distributed版本 +* 安装PaddlePaddle Fluid with Distribution版本 所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。 **注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。 From 89b9788810ba15e6456dec9b45fc63eb57648f49 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 19:57:22 +0800 Subject: [PATCH 67/79] fluid_cluster_train_cn_doc --- .../howto/cluster/fluid_cluster_train_cn.md | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 8ac436007e..3bcce85ba8 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -10,12 +10,12 @@ 所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。 **注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。 cmake编译命令中需要将WITH_DISTRIBUTE设置为ON,下面是一个cmake编译指令示例: -``` +``` bash cmake .. -DWITH_DOC=OFF -DWITH_GPU=OFF -DWITH_DISTRIBUTE=ON -DWITH_SWIG_PY=ON -DWITH_PYTHON=ON ``` ## 更新训练脚本 -这里,我们以[Deep Learing 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html)课程中的第一章 fit a line 为例。 +这里,我们以[Deep Learing 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html)课程中的第一章 fit a line 为例,描述如何将单机训练脚本改造成支持集群训练的版本。 ### 单机训练脚本示例 ```python import paddle.v2 as paddle @@ -60,7 +60,7 @@ exit(1) 我们创建了一个简单的全连接神经网络程序,并且通过Fluid的Executor执行了100次迭代,现在我们需要将该单机版本的程序更新为分布式版本的程序。 ### 介绍Parameter Server -在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算和保存任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为Parameter Server, [Parameter Server 设计文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) +在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算、保存和优化任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为[Parameter Server](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) **因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是Parameter Server和Trainer。** @@ -99,14 +99,14 @@ for pass_id in range(100): 分布式任务的运行需要将表格中说明的多个参数进行赋值: | 参数名 | 值类型 | 说明 | 示例 | -|:-------------|:---|:---------------------------------------|:-------------| +|:-------------|:------|:---------------------------------------|:-------------| | trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | | pservers | str | parameter server 列表 | 127.0.0.1:6710,127.0.0.1:6711 | | trainers | int | 训练节点的总个数,>0的数字 | 4 | | server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | | training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | -**其中:training_role 是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,至于如何从外部环境传入,用户可自定义。** +**注意:** ```training_role```是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,样例如下: 参数赋值及使用的相关代码片段: ```python @@ -122,21 +122,18 @@ if training_role == "PSERVER": pserver_startup = t.get_startup_program(server_endpoint, pserver_prog) ``` -### 启动顺序 -先启动全部的PSERVER (Parameter Server)后,再启动TRAINER(Trainer)。 - ### Demo 完整的demo代码位于Fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。 -``` +```bash cd /paddle/python/paddle/fluid/tests/book ``` -第一步:启动Parameter Server, 启动Parameter Server的命令: -``` +第一步:参考如下命令启动Parameter Server: +```bash PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py ``` -执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ``` +执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ```, 表示Paramter Server已经正常启动。 第二步:启动Trainer, 启动Trainer的命令: -``` +```bash PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py ``` 由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。 From 55a55839b6ce52aa2953ba78639b9130775f414b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 20:10:02 +0800 Subject: [PATCH 68/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 3bcce85ba8..2bf9584d81 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -65,7 +65,7 @@ exit(1) **因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是Parameter Server和Trainer。** ### 分布式训练 -Fliud专门提供了工具[Distributed Transpiler](https://github.com/PaddlePaddle/Paddle/blob/ba65d54d9d3b41cd3c5171b00f476d4e60133ddb/doc/fluid/design/dist_train/distributed_architecture.md#distributed-transpiler)用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recive 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 +Fliud专门提供了工具[Distributed Transpiler](https://github.com/PaddlePaddle/Paddle/blob/ba65d54d9d3b41cd3c5171b00f476d4e60133ddb/doc/fluid/design/dist_train/distributed_architecture.md#distributed-transpiler)用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recv 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 ```python optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) ``` @@ -124,17 +124,21 @@ if training_role == "PSERVER": ### Demo 完整的demo代码位于Fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。 +第一步,进入demo代码所在目录: ```bash cd /paddle/python/paddle/fluid/tests/book ``` -第一步:参考如下命令启动Parameter Server: + +第二步,参考如下命令启动Parameter Server: ```bash PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py ``` 执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ```, 表示Paramter Server已经正常启动。 -第二步:启动Trainer, 启动Trainer的命令: + +第三步,启动Trainer, 启动Trainer的命令: ```bash PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py ``` 由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。 + 现在我们就启动了一个包含一个Parameter Server和两个Trainer的分布式训练任务。 From f5eaa32dcae2ce9df6926fb485ae1073d78dca08 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 20:14:09 +0800 Subject: [PATCH 69/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 2bf9584d81..53ead76324 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -124,6 +124,7 @@ if training_role == "PSERVER": ### Demo 完整的demo代码位于Fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。 + 第一步,进入demo代码所在目录: ```bash cd /paddle/python/paddle/fluid/tests/book From b577277e30fb48aa4705dca060b5c3e141f3592e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 21 Mar 2018 20:24:20 +0800 Subject: [PATCH 70/79] fluid_cluster_train_cn_doc --- doc/fluid/howto/cluster/fluid_cluster_train_cn.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md index 53ead76324..1b6f767869 100644 --- a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -1,5 +1,5 @@ # Fluid 分布式版本使用指南 -本篇文章将说明在PaddlePaddle Fluid版本下进行分布式训练的配置和执行 +本篇文章将说明如何在PaddlePaddle Fluid版本下进行分布式训练的配置和执行,以及将单机训练脚本改造成支持集群训练的版本 ## 准备工作 * 可用的集群 @@ -8,6 +8,7 @@ * 安装PaddlePaddle Fluid with Distribution版本 所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。 + **注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。 cmake编译命令中需要将WITH_DISTRIBUTE设置为ON,下面是一个cmake编译指令示例: ``` bash @@ -108,7 +109,6 @@ for pass_id in range(100): **注意:** ```training_role```是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,样例如下: -参数赋值及使用的相关代码片段: ```python t = fluid.DistributeTranspiler() t.transpile( @@ -130,13 +130,13 @@ if training_role == "PSERVER": cd /paddle/python/paddle/fluid/tests/book ``` -第二步,参考如下命令启动Parameter Server: +第二步,启动Parameter Server: ```bash PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py ``` 执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ```, 表示Paramter Server已经正常启动。 -第三步,启动Trainer, 启动Trainer的命令: +第三步,启动Trainer: ```bash PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py ``` From 7c1472427089448989994017c449cc2516f4b99d Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Wed, 21 Mar 2018 09:07:13 -0700 Subject: [PATCH 71/79] Add default value of keyword argument to DocString (#9262) --- python/paddle/fluid/concurrency.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index 0fc4981a8e..3e4292d235 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -131,7 +131,7 @@ def make_channel(dtype, capacity=0): return channel -def channel_send(channel, value, copy=False): +def channel_send(channel, value, is_copy=False): """ Sends a value through a channel variable. Used by an unbuffered or buffered channel to pass data from within or to a concurrent Go block, where @@ -141,8 +141,8 @@ def channel_send(channel, value, copy=False): channel (Variable|Channel): Channel variable created using `make_channel`. value (Variable): Value to send to channel - copy (bool): Copy data while channel send. If False, then data - is moved. The input cannot be used after move. + is_copy (bool): Copy data while channel send. If False, then data + is moved. The input cannot be used after move. (default False) Returns: Variable: The boolean status on whether or not the channel successfully sent the passed value. @@ -166,7 +166,7 @@ def channel_send(channel, value, copy=False): X = value - if copy is True: + if is_copy is True: copied_X = helper.create_variable( name=unique_name.generate(value.name + '_copy'), type=value.type, From 1d8fe2a22026e15b009eca8df3c0d0f2fbef3451 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 22 Mar 2018 08:31:26 +0800 Subject: [PATCH 72/79] Enhance device context pool (#9293) --- paddle/fluid/platform/device_context.cc | 35 ++++++++++++++----------- paddle/fluid/platform/device_context.h | 18 +++---------- paddle/fluid/platform/place.h | 12 +++++++++ 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 98b4178177..59b76a1edb 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -10,43 +10,45 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_context.h" +#include #include "paddle/fluid/memory/memory.h" - namespace paddle { namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -const platform::DeviceContext* DeviceContextPool::Get( - const platform::Place& place) { +platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW( "'Place' is not supported, Please re-compile with WITH_GPU " "option"); } - return it->second; + return it->second.get(); } DeviceContextPool::DeviceContextPool( const std::vector& places) { PADDLE_ENFORCE_GT(places.size(), 0); - for (size_t i = 0; i < places.size(); i++) { - if (platform::is_cpu_place(places[i])) { + using PtrType = std::unique_ptr; + std::unordered_set set; + for (auto& p : places) { + set.insert(p); + } + + for (auto& p : set) { + if (platform::is_cpu_place(p)) { #ifdef PADDLE_WITH_MKLDNN - device_contexts_.emplace(places[i], - new platform::MKLDNNDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new MKLDNNDeviceContext(boost::get(p)))); #else - device_contexts_.emplace(places[i], - new platform::CPUDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CPUDeviceContext(boost::get(p)))); #endif - } else if (platform::is_gpu_place(places[i])) { + } else if (platform::is_gpu_place(p)) { #ifdef PADDLE_WITH_CUDA - device_contexts_.emplace(places[i], - new platform::CUDADeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CUDADeviceContext(boost::get(p)))); #else PADDLE_THROW( "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " @@ -159,6 +161,7 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { + std::lock_guard guard(mutex_); PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaGetLastError()); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 603b890af1..202394c7be 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; + mutable std::mutex mutex_; cudaStream_t stream_; cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; @@ -159,7 +160,7 @@ class DeviceContextPool { } /*! \brief Return handle of single device context. */ - const platform::DeviceContext* Get(const platform::Place& place); + platform::DeviceContext* Get(const platform::Place& place); template const typename DefaultDeviceContextType::TYPE* GetByPlace( @@ -172,19 +173,8 @@ class DeviceContextPool { private: static DeviceContextPool* pool; - constexpr static int LEFT_SHIFT = 8; - struct Hash { - std::hash hash_; - size_t operator()(const platform::Place& place) const { - int pre_hash = place.which() << LEFT_SHIFT; - if (platform::is_gpu_place(place)) { - pre_hash += boost::get(place).GetDeviceId(); - } - return hash_(pre_hash); - } - }; - std::unordered_map + std::unordered_map, PlaceHash> device_contexts_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index 501bddfc6e..4cc8b377b8 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -65,6 +65,18 @@ bool is_cpu_place(const Place &); bool places_are_same_class(const Place &, const Place &); bool is_same_place(const Place &, const Place &); +struct PlaceHash { + std::size_t operator()(const Place &p) const { + constexpr size_t num_dev_bits = 4; + std::hash ihash; + size_t dev_id = 0; + if (is_gpu_place(p)) { + dev_id = boost::get(p).device; + } + return ihash(dev_id << num_dev_bits | p.which()); + } +}; + std::ostream &operator<<(std::ostream &, const Place &); template From 8440046b7f69a34e4d593bf1b8c4fe997270a6d9 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 22 Mar 2018 10:14:48 +0800 Subject: [PATCH 73/79] fix doc --- python/paddle/trainer_config_helpers/layers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index eac2cb3168..3684d1e8f7 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2747,17 +2747,17 @@ def img_pool_layer(input, .. math:: - w & = 1 + \\frac{ceil(input\_width + 2 * padding - pool\_size)}{stride} + w & = 1 + ceil(\\frac{input\_width + 2 * padding - pool\_size}{stride}) - h & = 1 + \\frac{ceil(input\_height + 2 * padding\_y - pool\_size\_y)}{stride\_y} + h & = 1 + ceil(\\frac{input\_height + 2 * padding\_y - pool\_size\_y}{stride\_y}) - ceil_mode=False: .. math:: - w & = 1 + \\frac{floor(input\_width + 2 * padding - pool\_size)}{stride} + w & = 1 + floor(\\frac{input\_width + 2 * padding - pool\_size}{stride}) - h & = 1 + \\frac{floor(input\_height + 2 * padding\_y - pool\_size\_y)}{stride\_y} + h & = 1 + floor(\\frac{input\_height + 2 * padding\_y - pool\_size\_y}{stride\_y}) The example usage is: From d70a70bcdac3c7382be999ee685ae8c7e50cd381 Mon Sep 17 00:00:00 2001 From: weixing02 <564445201@qq.com> Date: Thu, 22 Mar 2018 10:18:10 +0800 Subject: [PATCH 74/79] Modified build.sh and remove build_doc.sh --- paddle/scripts/docker/build.sh | 6 +++--- paddle/scripts/tools/build_docs/.gitignore | 2 -- paddle/scripts/tools/build_docs/build_docs.sh | 8 -------- 3 files changed, 3 insertions(+), 13 deletions(-) delete mode 100644 paddle/scripts/tools/build_docs/.gitignore delete mode 100755 paddle/scripts/tools/build_docs/build_docs.sh diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 6be2bd8fad..2e9b088bfa 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -35,7 +35,7 @@ function cmake_gen() { -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} ${PYTHON_FLAGS} -DWITH_DSO=ON - -DWITH_DOC=OFF + -DWITH_DOC=${WITH_DOC:-OFF} -DWITH_GPU=${WITH_GPU:-OFF} -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} -DWITH_MKL=${WITH_MKL:-ON} @@ -60,7 +60,7 @@ EOF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} \ ${PYTHON_FLAGS} \ -DWITH_DSO=ON \ - -DWITH_DOC=OFF \ + -DWITH_DOC=${WITH_DOC:-OFF} \ -DWITH_GPU=${WITH_GPU:-OFF} \ -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} \ -DWITH_MKL=${WITH_MKL:-ON} \ @@ -231,7 +231,7 @@ gen_capi_package gen_fluid_inference_lib if [[ ${WITH_C_API:-OFF} == "ON" ]]; then - printf "PaddlePaddle C-API libraries was generated on build/paddle.tgz\n" + printf "PaddlePaddle C-API libraries was generated on build/paddle.tgz\n" else printf "If you need to install PaddlePaddle in develop docker image," printf "please make install or pip install build/python/dist/*.whl.\n" diff --git a/paddle/scripts/tools/build_docs/.gitignore b/paddle/scripts/tools/build_docs/.gitignore deleted file mode 100644 index 6ec14c8f5b..0000000000 --- a/paddle/scripts/tools/build_docs/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -doc -doc_cn diff --git a/paddle/scripts/tools/build_docs/build_docs.sh b/paddle/scripts/tools/build_docs/build_docs.sh deleted file mode 100755 index f9bc8bf63a..0000000000 --- a/paddle/scripts/tools/build_docs/build_docs.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -docker run --rm \ - -v $(git rev-parse --show-toplevel):/paddle \ - -e "WITH_GPU=OFF" \ - -e "WITH_AVX=ON" \ - -e "WITH_DOC=ON" \ - -e "WOBOQ=ON" \ - ${1:-"paddlepaddle/paddle:latest-dev"} From 990d6396fed3708d1f1eaa5ad87a9a4c3e841c5c Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 22 Mar 2018 10:47:05 +0800 Subject: [PATCH 75/79] Reuduce memory copy when communication between trainer and pserver. (#9271) --- benchmark/cluster/vgg16/vgg16_fluid.py | 52 ++- benchmark/cluster/vgg16/vgg16_tf.py | 10 +- paddle/fluid/operators/detail/CMakeLists.txt | 6 +- .../operators/detail/bytebuffer_stream.h | 134 ++++++ paddle/fluid/operators/detail/grpc_client.cc | 39 +- paddle/fluid/operators/detail/grpc_client.h | 38 +- paddle/fluid/operators/detail/grpc_server.cc | 92 ++-- paddle/fluid/operators/detail/grpc_server.h | 36 +- paddle/fluid/operators/detail/grpc_service.h | 118 ++++++ paddle/fluid/operators/detail/send_recv.proto | 6 +- .../operators/detail/sendrecvop_utils.cc | 129 +----- .../fluid/operators/detail/sendrecvop_utils.h | 12 +- paddle/fluid/operators/detail/test_serde.cc | 177 ++++---- .../operators/detail/variable_response.cc | 400 ++++++++++++++++++ .../operators/detail/variable_response.h | 81 ++++ paddle/fluid/operators/listen_and_serv_op.cc | 9 +- python/paddle/fluid/debuger.py | 2 - python/paddle/fluid/distribute_transpiler.py | 2 + 18 files changed, 1021 insertions(+), 322 deletions(-) create mode 100644 paddle/fluid/operators/detail/grpc_service.h create mode 100644 paddle/fluid/operators/detail/variable_response.cc create mode 100644 paddle/fluid/operators/detail/variable_response.h diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 786f224608..8b29227cfa 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -18,12 +18,13 @@ import sys import time import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid as fluid -import paddle.v2.fluid.core as core -import paddle.v2.fluid.profiler as profiler +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.profiler as profiler import argparse import functools import os +from paddle.fluid import debuger def str2bool(v): @@ -182,28 +183,27 @@ def main(): start_time = time.time() num_samples = 0 train_pass_acc.reset() - with profiler.profiler("CPU", 'total') as prof: - for batch_id, data in enumerate(train_reader()): - ts = time.time() - img_data = np.array( - map(lambda x: x[0].reshape(data_shape), data)).astype( - "float32") - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - y_data = y_data.reshape([-1, 1]) - - loss, acc, b_size = exe.run( - trainer_prog, - feed={"pixel": img_data, - "label": y_data}, - fetch_list=[avg_cost, batch_acc, batch_size]) - iters += 1 - num_samples += len(data) - train_pass_acc.add(value=acc, weight=b_size) - print( - "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" - % (pass_id, iters, loss, acc, - len(data) / (time.time() - ts)) - ) # The accuracy is the accumulation of batches, but not the current batch. + for batch_id, data in enumerate(train_reader()): + ts = time.time() + img_data = np.array( + map(lambda x: x[0].reshape(data_shape), data)).astype( + "float32") + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + y_data = y_data.reshape([-1, 1]) + + loss, acc, b_size = exe.run( + trainer_prog, + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[avg_cost, batch_acc, batch_size]) + iters += 1 + num_samples += len(data) + train_pass_acc.add(value=acc, weight=b_size) + print( + "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" + % (pass_id, iters, loss, acc, + len(data) / (time.time() - ts)) + ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time pass_train_acc = train_pass_acc.eval() @@ -254,9 +254,7 @@ def main(): pserver_prog = t.get_pserver_program(current_endpoint) pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) - print("starting server side startup") exe.run(pserver_startup) - print("starting parameter server...") exe.run(pserver_prog) elif training_role == "TRAINER": # Parameter initialization diff --git a/benchmark/cluster/vgg16/vgg16_tf.py b/benchmark/cluster/vgg16/vgg16_tf.py index 996df0e314..2d220478ac 100644 --- a/benchmark/cluster/vgg16/vgg16_tf.py +++ b/benchmark/cluster/vgg16/vgg16_tf.py @@ -292,14 +292,18 @@ def run_benchmark(cluster_spec, server): return np.mean(test_accs) config = tf.ConfigProto( - intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1, + log_device_placement=True) config.gpu_options.allow_growth = True hooks = [tf.train.StopAtStepHook(last_step=1000000)] with tf.train.MonitoredTrainingSession( - master=server.target, is_chief=(args.task_index == 0), - hooks=hooks) as sess: + master=server.target, + is_chief=(args.task_index == 0), + hooks=hooks, + config=config) as sess: iters, num_samples, start_time = 0, 0, 0.0 for pass_id in range(args.num_passes): # train diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 94395ccfbc..2b19f04489 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,6 +1,8 @@ if(WITH_DISTRIBUTE) - grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc + grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) + cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr + cares zlib protobuf sendrecvop_grpc) endif() diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.h b/paddle/fluid/operators/detail/bytebuffer_stream.h index 099deb12d0..0cbe514d04 100644 --- a/paddle/fluid/operators/detail/bytebuffer_stream.h +++ b/paddle/fluid/operators/detail/bytebuffer_stream.h @@ -23,9 +23,107 @@ limitations under the License. */ #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" +namespace grpc { +// A ZeroCopyInputStream that reads from grpc_byte_buffer +class GrpcBufferReader final + : public ::google::protobuf::io::ZeroCopyInputStream { + typedef void (CoreCodegenInterface::*OldReaderInitAPI)( + grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); + typedef int (CoreCodegenInterface::*NewReaderInitAPI)( + grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); + void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader, + grpc_byte_buffer* buffer) { + (g_core_codegen_interface->*ptr)(reader, buffer); + } + void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader, + grpc_byte_buffer* buffer) { + int result = (g_core_codegen_interface->*ptr)(reader, buffer); + (void)result; + } + + public: + explicit GrpcBufferReader(grpc_byte_buffer* buffer) + : byte_count_(0), backup_count_(0) { + ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_, + buffer); + } + ~GrpcBufferReader() override { + g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_); + } + + bool Next(const void** data, int* size) override { + if (backup_count_ > 0) { + *data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) - + backup_count_; + GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX); + *size = (int)backup_count_; + backup_count_ = 0; + return true; + } + if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_, + &slice_)) { + return false; + } + g_core_codegen_interface->grpc_slice_unref(slice_); + *data = GRPC_SLICE_START_PTR(slice_); + // On win x64, int is only 32bit + GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX); + byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_); + return true; + } + + void BackUp(int count) override { backup_count_ = count; } + + bool Skip(int count) override { + const void* data; + int size; + while (Next(&data, &size)) { + if (size >= count) { + BackUp(size - count); + return true; + } + // size < count; + count -= size; + } + // error or we have too large count; + return false; + } + + ::google::protobuf::int64 ByteCount() const override { + return byte_count_ - backup_count_; + } + + private: + int64_t byte_count_; + int64_t backup_count_; + grpc_byte_buffer_reader reader_; + grpc_slice slice_; +}; + +}; // namespace grpc + namespace paddle { namespace operators { namespace detail { +// Source provides a way for a particular RPC implementation to provide +// received data to ParseFrom. +class Source { + public: + virtual ~Source() {} + + // Return the stream that contains the data to be parsed. + // Note that this method might be invoked more than once if + // ParseFrom needs to fall back to a more expensive parsing method. + // Every call must return a stream pointing at the beginning of + // the serialized RecvTensorResponse. + // + // Note that a subsequent call to contents() invalidates previous + // results of contents(). + // + // Ownership of the returned stream is retained by the Source and + // should not be deleted by the caller. + virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0; +}; // A ZeroCopyInputStream that reads from a grpc::ByteBuffer. class GrpcByteBufferSource @@ -46,6 +144,42 @@ class GrpcByteBufferSource ::google::protobuf::int64 byte_count_; }; +class GrpcByteBufferSourceWrapper : public Source { + public: + GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source) : source_(source) {} + virtual ::google::protobuf::io::ZeroCopyInputStream* contents() override { + return source_; + } + + private: + GrpcByteBufferSource* source_; +}; + +class GrpcByteSource : public Source { + public: + explicit GrpcByteSource(grpc_byte_buffer* buffer) : buffer_(buffer) {} + ~GrpcByteSource() override { DeleteStream(); } + + typedef ::grpc::GrpcBufferReader Reader; + + ::google::protobuf::io::ZeroCopyInputStream* contents() override { + DeleteStream(); + stream_ = new (&space_) Reader(buffer_); + return stream_; + } + + private: + void DeleteStream() { + if (stream_) { + stream_->~Reader(); + } + } + + grpc_byte_buffer* buffer_; // Not owned + Reader* stream_ = nullptr; // Points into space_ if non-nullptr + char space_[sizeof(Reader)]; +}; + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ddeeebec58..eb19685aa6 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "grpc_client.h" +#include #include "paddle/fluid/framework/threadpool.h" + namespace paddle { namespace operators { namespace detail { @@ -31,8 +33,9 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { auto* var = p_scope->FindVar(var_name_val); - sendrecv::VariableMessage req; - SerializeToMessage(var_name_val, var, *p_ctx, &req); + + ::grpc::ByteBuffer req; + SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); // varhandle VarHandle var_h; @@ -46,8 +49,11 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, s->Prepare(var_h, time_out); s->response_call_back_ = NULL; - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, (void*)s); + auto call = std::move(s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, + &cq_)); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, (void*)s); }); req_count_++; @@ -56,9 +62,19 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, } void ProcGetResponse(const VarHandle& var_h, - const sendrecv::VariableMessage& ret_msg) { - auto* outvar = var_h.scope->FindVar(var_h.name); - DeserializeFromMessage(ret_msg, *var_h.ctx, outvar); + // const sendrecv::VariableMessage& ret_msg) { + const ::grpc::ByteBuffer& ret_msg) { + framework::Variable* outvar = NULL; + DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, outvar); +} + +template +void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { + ::grpc::Slice slice(proto.ByteSizeLong()); + proto.SerializeWithCachedSizesToArray( + const_cast(reinterpret_cast(slice.begin()))); + ::grpc::ByteBuffer tmp(&slice, 1); + result->Swap(&tmp); } bool RPCClient::AsyncGetVariable(const std::string& ep, @@ -88,8 +104,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, s->Prepare(var_h, time_out); s->response_call_back_ = ProcGetResponse; - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, (void*)s); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); + + auto call = std::move(s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_)); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, (void*)s); }); req_count_++; diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index f520367dd9..8216ac52fb 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -25,6 +25,11 @@ limitations under the License. */ #include #include +#include +#include +#include +#include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -49,15 +54,11 @@ struct VarHandle { } }; -void ProcGetResponse(const VarHandle& var_h, - const sendrecv::VariableMessage& msg); +void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { public: - explicit BaseProcessor(std::shared_ptr ch) { - stub_ = sendrecv::SendRecvService::NewStub(ch); - context_ = NULL; - } + explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; } virtual ~BaseProcessor() {} @@ -82,19 +83,18 @@ class BaseProcessor { virtual void Process() = 0; - std::unique_ptr stub_; std::unique_ptr context_; grpc::Status status_; VarHandle var_h_; }; -typedef std::function +typedef std::function RequestSendCallBack; class SendProcessor : public BaseProcessor { public: explicit SendProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch), stub_g_(ch) {} virtual ~SendProcessor() {} @@ -104,17 +104,18 @@ class SendProcessor : public BaseProcessor { } } - sendrecv::VoidMessage reply_; + ::grpc::GenericStub stub_g_; + ::grpc::ByteBuffer reply_; RequestSendCallBack response_call_back_ = NULL; }; -typedef std::function +typedef std::function RequestGetCallBack; class GetProcessor : public BaseProcessor { public: explicit GetProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch), stub_g_(ch) {} virtual ~GetProcessor() {} @@ -124,30 +125,37 @@ class GetProcessor : public BaseProcessor { } } - sendrecv::VariableMessage reply_; + ::grpc::ByteBuffer reply_; + ::grpc::GenericStub stub_g_; RequestGetCallBack response_call_back_ = ProcGetResponse; }; class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } virtual ~BatchBarrierProcessor() {} virtual void Process() {} sendrecv::VoidMessage reply_; + std::unique_ptr stub_; }; class FetchBarrierProcessor : public BaseProcessor { public: explicit FetchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } virtual ~FetchBarrierProcessor() {} virtual void Process() {} sendrecv::VariableMessage reply_; + std::unique_ptr stub_; }; class RPCClient { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 8fff430cc4..9691d1e86b 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_server.h" -using grpc::ServerAsyncResponseWriter; +using ::grpc::ServerAsyncResponseWriter; namespace paddle { namespace operators { @@ -26,9 +26,10 @@ enum CallStatus { PROCESS = 0, FINISH }; // https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server class RequestBase { public: - explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq) - : service_(service), cq_(cq), status_(PROCESS) { + explicit RequestBase(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + const platform::DeviceContext* dev_ctx) + : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { PADDLE_ENFORCE(cq_); } virtual ~RequestBase() {} @@ -42,55 +43,58 @@ class RequestBase { } protected: - grpc::ServerContext ctx_; - sendrecv::SendRecvService::AsyncService* service_; - grpc::ServerCompletionQueue* cq_; + ::grpc::ServerContext ctx_; + GrpcService::AsyncService* service_; + ::grpc::ServerCompletionQueue* cq_; CallStatus status_; + const platform::DeviceContext* dev_ctx_; }; -typedef std::pair MessageWithName; - class RequestSend final : public RequestBase { public: - explicit RequestSend(sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq, - SimpleBlockQueue* queue) - : RequestBase(service, cq), queue_(queue), responder_(&ctx_) { - service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_, - this); + explicit RequestSend(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + framework::Scope* scope, ReceivedQueue* queue, + const platform::DeviceContext* dev_ctx) + : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { + request_.reset(new VariableResponse(scope, dev_ctx_)); + int method_id = static_cast(detail::GrpcMethod::kSendVariable); + service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, + cq_, cq_, this); } virtual ~RequestSend() {} - virtual std::string GetReqName() { return request_.varname(); } + virtual std::string GetReqName() { return request_->Varname(); } virtual void Process() { - MessageWithName msg_with_name = - std::make_pair(request_.varname(), std::move(request_)); - queue_->Push(std::move(msg_with_name)); - responder_.Finish(reply_, grpc::Status::OK, this); + queue_->Push(std::make_pair(request_->Varname(), request_)); + + sendrecv::VoidMessage reply; + responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; } protected: - sendrecv::VariableMessage request_; - sendrecv::VoidMessage reply_; - SimpleBlockQueue* queue_; + std::shared_ptr request_; + ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; }; class RequestGet final : public RequestBase { public: - explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq, framework::Scope* scope, + explicit RequestGet(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + framework::Scope* scope, const platform::DeviceContext* dev_ctx, SimpleBlockQueue* queue) - : RequestBase(service, cq), + : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope), - dev_ctx_(dev_ctx), queue_(queue) { - service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); + int method_id = static_cast(detail::GrpcMethod::kGetVariable); + service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, + cq_, this); } virtual ~RequestGet() {} @@ -101,24 +105,26 @@ class RequestGet final : public RequestBase { // proc request. std::string var_name = request_.varname(); auto* var = scope_->FindVar(var_name); + + ::grpc::ByteBuffer reply; if (var_name != FETCH_BARRIER_MESSAGE) { - SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); } - // TODO(gongwb): check var's info. - responder_.Finish(reply_, grpc::Status::OK, this); + + responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; - MessageWithName msg_with_name = - // request name reply - std::make_pair(var_name, std::move(reply_)); - queue_->Push(msg_with_name); + + if (var_name == FETCH_BARRIER_MESSAGE) { + sendrecv::VariableMessage msg; + MessageWithName msg_with_name = std::make_pair(var_name, msg); + queue_->Push(msg_with_name); + } } protected: sendrecv::VariableMessage request_; - sendrecv::VariableMessage reply_; - ServerAsyncResponseWriter responder_; + ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; - const platform::DeviceContext* dev_ctx_; SimpleBlockQueue* queue_; }; @@ -133,8 +139,8 @@ void AsyncGRPCServer::WaitClientGet(int count) { } void AsyncGRPCServer::RunSyncUpdate() { - grpc::ServerBuilder builder; - builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); + ::grpc::ServerBuilder builder; + builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials()); builder.SetMaxSendMessageSize(std::numeric_limits::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); builder.RegisterService(&service_); @@ -182,8 +188,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { if (is_shut_down_) { return; } - RequestSend* send = - new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); + RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, + &var_recv_queue_, dev_ctx_); VLOG(4) << "Create RequestSend status:" << send->Status(); } @@ -198,7 +204,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { } // FIXME(typhoonzero): change cq_name to enum. -void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq, +void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, std::string cq_name, std::function TryToRegisterNewOne) { TryToRegisterNewOne(); diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index b6666bcf96..9c21a07432 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -14,28 +14,35 @@ limitations under the License. */ #pragma once +#include +#include + #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/simple_block_queue.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" -#include -#include -#include -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/detail/grpc_service.h" + +//#include namespace paddle { namespace operators { namespace detail { +typedef std::pair> + ReceivedMessage; +typedef SimpleBlockQueue ReceivedQueue; + typedef std::pair MessageWithName; class RequestBase; -class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { +class AsyncGRPCServer final { public: explicit AsyncGRPCServer(const std::string &address) : address_(address) {} @@ -50,14 +57,16 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } - const MessageWithName Get() { return this->var_recv_queue_.Pop(); } + const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } - void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } + void Push(const std::string &msg_name) { + this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr)); + } void ShutDown(); protected: - void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name, + void HandleRequest(::grpc::ServerCompletionQueue *cq, std::string cq_name, std::function TryToRegisterNewOne); void TryToRegisterNewSendOne(); void TryToRegisterNewGetOne(); @@ -66,18 +75,19 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { private: std::mutex cq_mutex_; volatile bool is_shut_down_ = false; - std::unique_ptr cq_send_; - std::unique_ptr cq_get_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; - sendrecv::SendRecvService::AsyncService service_; - std::unique_ptr server_; + GrpcService::AsyncService service_; + std::unique_ptr<::grpc::Server> server_; std::string address_; framework::Scope *scope_; const platform::DeviceContext *dev_ctx_; + // received variable from RPC, operators fetch variable from this queue. - SimpleBlockQueue var_recv_queue_; SimpleBlockQueue var_get_queue_; + ReceivedQueue var_recv_queue_; // condition of the sub program std::mutex barrier_mutex_; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h new file mode 100644 index 0000000000..ae6f9db3bd --- /dev/null +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -0,0 +1,118 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/operators/detail/variable_response.h" + +// NOTE: This method was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// method and did some modifications so that we can parse gRPC +// requests without too much copying of the tensor data. + +namespace grpc { +class CompletionQueue; +class Channel; +class RpcService; +class ServerCompletionQueue; +class ServerContext; + +// Support parsing/unparsing of tensorflow::VariableResponse. +// Wire-format is identical to RecvVariableResponse. +template <> +class SerializationTraits { + public: + static Status Serialize( + const paddle::operators::detail::VariableResponse& msg, + grpc_byte_buffer** bp, bool* own_buffer) { + PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); + return Status(); + } + static Status Deserialize(grpc_byte_buffer* buffer, + paddle::operators::detail::VariableResponse* msg, + int max_message_size = INT_MAX) { + if (buffer == nullptr) { + return Status(StatusCode::INTERNAL, "No payload"); + } + + Status result = g_core_codegen_interface->ok(); + if (result.ok()) { + paddle::operators::detail::GrpcByteSource source(buffer); + int ret = msg->Parse(&source); + if (ret != 0) { + result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); + } + } + g_core_codegen_interface->grpc_byte_buffer_destroy(buffer); + return result; + } +}; +} // namespace grpc + +namespace paddle { +namespace operators { +namespace detail { + +enum class GrpcMethod { + kSendVariable, + kGetVariable, +}; + +static const int kGrpcNumMethods = + static_cast(GrpcMethod::kGetVariable) + 1; + +inline const char* GrpcMethodName(GrpcMethod id) { + switch (id) { + case GrpcMethod::kSendVariable: + return "/sendrecv.SendRecvService/SendVariable"; + case GrpcMethod::kGetVariable: + return "/sendrecv.SendRecvService/GetVariable"; + } + + // Shouldn't be reached. + PADDLE_ENFORCE(false, "Invalid id: not found valid method name"); + return nullptr; +} + +class GrpcService final { + public: + class AsyncService : public ::grpc::Service { + public: + AsyncService() { + for (int i = 0; i < kGrpcNumMethods; ++i) { + AddMethod(new ::grpc::internal::RpcServiceMethod( + GrpcMethodName(static_cast(i)), + ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); + ::grpc::Service::MarkMethodAsync(i); + } + } + virtual ~AsyncService() {} + + // Make RequestAsyncUnary public for grpc_call.h + using ::grpc::Service::RequestAsyncUnary; + }; +}; + +} // namespace detail +} // namespace operator +} // namespace paddle diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index b0215d4a80..598aaa4c51 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -32,6 +32,9 @@ enum VarType { SELECTED_ROWS = 1; } +// NOTICE(gongwb):don't modify this proto if you are not +// not familar with how we serialize in sendrecvop_utils.h +// and deserilize it in variable_response.h. message VariableMessage { enum Type { // Pod Types @@ -45,7 +48,6 @@ message VariableMessage { } message LodData { repeated int64 lod_data = 1; } - string varname = 1; // TODO(Yancey1989): reference framework::proto::VarDesc::VarType VarType type = 2; @@ -64,3 +66,5 @@ message VariableMessage { } message VoidMessage {} + +message TestMessage { int64 test_1 = 1; } diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 39117eeeb6..d7bbf79c50 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -13,61 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include +#include #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/detail/proto_encoder_helper.h" +#include "paddle/fluid/operators/detail/variable_response.h" namespace paddle { namespace operators { namespace detail { -void SerializeToMessage(const std::string& name, const framework::Variable* var, - const platform::DeviceContext& ctx, - sendrecv::VariableMessage* msg) { - msg->set_varname(name); - std::ostringstream oss; - switch (framework::ToVarType(var->Type())) { - case framework::proto::VarType_Type_LOD_TENSOR: - msg->set_type(sendrecv::VarType::LOD_TENSOR); - framework::SerializeToStream(oss, var->Get(), ctx); - break; - case framework::proto::VarType_Type_SELECTED_ROWS: - msg->set_type(sendrecv::VarType::SELECTED_ROWS); - framework::SerializeToStream(oss, var->Get(), - ctx); - break; - default: { - PADDLE_THROW("Serialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } - msg->set_serialized(oss.str()); -} - -void DeserializeFromMessage(const sendrecv::VariableMessage& msg, - const platform::DeviceContext& ctx, - framework::Variable* var) { - std::istringstream iss(msg.serialized()); - switch (msg.type()) { - case sendrecv::VarType::LOD_TENSOR: - DeserializeFromStream(iss, var->GetMutable(), ctx); - break; - case sendrecv::VarType::SELECTED_ROWS: { - DeserializeFromStream(iss, var->GetMutable(), - ctx); - break; - } - default: { - PADDLE_THROW("Deserialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } -} - void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg) { @@ -123,6 +81,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, static_cast(ctx); auto copy_size = tensor.memory_size(); payload = memory::Alloc(cpu, copy_size); + memory::Copy(cpu, payload, boost::get(tensor.place()), reinterpret_cast(tensor.data()), @@ -132,6 +91,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, platform::CPUPlace cpu; memory::Free(cpu, backing); }; + #endif } else { payload = tensor.data(); @@ -219,80 +179,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, - framework::Variable* var) { - sendrecv::VariableMessage meta; - GrpcByteBufferSource source; - source.Init(msg); - ::google::protobuf::io::CodedInputStream input(&source); - // do zerocopy parsing - PADDLE_ENFORCE(meta.ParseFromCodedStream(&input)); - PADDLE_ENFORCE(input.ConsumedEntireMessage()); - // dims is needed by both tensor and selectedrows - std::vector vecdims; - for (auto& d : meta.dims()) { - vecdims.push_back(d); - } - framework::DDim dims = framework::make_ddim(vecdims); - - if (meta.type() == sendrecv::LOD_TENSOR) { - auto* tensor = var->GetMutable(); - tensor->Resize(dims); - void* tensor_data = tensor->mutable_data( - ctx.GetPlace(), - paddle::operators::detail::ToTypeIndex(meta.data_type())); - framework::LoD lod; - for (int i = 0; i < meta.lod_level(); ++i) { - framework::Vector v; - for (int j = 0; j < meta.lod(i).lod_data_size(); ++j) { - v.push_back(meta.lod(i).lod_data(j)); - } - lod.push_back(v); - } - tensor->set_lod(lod); - // How to avoid copying and use the message buffer directly? - // Maybe need to find a way to release all memory except tensor content. - if (platform::is_gpu_place(ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA - platform::CPUPlace cpu; - auto& gpu_dev_ctx = static_cast(ctx); - memory::Copy(boost::get(tensor->place()), - tensor_data, cpu, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size(), gpu_dev_ctx.stream()); - ctx.Wait(); -#endif - } else { - memcpy(tensor_data, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size()); - } - } else if (meta.type() == sendrecv::SELECTED_ROWS) { - auto* slr = var->GetMutable(); - auto* tensor = slr->mutable_value(); - int64_t* rows_data = slr->mutable_rows()->data(); - tensor->Resize(dims); - void* tensor_data = tensor->mutable_data( - ctx.GetPlace(), - paddle::operators::detail::ToTypeIndex(meta.data_type())); - if (platform::is_gpu_place(ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA - platform::CPUPlace cpu; - auto& gpu_dev_ctx = static_cast(ctx); - memory::Copy(boost::get(tensor->place()), - tensor_data, cpu, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size(), gpu_dev_ctx.stream()); - ctx.Wait(); -#endif - } else { - memcpy(tensor_data, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size()); - } - // copy rows CPU data, GPU data will be copied lazly - memcpy(rows_data, reinterpret_cast(meta.rows().data()), - meta.rows().size()); - } + const framework::Scope* scope, + framework::Variable*& var) { + operators::detail::VariableResponse resp(scope, &ctx); + PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); + var = resp.GetVar(); } } // namespace detail diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index 4fa6aefd3e..3b87562703 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" @@ -36,21 +37,14 @@ namespace detail { typedef void (*DestroyCallback)(void*); -void SerializeToMessage(const std::string& name, const framework::Variable* var, - const platform::DeviceContext& ctx, - sendrecv::VariableMessage* msg); - -void DeserializeFromMessage(const sendrecv::VariableMessage& msg, - const platform::DeviceContext& ctx, - framework::Variable* var); - void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg); void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, - framework::Variable* var); + const framework::Scope* scope, + framework::Variable*& var); inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { switch (type) { diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc index 2f06e5a686..4be5963794 100644 --- a/paddle/fluid/operators/detail/test_serde.cc +++ b/paddle/fluid/operators/detail/test_serde.cc @@ -16,11 +16,13 @@ limitations under the License. */ #include #include +#include #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" @@ -31,19 +33,21 @@ namespace operators = paddle::operators; namespace math = paddle::operators::math; namespace memory = paddle::memory; -void RunSerdeTestTensor(platform::Place place) { - // serialize var to ByteBuffer - framework::Variable var; - auto* tensor = var.GetMutable(); - tensor->Resize(framework::make_ddim({4, 8, 4, 2})); - framework::LoD lod; - lod.push_back(framework::Vector({1, 3, 8})); - tensor->set_lod(lod); - int tensor_numel = 4 * 8 * 4 * 2; +void RunSerdeTestSelectedRows(platform::Place place) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); + + // serialize var to ByteBuffer + framework::Variable var; + auto* slr = var.GetMutable(); + auto* tensor = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + tensor->Resize(framework::make_ddim({2, 10})); tensor->mutable_data(place); - math::set_constant(ctx, tensor, 31.9); + int tensor_numel = 2 * 10; + math::set_constant(ctx, tensor, 32.7); + rows->push_back(3); + rows->push_back(10); ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); @@ -56,62 +60,67 @@ void RunSerdeTestTensor(platform::Place place) { for (const auto& s : slices) { tmp.append(reinterpret_cast(s.begin()), s.size()); } + sendrecv::VariableMessage varmsg; EXPECT_TRUE(varmsg.ParseFromString(tmp)); + EXPECT_EQ(varmsg.varname(), "myvar"); - EXPECT_EQ(varmsg.type(), 0); - EXPECT_EQ(varmsg.dims()[0], 4); - EXPECT_EQ(varmsg.dims()[1], 8); - EXPECT_EQ(varmsg.dims()[2], 4); - EXPECT_EQ(varmsg.dims()[3], 2); - EXPECT_EQ(varmsg.lod_level(), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); - EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + EXPECT_EQ(varmsg.type(), 1); const float* tensor_data = reinterpret_cast(varmsg.serialized().data()); + const int64_t* rows_data = + reinterpret_cast(varmsg.rows().data()); for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data[i], 31.9); + EXPECT_FLOAT_EQ(tensor_data[i], 32.7); } - + EXPECT_EQ(rows_data[0], 3); + EXPECT_EQ(rows_data[1], 10); // deserialize zero-copy - framework::Variable var2; - operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); - auto tensor2 = var2.Get(); + // framework::Variable var2; + // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + framework::Scope scope; + scope.Var("myvar"); + operators::detail::TensorResponse resp(&scope, &ctx); + EXPECT_EQ(resp.Parse(msg), 0); + + framework::Variable* var2 = resp.GetVar(); + + auto* slr2 = var2->GetMutable(); + auto* tensor2 = slr2->mutable_value(); + auto* rows2 = slr2->mutable_rows(); float* tensor_data2 = nullptr; framework::Tensor tmp_tensor; if (platform::is_gpu_place(ctx.GetPlace())) { platform::CPUPlace cpu; - framework::TensorCopy(tensor2, cpu, &tmp_tensor); + framework::TensorCopy(*tensor2, cpu, &tmp_tensor); tensor_data2 = tmp_tensor.data(); } else { - tensor_data2 = const_cast(tensor2.data()); + tensor_data2 = const_cast(tensor2->data()); } + const int64_t* rows_data2 = rows2->data(); - EXPECT_EQ(varmsg.lod_level(), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); - EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); - for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); + for (int i = 0; i < tensor_numel; ++i) { + EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); + } + EXPECT_EQ(rows_data2[0], 3); + EXPECT_EQ(rows_data2[1], 10); } -void RunSerdeTestSelectedRows(platform::Place place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - +void RunTestLodTensor(platform::Place place, int from_type = 0) { // serialize var to ByteBuffer framework::Variable var; - auto* slr = var.GetMutable(); - auto* tensor = slr->mutable_value(); - auto* rows = slr->mutable_rows(); - tensor->Resize(framework::make_ddim({2, 10})); + auto* tensor = var.GetMutable(); + tensor->Resize(framework::make_ddim({4, 8, 4, 2})); + framework::LoD lod; + lod.push_back(framework::Vector({1, 3, 8})); + tensor->set_lod(lod); + int tensor_numel = 4 * 8 * 4 * 2; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); tensor->mutable_data(place); - int tensor_numel = 2 * 10; - math::set_constant(ctx, tensor, 32.7); - rows->push_back(3); - rows->push_back(10); + math::set_constant(ctx, tensor, 31.9); ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); @@ -126,43 +135,75 @@ void RunSerdeTestSelectedRows(platform::Place place) { } sendrecv::VariableMessage varmsg; EXPECT_TRUE(varmsg.ParseFromString(tmp)); - EXPECT_EQ(varmsg.varname(), "myvar"); - EXPECT_EQ(varmsg.type(), 1); + EXPECT_EQ(varmsg.type(), 0); + EXPECT_EQ(varmsg.dims()[0], 4); + EXPECT_EQ(varmsg.dims()[1], 8); + EXPECT_EQ(varmsg.dims()[2], 4); + EXPECT_EQ(varmsg.dims()[3], 2); + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); const float* tensor_data = reinterpret_cast(varmsg.serialized().data()); - const int64_t* rows_data = - reinterpret_cast(varmsg.rows().data()); for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data[i], 32.7); + EXPECT_FLOAT_EQ(tensor_data[i], 31.9); } - EXPECT_EQ(rows_data[0], 3); - EXPECT_EQ(rows_data[1], 10); + + // message binary + std::string str; + varmsg.SerializeToString(&str); + + // message bytebuffer + ::grpc::Slice slices_2[1]; + int num_slices = 1; + slices_2[0] = ::grpc::Slice(str.length()); + memcpy(const_cast(slices_2[0].begin()), str.c_str(), str.length()); + ::grpc::ByteBuffer bytebuffer2(&slices_2[0], num_slices); + // deserialize zero-copy - framework::Variable var2; - operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + framework::Scope scope; + scope.Var("myvar"); + operators::detail::TensorResponse resp(&scope, &ctx); + if (from_type == 0) { + EXPECT_EQ(resp.Parse(msg), 0); + } else { + EXPECT_EQ(resp.Parse(bytebuffer2), 0); + } - auto* slr2 = var2.GetMutable(); - auto* tensor2 = slr2->mutable_value(); - auto* rows2 = slr2->mutable_rows(); + framework::Variable* var2 = resp.GetVar(); + + auto tensor2 = var2->Get(); float* tensor_data2 = nullptr; framework::Tensor tmp_tensor; if (platform::is_gpu_place(ctx.GetPlace())) { platform::CPUPlace cpu; - framework::TensorCopy(*tensor2, cpu, &tmp_tensor); + framework::TensorCopy(tensor2, cpu, &tmp_tensor); tensor_data2 = tmp_tensor.data(); } else { - tensor_data2 = const_cast(tensor2->data()); + tensor_data2 = const_cast(tensor2.data()); } - const int64_t* rows_data2 = rows2->data(); - for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); - } - EXPECT_EQ(rows_data2[0], 3); - EXPECT_EQ(rows_data2[1], 10); + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); +} + +TEST(LodTensor, GPU) { + platform::CUDAPlace place; + RunTestLodTensor(place); + RunTestLodTensor(place, 1); +} + +TEST(LodTensor, CPU) { + platform::CPUPlace place; + RunTestLodTensor(place); + RunTestLodTensor(place, 1); } TEST(SelectedRows, CPU) { @@ -174,13 +215,3 @@ TEST(SelectedRows, GPU) { platform::CUDAPlace place; RunSerdeTestSelectedRows(place); } - -TEST(Tensor, CPU) { - platform::CPUPlace place; - RunSerdeTestTensor(place); -} - -TEST(Tensor, GPU) { - platform::CUDAPlace place; - RunSerdeTestTensor(place); -} \ No newline at end of file diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc new file mode 100644 index 0000000000..12e8eb0b4d --- /dev/null +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -0,0 +1,400 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/detail/variable_response.h" +#include +#include "paddle/fluid/operators/detail/send_recv.pb.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +enum WireType { + WIRETYPE_VARINT = 0, + WIRETYPE_LENGTH_DELIMITED = 2, +}; + +inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; } + +inline WireType GetTagWireType(uint32_t tag) { + return static_cast(tag & 0x7); +} + +bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input, + int* result) { + uint64_t v; + if (input->ReadVarint64(&v) && v <= static_cast(INT_MAX)) { + *result = static_cast(v); + return true; + } else { + return false; + } +} + +bool ReadRaw(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& dev_ctx, platform::Place place, + void* dest, int size) { + const void* data = NULL; + int size_to_write = 0; + + if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + auto& gpu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + + char* p = reinterpret_cast(dest); + while (size > 0) { + if (!input->GetDirectBufferPointer(&data, &size_to_write)) { + return false; + } + + memory::Copy(boost::get(place), + reinterpret_cast(p), cpu, data, size_to_write, + gpu_dev_ctx.stream()); + p += size_to_write; + size -= size_to_write; + + input->Skip(size_to_write); + } + gpu_dev_ctx.Wait(); +#else + PADDLE_THROW("Unexpected branch"); +#endif + return true; + } + + char* p = reinterpret_cast(dest); + while (size > 0) { + if (!input->GetDirectBufferPointer(&data, &size_to_write)) { + return false; + } + // TODO(gongwb): can we avoid copy? + platform::CPUPlace cpu; + memory::Copy(cpu, reinterpret_cast(p), cpu, data, size_to_write); + + p += size_to_write; + size -= size_to_write; + + input->Skip(size_to_write); + } + + return true; +} + +bool VariableResponse::CopyLodTensorData( + ::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, framework::DDim& dims, int length) { + auto var = scope_->FindVar(meta_.varname()); + auto* tensor = var->GetMutable(); + tensor->Resize(dims); + + framework::LoD lod; + for (int i = 0; i < meta_.lod_level(); ++i) { + framework::Vector v; + for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) { + v.push_back(meta_.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + + void* tensor_data = + tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); + + if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { + return false; + } + + return true; +} + +inline framework::DDim GetDims( + const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) { + std::vector vecdims; + for (auto& d : dims) { + vecdims.push_back(d); + } + return framework::make_ddim(vecdims); +} + +bool VariableResponse::CopySelectRowsTensorData( + ::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, framework::DDim& dims, int length) { + auto var = scope_->FindVar(meta_.varname()); + auto* slr = var->GetMutable(); + auto* tensor = slr->mutable_value(); + tensor->Resize(dims); + void* tensor_data = tensor->mutable_data( + ctx.GetPlace(), + paddle::operators::detail::ToTypeIndex(meta_.data_type())); + + if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { + return false; + } + + return true; +} + +bool VariableResponse::CopySelectRowsData( + ::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, int length) { + auto var = scope_->FindVar(meta_.varname()); + auto* slr = var->GetMutable(); + int64_t* rows_data = slr->mutable_rows()->data(); + + // copy rows CPU data, GPU data will be copied lazily. + platform::CPUPlace cpu; + if (!ReadRaw(input, ctx, cpu, rows_data, length)) { + return false; + } + + return true; +} + +bool ParseLodData(::google::protobuf::io::CodedInputStream* input, + std::vector* lod) { + while (true) { + auto p = input->ReadTagWithCutoff(127); + int tag = GetTagFieldNumber(p.first); + WireType wt = GetTagWireType(p.first); + + if (!p.second) { + return (tag == 0); + } + + switch (tag) { + case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: { + uint64_t v; + if (wt == WIRETYPE_VARINT) { + if (!input->ReadVarint64(&v)) { + return false; + } + lod->push_back(v); + break; + } + + if (wt == WIRETYPE_LENGTH_DELIMITED) { + int length = 0; + if (!input->ReadVarintSizeAsInt(&length)) { + return tag; + } + + for (int i = 0; i < length; i++) { + uint64_t v; + if (!input->ReadVarint64(&v)) { + return false; + } + lod->push_back(v); + } + break; + } + + return false; + } + default: { return false; } + } + } + + return true; +} + +int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { + GrpcByteBufferSource source; + source.Init(byte_buffer); + GrpcByteBufferSourceWrapper r(&source); + + return Parse(&r); +} + +int VariableResponse::Parse(Source* source) { + ::google::protobuf::io::ZeroCopyInputStream* input_stream = + source->contents(); + ::google::protobuf::io::CodedInputStream input(input_stream); + input.SetTotalBytesLimit(INT_MAX, INT_MAX); + + while (true) { + auto p = input.ReadTagWithCutoff(127); + int tag = GetTagFieldNumber(p.first); + WireType wt = GetTagWireType(p.first); + if (!p.second) { + if (tag != 0) { + return -1; + } + + return 0; + } + + switch (tag) { + case sendrecv::VariableMessage::kVarnameFieldNumber: { + uint32_t length; + if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { + return tag; + } + + std::string temp; + if (!input.ReadString(&temp, length)) { + return tag; + } + + meta_.set_varname(temp); + break; + } + case sendrecv::VariableMessage::kTypeFieldNumber: { + uint64_t v; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + + meta_.set_type(static_cast<::sendrecv::VarType>(v)); + break; + } + case sendrecv::VariableMessage::kDataTypeFieldNumber: { + uint64_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + + meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v)); + break; + } + case sendrecv::VariableMessage::kDimsFieldNumber: { + // not packed + if (wt == WIRETYPE_VARINT) { + uint64_t v; + if (!input.ReadVarint64(&v)) { + return tag; + } + meta_.add_dims(v); + break; + } + + // packed + if (wt == WIRETYPE_LENGTH_DELIMITED) { + int length = 0; + if (!input.ReadVarintSizeAsInt(&length)) { + return tag; + } + for (int i = 0; i < length; i++) { + uint64_t v; + if (!input.ReadVarint64(&v)) { + return tag; + } + meta_.add_dims(v); + } + break; + } + + return tag; + } + case sendrecv::VariableMessage::kLodLevelFieldNumber: { + uint64_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + meta_.set_lod_level(static_cast(v)); + break; + } + case sendrecv::VariableMessage::kLodFieldNumber: { + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + + std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p = + input.IncrementRecursionDepthAndPushLimit(length); + + std::vector lod_data; + if (p.second < 0 || !ParseLodData(&input, &lod_data)) { + return tag; + } + + if (!input.DecrementRecursionDepthAndPopLimit(p.first)) { + return false; + } + + if (lod_data.size() == 0) { + break; + } + + auto lod = meta_.add_lod(); + for (uint32_t i = 0; i < lod_data.size(); i++) { + lod->add_lod_data(lod_data[i]); + } + break; + } + case sendrecv::VariableMessage::kSerializedFieldNumber: { + PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || + meta_.type() == sendrecv::LOD_TENSOR) && + meta_.varname() != "", + "meta info should be got first!"); + + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + + framework::DDim dims = GetDims(meta_.dims()); + if (meta_.type() == sendrecv::LOD_TENSOR) { + PADDLE_ENFORCE(meta_.lod_size() >= 0, + "lod info should be got first!"); + if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) { + return tag; + } + break; + } + + if (meta_.type() == sendrecv::SELECTED_ROWS) { + if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) { + return tag; + } + break; + } + + return tag; + } + case sendrecv::VariableMessage::kRowsFieldNumber: { + PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || + meta_.type() == sendrecv::LOD_TENSOR) && + meta_.varname() != "", + "meta info should be got first!"); + + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + + if (!CopySelectRowsData(&input, *dev_ctx_, length)) { + return tag; + } + break; + } + + default: { + // Unknown tag, return unknown error. + return -1; + } + } + } + + return 0; +} + +}; // namespace detail +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h new file mode 100644 index 0000000000..c7bc7a46e7 --- /dev/null +++ b/paddle/fluid/operators/detail/variable_response.h @@ -0,0 +1,81 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/var_type.h" + +#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/detail/send_recv.pb.h" + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/detail/bytebuffer_stream.h" + +namespace paddle { +namespace operators { +namespace detail { + +class VariableResponse { + public: + VariableResponse(const framework::Scope* scope, + const platform::DeviceContext* dev_ctx) + : scope_(scope), dev_ctx_(dev_ctx){}; + + virtual ~VariableResponse(){}; + + // return: + // 0:ok. + // -1: unkown error. + // other: number of error field. + int Parse(Source* source); + + // return: + // 0:ok. + // -1: unkown error. + // other: number of error field. + int Parse(const ::grpc::ByteBuffer& byte_buffer); + + inline std::string Varname() { return meta_.varname(); } + + // should call parse first. + framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); } + + private: + bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, + framework::DDim& dims, int length); + + bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, int length); + + bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, + framework::DDim& dims, int length); + + private: + const framework::Scope* scope_; + const platform::DeviceContext* dev_ctx_; + // only Skeleton + sendrecv::VariableMessage meta_; +}; + +}; // namespace detail +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index a594de67e0..31ea2a7e58 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -69,9 +69,7 @@ class ListenAndServOp : public framework::OperatorBase { } void Stop() override { - detail::MessageWithName term_msg; - term_msg.first = LISTEN_TERMINATE_MESSAGE; - rpc_service_->Push(term_msg); + rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); rpc_service_->ShutDown(); server_thread_->join(); } @@ -108,7 +106,7 @@ class ListenAndServOp : public framework::OperatorBase { size_t recv_var_cnt = 0; int batch_barrier = 0; while (batch_barrier != fan_in) { - const detail::MessageWithName &v = rpc_service_->Get(); + const detail::ReceivedMessage v = rpc_service_->Get(); auto recv_var_name = v.first; if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { LOG(INFO) << "received terminate message and exit"; @@ -121,12 +119,11 @@ class ListenAndServOp : public framework::OperatorBase { } else { VLOG(3) << "received grad: " << recv_var_name; recv_var_cnt++; - auto *var = recv_scope.FindVar(recv_var_name); + auto var = v.second->GetVar(); if (var == nullptr) { LOG(ERROR) << "Can not find server side var: " << recv_var_name; PADDLE_THROW("Can not find server side var"); } - detail::DeserializeFromMessage(v.second, dev_ctx, var); if (var->IsType()) { sparse_vars.push_back(var); } diff --git a/python/paddle/fluid/debuger.py b/python/paddle/fluid/debuger.py index 97fa182c40..7b4afa9bf6 100644 --- a/python/paddle/fluid/debuger.py +++ b/python/paddle/fluid/debuger.py @@ -16,7 +16,6 @@ import sys import re from graphviz import GraphPreviewGenerator import proto.framework_pb2 as framework_pb2 -import paddle.fluid.core as core _vartype2str_ = [ "UNK", @@ -126,7 +125,6 @@ def pprint_block_codes(block_desc, show_backward=False): def is_var_backward(var_desc): return "@GRAD" in var_desc.name - #print(type(block_desc)) if type(block_desc) is not framework_pb2.BlockDesc: block_desc = framework_pb2.BlockDesc.FromString( block_desc.serialize_to_string()) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index ad655ee96c..33cea96421 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -20,6 +20,7 @@ from layer_helper import LayerHelper from distributed_spliter import * import math from . import core +import debuger class VarBlock: @@ -289,6 +290,7 @@ class DistributeTranspiler: dtype=v.dtype, shape=v.shape) recv_inputs.append(var) + # step3 optimize_block = pserver_program.create_block(0) # step 4 From ab5ecdf60ebecdd4e18dd4208dee873ba0bb8dfc Mon Sep 17 00:00:00 2001 From: weixing Date: Thu, 22 Mar 2018 13:02:09 +0800 Subject: [PATCH 76/79] Adjust some contents in write_docs_en.rst for Contribue Documentation (#9147) * Add some contents * Adjust the content of the English version * Fix some error, replace word generate with build * Replace document with documentation * Adjust contents * Make links more visible --- doc/v2/dev/write_docs_cn.rst | 9 +++-- doc/v2/dev/write_docs_en.rst | 78 +++++++++++++++++++++++++++--------- 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/doc/v2/dev/write_docs_cn.rst b/doc/v2/dev/write_docs_cn.rst index a055bb04c0..23615f8830 100644 --- a/doc/v2/dev/write_docs_cn.rst +++ b/doc/v2/dev/write_docs_cn.rst @@ -2,13 +2,14 @@ 如何贡献文档 ############# -PaddlePaddle的文档包括中英文两个部分。文档都是通过 ``cmake`` 驱动 ``sphinx`` 编译生成,也可以利用paddlepaddle.org工具来编译和预览文档。 +PaddlePaddle的文档包括中英文两个部分。文档都是通过 ``cmake`` 驱动 ``sphinx`` 编译生成的,PaddlePaddle.org工具可以帮助我们实现这一编译过程,并提供更好的预览效果。 如何构建文档 ============ PaddlePaddle的文档构建有两种方式,分别为使用paddlepaddle.org工具和不使用paddlepaddle.org工具,两种方式都有各自的优点,前者方便预览,后者方便开发者进行调试。这两种方式中又分别有使用docker和不使用docker的两种构建方法。 +我们建议使用PaddlePaddle.org工具来构建文档。 使用PaddlePaddle.org工具 ------------------------ @@ -31,7 +32,7 @@ PaddlePaddle.org工具可以配合Docker使用,需要在系统里先安装好D docker run -it -p 8000:8000 -v `pwd`:/var/content paddlepaddle/paddlepaddle.org:latest 注意: PaddlePaddle.org 会在 -v (volume) 指定的内容存储库运行命令 -之后再用网页连到http://localhost:8000就可以在网页上生成需要的文档 +之后再用网页连到 http://localhost:8000 就可以在网页上生成需要的文档 编译后的文件将被存储在工作目录 /.ppo_workspace/content。 如果不想使用Docker,你还可以通过运行Django框架直接激活工具的服务器。使用下面的命令来运行它。 @@ -56,7 +57,7 @@ PaddlePaddle.org工具可以配合Docker使用,需要在系统里先安装好D python manage.py runserver 工具服务器将读取环境变量 CONTENT_DIR 搜索代码库。请指定的PaddlePaddle工作目录给环境变量 CONTENT_DIR。 -之后再用网页连到http://localhost:8000就可以在网页上生成需要的文档。 +之后再用网页连到 http://localhost:8000 就可以在网页上生成需要的文档。 编译后的文件将被存储在工作目录 /.ppo_workspace/content。 想了解更多PaddlePaddle.org工具的详细信息,可以 `点击这里 `_ 。 @@ -96,7 +97,7 @@ PaddlePaddle.org工具可以配合Docker使用,需要在系统里先安装好D python -m SimpleHTTPServer 8088 -在浏览器中输入http://localhost:8088就可以看到编译生成的中/英文的文档页面和英文的API页面,下图为生成的英文文档首页示例。注意,示例中由于使用了sphinx的原始主题,所以页面的风格与官网并不一致,但这并不影响开发者进行调试。 +在浏览器中输入 http://localhost:8088 就可以看到编译生成的中/英文的文档页面和英文的API页面,下图为生成的英文文档首页示例。注意,示例中由于使用了sphinx的原始主题,所以页面的风格与官网并不一致,但这并不影响开发者进行调试。 .. image:: src/doc_en.png :align: center diff --git a/doc/v2/dev/write_docs_en.rst b/doc/v2/dev/write_docs_en.rst index f3408a8426..15ff0d34ad 100644 --- a/doc/v2/dev/write_docs_en.rst +++ b/doc/v2/dev/write_docs_en.rst @@ -2,21 +2,20 @@ Contribute Documentation ######################## -PaddlePaddle supports English documentation ``doc`` and Chinese documentation ``doc_cn``. -Both are compiled by `cmake`_ and `sphinx`_ , the compiled documentations will be stored under ``doc`` and ``doc_cn`` directories. -When using the PaddlePaddle.org to compile documentations, the compiled documentations will be stored under a consolidated directory: .ppo_workspace/content +PaddlePaddle's documentation includes both Chinese and English versions. The documentation is built using the ``cmake`` command to drive the ``sphinx`` compiler. The PaddlePaddle.org tool helps us to implement this compilation process and provides better preview results. -How to Build Documentations -============ +How to build Documentation +=========================== -We recommend using PaddlePaddle.org tool to build documentation +PaddlePaddle's documentation is built in two ways: using the PaddlePaddle.org tool and without using it. Both methods have their own advantages. The former facilitates previewing, while the latter facilitates debugging by the developer. We could choose to build the documentation with Docker or without it in each of the above ways. +We recommend using PaddlePaddle.org tool to build documentation. -Use PaddlePaddle.org tool --------------- -This is the recommended method to build documentation. It can compile documentation and preview the documentation in a web browser. +Using PaddlePaddle.org tool +----------------------------- +This is the recommended method to build documentation, because it can automatically compile the documentation and preview the documentation directly in a web page. Note that, although you can preview the documentation in other ways, its style may not be consistent with the official website. Compiling with the PaddlePaddle.org tool produces a preview that will be consistent with the official website documentation style. -The tool uses Docker, please install it on your system. Please check Docker official website on how to install Docker. You may use the following commands to activate the tool +The PaddlePaddle.org tool can be used with Docker and Docker needs to be installed first. Please refer to `Docker's official website `_ on how to install Docker. After installing Docker, you may use the following commands to activate the tool .. code-block:: bash @@ -32,8 +31,8 @@ The tool uses Docker, please install it on your system. Please check Docker offi # Please specify the working directory through -v docker run -it -p 8000:8000 -v `pwd`:/var/content paddlepaddle/paddlepaddle.org:latest -Note: PaddlePaddle.org will read the content repos specified in the -v (volume) flag of the docker run command -Use a web browser and navigate to http://localhost:8000, click the buttons to compile the documentation +Note: PaddlePaddle.org will read the content repos specified in the -v (volume) flag of the docker run commands +Use a web browser and navigate to http://localhost:8000. Click the buttons to compile the documentation. The compiled documentations will be stored in /.ppo_workspace/content @@ -58,19 +57,62 @@ If you don't wish to use Docker, you can also activate the tool through Django. pip install -r requirements.txt python manage.py runserver -Use a web browser and navigate to http://localhost:8000, click the buttons to compile the documentation +Specify the PaddlePaddle working directory for the environment variable CONTENT_DIR so that the tool could find where the working directory is. + +Use a web browser and navigate to http://localhost:8000. Click the buttons to compile the documentation The compiled documentations will be stored in /.ppo_workspace/content -If you want to learn more on the PaddlePaddle.org, please `click here `_ 。 +Please `click here `_ for more information about the PaddlePaddle.org tool. + + +Manually Building the Documentation +------------------------------------- + +Build PaddlePaddle's documentation with Docker,you need to install Docker first. Please refer to `Docker's official website `_ on how to install Docker. After Docker is installed, you could use the scripts in the source directory to build the documentation. + +[TBD] + +If you do not wish to use Docker, you can also use the following commands to directly build the PaddlePaddle documentation. + +.. code-block:: bash + + mkdir paddle + cd paddle + git clone https://github.com/PaddlePaddle/Paddle.git + mkdir -p build + cd build + cmake .. -DCMAKE_BUILD_TYPE=Release -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON + + # If you only need to build documents, use the following commands + make -j $processors gen_proto_py + make -j $processors paddle_docs paddle_docs_cn + + # If you only need to build APIs, use the following commands + make -j $processors gen_proto_py framework_py_proto + make -j $processors copy_paddle_pybind + make -j $processors paddle_api_docs + +$processors indicates that as many processes as the CPU cores are started to compile in parallel. It should be set according to the number of CPU cores of your machine. + +After the compilation is complete, enter the ``doc/v2`` directory. If you chose to build documents, it will generate ``cn/html/`` and ``en/html`` subdirectories under this directory. If you chose to build APIs,it will generate``api/en/html`` subdirectory. Please enter these directories respectively and execute the following commands: + +.. code-block:: bash + + python -m SimpleHTTPServer 8088 + +Use a web browser and navigate to http://localhost:8000, you could see the compiled Chinese/English documents page and the English APIs page. The following figure is an example of the built English documents home page. Note that due to the sphinx's original theme used in the example, the style of the page is not consistent with the official website, but this does not affect the developer's debugging. -How to write Documentations -============ +.. image:: src/doc_en.png + :align: center + :scale: 60 % -PaddlePaddle uses `sphinx`_ to compile documentations,Please check sphinx official website for more detail. +How to write Documentation +=========================== +PaddlePaddle uses `sphinx`_ to compile documentation,Please check sphinx official website for more detail. How to update www.paddlepaddle.org -============================ +=================================== Please create PRs and submit them to github, please check `Contribute Code `_ 。 PaddlePaddle develop branch will update the documentation once the PR is merged. User may check latest `Chinese Docs `_ and From 3c8bbd306f254841dd7c0af820739d945bf096d7 Mon Sep 17 00:00:00 2001 From: legend06hvl Date: Thu, 22 Mar 2018 15:10:04 +0800 Subject: [PATCH 77/79] Update index_en.rst (#9280) * Update index_en.rst * Update index_en.rst Update refer to commits --- doc/v2/howto/index_en.rst | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/doc/v2/howto/index_en.rst b/doc/v2/howto/index_en.rst index 2079be766f..bf2320a169 100644 --- a/doc/v2/howto/index_en.rst +++ b/doc/v2/howto/index_en.rst @@ -1,11 +1,37 @@ HOW TO -======= +======== + +PaddlePaddle provides the users the ability to flexibly set various command line parameters to control the model training and inference process. Please refer to the following instructions on using PaddlePaddle: + +.. toctree:: + :maxdepth: 1 + + cmd_parameter/index_cn.rst + +PaddlePaddle supports distributed training tasks on fabric clusters, MPI clusters, and Kubernetes clusters. For detailed configuration and usage instructions, refer to: + +.. toctree:: + :maxdepth: 1 + + cluster/index_cn.rst + +PaddlePaddle provides a C-API for inference. We provide the following guidelines for using the C-API: + +.. toctree:: + :maxdepth: 1 + + capi/index_cn.rst + +PaddlePaddle supports a variety of flexible and efficient recurrent neural networks. For details, please refer to: + +.. toctree:: + :maxdepth: 1 + + rnn/index_cn.rst + +How to use the built-in timing tool, nvprof, or nvvp to run performance analysis and tuning, please refer to: .. toctree:: :maxdepth: 1 - cmd_parameter/index_en.rst - cluster/index_en.rst - capi/index_en.rst - rnn/index_en.rst - optimization/gpu_profiling_en.rst + optimization/gpu_profiling_cn.rst From 13f1050ab0f5113fea223f47e99f7c6b4f9644a7 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 22 Mar 2018 15:15:02 +0800 Subject: [PATCH 78/79] "fix mixed_vector bug" (#9319) --- paddle/fluid/framework/mixed_vector.h | 2 +- paddle/fluid/framework/mixed_vector_test.cu | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 6a6fa53871..d99a15547b 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -176,7 +176,7 @@ class Vector { // resize the vector void resize(size_t size) { - if (size + 1 < capacity()) { + if (size + 1 <= capacity()) { size_ = size; } else { MutableCPU(); diff --git a/paddle/fluid/framework/mixed_vector_test.cu b/paddle/fluid/framework/mixed_vector_test.cu index 4bf78499f2..d57f825108 100644 --- a/paddle/fluid/framework/mixed_vector_test.cu +++ b/paddle/fluid/framework/mixed_vector_test.cu @@ -104,3 +104,11 @@ TEST(mixed_vector, ForEach) { for (auto& v : tmp) { } } + +TEST(mixed_vector, Reserve) { + paddle::framework::Vector vec; + vec.reserve(1); + vec.push_back(0); + vec.push_back(0); + vec.push_back(0); +} From 466f28a6b18f56fe0b2686091a49802ea97334b7 Mon Sep 17 00:00:00 2001 From: legend06hvl Date: Thu, 22 Mar 2018 15:16:01 +0800 Subject: [PATCH 79/79] Update index_en.rst (#9286) * Update index_en.rst Update en version * Update index_en.rst Update refer to commits and thank you for the suggestion. --- doc/v2/howto/capi/index_en.rst | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/doc/v2/howto/capi/index_en.rst b/doc/v2/howto/capi/index_en.rst index 2cbbe362fd..4ec39c9d52 100644 --- a/doc/v2/howto/capi/index_en.rst +++ b/doc/v2/howto/capi/index_en.rst @@ -1,6 +1,23 @@ -C-API Prediction Library +C-API Inference Library ======================== +After we train a neural network, we use it to do inference. Inference is the process of preparing input data and propagating it through the model to produce the result. + +Compared with model training, prediction has the following features: + +#. Inference does not require backpropagation and parameter updates, as required during training. +#. Labels are not needed in prediction. +#. Most of the time, predictions need to be integrated with the user system. + +Therefore, the model prediction SDK needs to be designed separately and has the following features: + +#. The predictive SDK does not include backpropagation and parameter updates to reduce the size of the SDK. +#. The predictive SDK needs a simple user interface for ease of use. +#. Since the input data may have a variety of structures, the format of the input data is clearly and compactly packaged. +#. In order to be compatible with user's system, the SDK's interface must conform to the C-standard interface. + +PaddlePaddle provides C-API to solve the above problem. Following are the guidelines to use the C-API: + .. toctree:: :maxdepth: 1